Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 7 additions & 23 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 8 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ members = [
exclude = [
"examples/notebook",
"examples/raspberry-pi-pico",
"examples/dqn-agent", # gym-rs
"examples/dqn-agent", # gym-rs
]

[workspace.package]
Expand Down Expand Up @@ -179,13 +179,14 @@ portable-atomic = { version = "1.13.1" }
portable-atomic-util = { version = "0.2.5", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" }
cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b121ed35e19bf9e0935dd563e97cab6e7a76005e" }
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" }
# cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b121ed35e19bf9e0935dd563e97cab6e7a76005e" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubek = { path = "../cubek/crates/cubek", default-features = false }
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false }
cubek = { path = "../cubek/crates/cubek", default-features = false }
### For the release. ###
# cubecl = { version = "=0.10.0-pre.1", default-features = false }
# cubecl-common = { version = "=0.10.0-pre.1", default-features = false }
Expand Down
18 changes: 9 additions & 9 deletions crates/burn-backend/src/backend/ops/modules/conv.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#![allow(clippy::single_range_in_vec_init)]
use super::{ConvOptions, ConvTransposeOptions};
use crate::{Backend, TensorMetadata, tensor::FloatTensor};
use burn_std::{Shape, ShapeError, Slice};
use burn_std::{MetadataError, Shape, Slice};

use alloc::{vec, vec::Vec};
#[cfg(not(feature = "std"))]
Expand All @@ -16,9 +16,9 @@ pub fn calculate_pool_output_shape<const N: usize>(
padding: &[usize; N],
dilation: &[usize; N],
ceil_mode: bool,
) -> Result<Shape, ShapeError> {
) -> Result<Shape, MetadataError> {
if in_shape.rank() != N + 2 {
return Err(ShapeError::RankMismatch {
return Err(MetadataError::RankMismatch {
left: in_shape.rank(),
right: N + 2,
});
Expand Down Expand Up @@ -47,16 +47,16 @@ pub fn calculate_conv_output_shape<const N: usize>(
stride: &[usize; N],
padding: &[usize; N],
dilation: &[usize; N],
) -> Result<Shape, ShapeError> {
) -> Result<Shape, MetadataError> {
if weight_shape.rank() != N + 2 {
return Err(ShapeError::RankMismatch {
return Err(MetadataError::RankMismatch {
left: weight_shape.rank(),
right: N + 2,
});
}

if in_shape.rank() != N + 2 {
return Err(ShapeError::RankMismatch {
return Err(MetadataError::RankMismatch {
left: in_shape.rank(),
right: N + 2,
});
Expand Down Expand Up @@ -85,16 +85,16 @@ pub fn calculate_conv_transpose_output_shape<const N: usize>(
padding_out: &[usize; N],
dilation: &[usize; N],
groups: usize,
) -> Result<Shape, ShapeError> {
) -> Result<Shape, MetadataError> {
if weight_shape.rank() != N + 2 {
return Err(ShapeError::RankMismatch {
return Err(MetadataError::RankMismatch {
left: weight_shape.rank(),
right: N + 2,
});
}

if in_shape.rank() != N + 2 {
return Err(ShapeError::RankMismatch {
return Err(MetadataError::RankMismatch {
left: in_shape.rank(),
right: N + 2,
});
Expand Down
10 changes: 5 additions & 5 deletions crates/burn-cubecl-fusion/src/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn_fusion::stream::Context;
use burn_std::{DType, quantization::QParamTensor};
use burn_std::{DType, Strides, quantization::QParamTensor, strides};
use cubecl::{
CubeElement, Runtime,
client::ComputeClient,
Expand Down Expand Up @@ -33,7 +33,7 @@ pub struct CubeFusionHandle<R: Runtime> {
/// The element type of the tensor.
pub dtype: DType,
/// The strides of the tensor.
pub strides: Vec<usize>,
pub strides: Strides,
/// Quantization runtime parameters, if applicable
pub qparams: Option<QParams>,
}
Expand Down Expand Up @@ -121,14 +121,14 @@ impl<R: Runtime> CubeFusionHandle<R> {
QuantParam::BF16 => DType::BF16,
QuantParam::UE8M0 | QuantParam::UE4M3 => unimplemented!("Not yet supported"),
},
strides: qparams.scales.strides.clone(),
strides: qparams.scales.metadata.strides().clone(),
qparams: None,
})
}
}

pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![0; shape.len()];
pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Strides {
let mut strides = strides![0; shape.len()];

let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
Expand Down
17 changes: 9 additions & 8 deletions crates/burn-cubecl-fusion/src/engine/codegen/ir.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use super::tensor::GlobalTensor;
use crate::engine::codegen::DYN_ELEM_ID;
use burn_std::{
DType, bf16, f16,
DType, Shape, Strides, bf16, f16,
quantization::{QuantScheme, QuantStore, QuantValue},
strides,
};
use core::fmt::Display;
use cubecl::{
Expand Down Expand Up @@ -416,15 +417,15 @@ impl<R: Runtime> GlobalArgsLaunch<'_, R> {
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn shape(&self, arg: &FuseArg) -> Vec<usize> {
pub fn shape(&self, arg: &FuseArg) -> Shape {
match self.resolve_arg(arg) {
TensorArg::Handle { handle, .. } => handle.shape.to_vec(),
TensorArg::Handle { handle, .. } => handle.shape.into(),
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}

/// Shape used by the reference tensor.
pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Vec<usize> {
pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Shape {
match ref_layout {
RefLayout::Concrete(arg) => self.shape(arg),
RefLayout::Virtual(layout) => match layout {
Expand Down Expand Up @@ -459,20 +460,20 @@ impl<R: Runtime> GlobalArgsLaunch<'_, R> {
/// # Panics
///
/// If the argument doesn't have an handle.
pub fn strides(&self, arg: &FuseArg) -> Vec<usize> {
pub fn strides(&self, arg: &FuseArg) -> Strides {
match self.resolve_arg(arg) {
TensorArg::Handle { handle, .. } => handle.strides.to_vec(),
TensorArg::Handle { handle, .. } => handle.strides.into(),
TensorArg::Alias { .. } => panic!("Unsupported yet"),
}
}

pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Vec<usize> {
pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Strides {
match ref_layout {
RefLayout::Concrete(arg) => self.strides(arg),
// When not concrete, we operate on the contiguous layout.
_ => {
let shape = self.shape_ref(ref_layout, rank);
let mut strides = vec![0; shape.len()];
let mut strides = strides![0; shape.len()];

let mut current = 1;
shape.iter().enumerate().rev().for_each(|(index, val)| {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-cubecl-fusion/src/engine/launch/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> {
for s in layout.shape.iter() {
inputs.runtime_layouts.push(ScalarArg::new(*s));
}
for s in layout.strides {
inputs.runtime_layouts.push(ScalarArg::new(s));
for s in layout.strides.iter() {
inputs.runtime_layouts.push(ScalarArg::new(*s));
}
}

Expand Down
Loading
Loading