From 9acf08b36ea329074d049eb9ac741af392a7f032 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:53:32 +0100 Subject: [PATCH 1/8] Remove R::supported_line_sizes --- Cargo.lock | 27 +++---------------- Cargo.toml | 14 +++++----- .../engine/launch/vectorization/planner.rs | 6 ++--- crates/burn-cubecl/src/kernel/conv/direct.rs | 2 +- .../src/kernel/index/slice_assign.rs | 9 +++---- crates/burn-cubecl/src/ops/base.rs | 8 ++---- 6 files changed, 21 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dfb80ee2e..9a79d7b547 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2077,7 +2077,6 @@ dependencies = [ [[package]] name = "cubecl" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "cubecl-core", "cubecl-cpu", @@ -2093,7 +2092,6 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "backtrace", "bincode", @@ -2131,7 +2129,6 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -2158,7 +2155,6 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "bytemuck", "cubecl-common", @@ -2174,7 +2170,6 @@ dependencies = [ [[package]] name = "cubecl-cpu" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "bytemuck", "cubecl-common", @@ -2195,7 +2190,6 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "bytemuck", "cubecl-common", @@ -2213,7 +2207,6 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "bytemuck", "cubecl-common", @@ -2242,7 +2235,6 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -2263,7 +2255,6 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "cubecl-common", "darling 0.23.0", @@ -2278,7 +2269,6 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -2289,7 +2279,6 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "cubecl-common", "cubecl-core", @@ -2306,7 +2295,6 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "async-channel", "bytemuck", @@ -2335,7 +2323,6 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "bitflags 2.10.0", "cubecl-common", @@ -2351,7 +2338,6 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "cubecl-common", "cubecl-core", @@ -2367,7 +2353,6 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" dependencies = [ "ash", "async-channel", @@ -2394,12 +2379,14 @@ dependencies = [ [[package]] name = "cubecl-zspace" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=11c32460621652077898d3309430711ab3d7944d#11c32460621652077898d3309430711ab3d7944d" +dependencies = [ + "serde", + "smallvec", +] [[package]] name = "cubek" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "cubecl", "cubek-attention", @@ -2413,7 +2400,6 @@ dependencies = [ [[package]] name = "cubek-attention" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "bytemuck", "cubecl", @@ -2427,7 +2413,6 @@ dependencies = [ [[package]] name = "cubek-convolution" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "bytemuck", "cubecl", @@ -2442,7 +2427,6 @@ dependencies = [ [[package]] name = "cubek-matmul" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "bytemuck", "cubecl", @@ -2454,7 +2438,6 @@ dependencies = [ [[package]] name = "cubek-quant" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "cubecl", "cubecl-common", @@ -2465,7 +2448,6 @@ dependencies = [ [[package]] name = "cubek-random" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "cubecl", "cubecl-common", @@ -2478,7 +2460,6 @@ dependencies = [ [[package]] name = "cubek-reduce" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=b121ed35e19bf9e0935dd563e97cab6e7a76005e#b121ed35e19bf9e0935dd563e97cab6e7a76005e" dependencies = [ "cubecl", "half", diff --git a/Cargo.toml b/Cargo.toml index b5d68a7d03..2a2242da31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ exclude = [ "examples/notebook", "examples/raspberry-pi-pico", - "examples/dqn-agent", # gym-rs + "examples/dqn-agent", # gym-rs ] [workspace.package] @@ -179,13 +179,13 @@ portable-atomic = { version = "1.13.1" } portable-atomic-util = { version = "0.2.5", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" } -cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b121ed35e19bf9e0935dd563e97cab6e7a76005e" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" } +# cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b121ed35e19bf9e0935dd563e97cab6e7a76005e" } ### For local development. ### -# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } -# cubek = { path = "../cubek/crates/cubek", default-features = false } +cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### # cubecl = { version = "=0.10.0-pre.1", default-features = false } # cubecl-common = { version = "=0.10.0-pre.1", default-features = false } diff --git a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs index 09f2919445..8697ad0a0f 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/vectorization/planner.rs @@ -126,7 +126,7 @@ impl<'a, R: Runtime> VectorizationPlanner<'a, R> { // compare to ref elem. Some(line_sizes) => line_sizes, None => client - .io_optimized_line_sizes_unchecked(ref_elem.0.size()) + .io_optimized_line_sizes(ref_elem.0.size()) .collect::>(), }; @@ -388,7 +388,7 @@ fn line_sizes_quants( | QuantValue::E5M2 | QuantValue::E2M1 => { let line_sizes = client - .io_optimized_line_sizes_unchecked(size_of::()) + .io_optimized_line_sizes(size_of::()) .collect::>(); match &quants_line_sizes { @@ -408,7 +408,7 @@ fn line_sizes_quants( }, QuantStore::PackedU32(_) => { let mut line_sizes = client - .io_optimized_line_sizes_unchecked(size_of::()) + .io_optimized_line_sizes(size_of::()) .collect::>(); for val in line_sizes.iter_mut() { *val *= scheme.num_quants(); diff --git a/crates/burn-cubecl/src/kernel/conv/direct.rs b/crates/burn-cubecl/src/kernel/conv/direct.rs index 57e4f83650..213c189fc2 100644 --- a/crates/burn-cubecl/src/kernel/conv/direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/direct.rs @@ -269,7 +269,7 @@ pub fn conv_direct( let mut grouped_out_shape = output.shape.clone(); grouped_out_shape[dim_c] = channels_per_group; let line_size_out = tensor_line_size_parallel( - input.client.io_optimized_line_sizes(&input.dtype.into()), + input.client.io_optimized_line_sizes(input.dtype.size()), &grouped_out_shape, &output.strides, dim_c, diff --git a/crates/burn-cubecl/src/kernel/index/slice_assign.rs b/crates/burn-cubecl/src/kernel/index/slice_assign.rs index 42eae11821..9cffbdc05e 100644 --- a/crates/burn-cubecl/src/kernel/index/slice_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/slice_assign.rs @@ -127,17 +127,16 @@ pub(crate) fn slice_assign( let end = last.end.unwrap_or(tensor.shape[ndims - 1] as isize); let shape = (end - last.start) as usize; let offset = last.start as usize; - *R::supported_line_sizes() - .iter() - .filter(|it| { - let it = **it; + client + .io_optimized_line_sizes(tensor.dtype.size()) + .filter(|&it| { shape.is_multiple_of(it) && strides_compatible(&tensor.strides, it) && strides_compatible(&value.strides, it) && offset.is_multiple_of(it) }) .max() - .unwrap_or(&1) + .unwrap_or(1) } else { 1 }; diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index 213639c82a..fdfeb6be36 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -374,9 +374,7 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub pub(crate) fn max_line_size(tensor: &CubeTensor) -> LineSize { tensor_line_size_parallel( - tensor - .client - .io_optimized_line_sizes_unchecked(tensor.dtype.size()), + tensor.client.io_optimized_line_sizes(tensor.dtype.size()), &tensor.shape, &tensor.strides, tensor.shape.len() - 1, @@ -391,9 +389,7 @@ pub(crate) fn max_line_size_many( .iter() .map(|tensor| { tensor_line_size_parallel( - tensor - .client - .io_optimized_line_sizes_unchecked(tensor.dtype.size()), + tensor.client.io_optimized_line_sizes(tensor.dtype.size()), &tensor.shape, &tensor.strides, axis, From 16f8ede01caea5718f8fb4710d3ec97fa756f9f9 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 17 Feb 2026 15:16:44 +0100 Subject: [PATCH 2/8] refactor: Metadata optimization --- Cargo.lock | 3 + Cargo.toml | 1 + crates/burn-cubecl-fusion/src/base.rs | 10 +- .../src/engine/codegen/ir.rs | 17 +- .../src/engine/launch/executor.rs | 10 +- .../src/engine/launch/input.rs | 6 +- .../src/engine/launch/output.rs | 11 +- .../src/engine/launch/plan.rs | 16 +- .../src/engine/trace/base.rs | 6 +- crates/burn-cubecl/src/fusion.rs | 6 +- .../burn-cubecl/src/kernel/attention/base.rs | 8 +- crates/burn-cubecl/src/kernel/binary.rs | 6 +- crates/burn-cubecl/src/kernel/binary_int.rs | 5 +- crates/burn-cubecl/src/kernel/cast/base.rs | 6 +- .../burn-cubecl/src/kernel/cast/bool_cast.rs | 10 +- crates/burn-cubecl/src/kernel/comparison.rs | 31 +- crates/burn-cubecl/src/kernel/contiguous.rs | 16 +- .../src/kernel/conv/backward_data/fallback.rs | 4 +- .../src/kernel/conv/backward_data/tune.rs | 16 +- .../kernel/conv/backward_weight/fallback.rs | 6 +- .../src/kernel/conv/backward_weight/tune.rs | 16 +- .../kernel/conv/conv_transpose2d/col2im.rs | 15 +- .../conv/conv_transpose2d/transpose_direct.rs | 6 +- .../src/kernel/conv/conv_transpose2d/tune.rs | 4 +- .../src/kernel/conv/conv_transpose3d.rs | 6 +- .../src/kernel/conv/deform_conv2d.rs | 8 +- .../kernel/conv/deform_conv_transpose2d.rs | 47 +- crates/burn-cubecl/src/kernel/conv/direct.rs | 27 +- .../conv/forward/implicit_gemm/launch.rs | 10 +- .../src/kernel/conv/forward/tune.rs | 20 +- crates/burn-cubecl/src/kernel/conv/im2col.rs | 44 +- crates/burn-cubecl/src/kernel/cross.rs | 8 +- .../src/kernel/grid_sample/bilinear.rs | 4 +- crates/burn-cubecl/src/kernel/index/flip.rs | 8 +- crates/burn-cubecl/src/kernel/index/gather.rs | 3 +- .../src/kernel/index/repeat_dim.rs | 12 +- .../burn-cubecl/src/kernel/index/scatter.rs | 2 +- crates/burn-cubecl/src/kernel/index/select.rs | 5 +- .../src/kernel/index/select_assign.rs | 2 +- crates/burn-cubecl/src/kernel/index/slice.rs | 26 +- .../src/kernel/index/slice_assign.rs | 27 +- .../src/kernel/interpolate/base.rs | 6 +- .../src/kernel/interpolate/bicubic.rs | 2 +- .../src/kernel/interpolate/bilinear.rs | 2 +- .../src/kernel/interpolate/nearest.rs | 2 +- .../kernel/interpolate/nearest_backward.rs | 2 +- .../burn-cubecl/src/kernel/mask/mask_fill.rs | 8 +- .../burn-cubecl/src/kernel/mask/mask_where.rs | 4 +- crates/burn-cubecl/src/kernel/matmul/base.rs | 4 +- .../src/kernel/matmul/tune/base.rs | 8 +- crates/burn-cubecl/src/kernel/matmul/utils.rs | 2 +- .../src/kernel/pool/adaptive_avg_pool2d.rs | 2 +- .../pool/adaptive_avg_pool2d_backward.rs | 4 +- .../burn-cubecl/src/kernel/pool/avg_pool2d.rs | 4 +- .../src/kernel/pool/avg_pool2d_backward.rs | 6 +- .../burn-cubecl/src/kernel/pool/max_pool2d.rs | 16 +- .../src/kernel/pool/max_pool2d_backward.rs | 6 +- crates/burn-cubecl/src/kernel/pool/pool2d.rs | 2 +- crates/burn-cubecl/src/kernel/prng/uniform.rs | 4 +- .../src/kernel/quantization/dequantize.rs | 4 +- .../src/kernel/quantization/quantize.rs | 4 +- crates/burn-cubecl/src/kernel/reduce/base.rs | 14 +- crates/burn-cubecl/src/kernel/reduce/tune.rs | 6 +- crates/burn-cubecl/src/kernel/unary_float.rs | 5 +- crates/burn-cubecl/src/kernel/unary_int.rs | 5 +- .../burn-cubecl/src/kernel/unary_numeric.rs | 5 +- crates/burn-cubecl/src/kernel/utils.rs | 56 +- crates/burn-cubecl/src/ops/base.rs | 108 +- crates/burn-cubecl/src/ops/bool_tensor.rs | 5 +- crates/burn-cubecl/src/ops/int_tensor.rs | 5 +- crates/burn-cubecl/src/ops/module.rs | 47 +- crates/burn-cubecl/src/ops/numeric.rs | 21 +- crates/burn-cubecl/src/ops/qtensor.rs | 9 +- crates/burn-cubecl/src/ops/tensor.rs | 6 +- crates/burn-cubecl/src/ops/transaction.rs | 17 +- crates/burn-cubecl/src/template/base.rs | 6 +- crates/burn-cubecl/src/tensor/base.rs | 64 +- crates/burn-cubecl/src/tensor/quantization.rs | 25 +- crates/burn-ir/src/builder.rs | 6 +- crates/burn-std/Cargo.toml | 1 + crates/burn-std/src/errors.rs | 189 --- crates/burn-std/src/lib.rs | 3 +- .../burn-std/src/tensor/index_conversion.rs | 159 --- crates/burn-std/src/tensor/indexing.rs | 321 ----- crates/burn-std/src/tensor/mod.rs | 30 +- crates/burn-std/src/tensor/quantization.rs | 8 +- crates/burn-std/src/tensor/shape.rs | 1137 +---------------- crates/burn-tensor/src/tensor/api/base.rs | 4 +- .../hardware_accelerated.rs | 11 +- .../backends/cube/connected_components/mod.rs | 2 +- .../cube/connected_components/prefix_sum.rs | 8 +- examples/custom-cubecl-kernel/src/forward.rs | 8 +- examples/custom-wgpu-kernel/src/forward.rs | 8 +- 93 files changed, 566 insertions(+), 2319 deletions(-) delete mode 100644 crates/burn-std/src/errors.rs delete mode 100644 crates/burn-std/src/tensor/index_conversion.rs delete mode 100644 crates/burn-std/src/tensor/indexing.rs diff --git a/Cargo.lock b/Cargo.lock index 9a79d7b547..4b085e4dd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1120,6 +1120,7 @@ dependencies = [ "bytes", "cubecl", "cubecl-common", + "cubecl-zspace", "dashmap", "half", "indicatif 0.18.3", @@ -2302,6 +2303,7 @@ dependencies = [ "cfg_aliases", "cubecl-common", "cubecl-ir", + "cubecl-zspace", "derive-new", "derive_more", "dirs", @@ -2380,6 +2382,7 @@ dependencies = [ name = "cubecl-zspace" version = "0.10.0-pre.1" dependencies = [ + "derive-new", "serde", "smallvec", ] diff --git a/Cargo.toml b/Cargo.toml index 2a2242da31..849ff92a1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -185,6 +185,7 @@ portable-atomic-util = { version = "0.2.5", features = ["alloc"] } ### For local development. ### cubecl = { path = "../cubecl/crates/cubecl", default-features = false } cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false } cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### # cubecl = { version = "=0.10.0-pre.1", default-features = false } diff --git a/crates/burn-cubecl-fusion/src/base.rs b/crates/burn-cubecl-fusion/src/base.rs index e434c24a91..9dd8103ecb 100644 --- a/crates/burn-cubecl-fusion/src/base.rs +++ b/crates/burn-cubecl-fusion/src/base.rs @@ -1,5 +1,5 @@ use burn_fusion::stream::Context; -use burn_std::{DType, quantization::QParamTensor}; +use burn_std::{DType, Strides, quantization::QParamTensor, strides}; use cubecl::{ CubeElement, Runtime, client::ComputeClient, @@ -33,7 +33,7 @@ pub struct CubeFusionHandle { /// The element type of the tensor. pub dtype: DType, /// The strides of the tensor. - pub strides: Vec, + pub strides: Strides, /// Quantization runtime parameters, if applicable pub qparams: Option, } @@ -121,14 +121,14 @@ impl CubeFusionHandle { QuantParam::BF16 => DType::BF16, QuantParam::UE8M0 | QuantParam::UE4M3 => unimplemented!("Not yet supported"), }, - strides: qparams.scales.strides.clone(), + strides: qparams.scales.metadata.strides().clone(), qparams: None, }) } } -pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec { - let mut strides = vec![0; shape.len()]; +pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Strides { + let mut strides = strides![0; shape.len()]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { diff --git a/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs b/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs index 7aaf4f41e3..240efac5ac 100644 --- a/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs +++ b/crates/burn-cubecl-fusion/src/engine/codegen/ir.rs @@ -1,8 +1,9 @@ use super::tensor::GlobalTensor; use crate::engine::codegen::DYN_ELEM_ID; use burn_std::{ - DType, bf16, f16, + DType, Shape, Strides, bf16, f16, quantization::{QuantScheme, QuantStore, QuantValue}, + strides, }; use core::fmt::Display; use cubecl::{ @@ -416,15 +417,15 @@ impl GlobalArgsLaunch<'_, R> { /// # Panics /// /// If the argument doesn't have an handle. - pub fn shape(&self, arg: &FuseArg) -> Vec { + pub fn shape(&self, arg: &FuseArg) -> Shape { match self.resolve_arg(arg) { - TensorArg::Handle { handle, .. } => handle.shape.to_vec(), + TensorArg::Handle { handle, .. } => handle.shape.into(), TensorArg::Alias { .. } => panic!("Unsupported yet"), } } /// Shape used by the reference tensor. - pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Vec { + pub fn shape_ref(&self, ref_layout: &RefLayout, rank: usize) -> Shape { match ref_layout { RefLayout::Concrete(arg) => self.shape(arg), RefLayout::Virtual(layout) => match layout { @@ -459,20 +460,20 @@ impl GlobalArgsLaunch<'_, R> { /// # Panics /// /// If the argument doesn't have an handle. - pub fn strides(&self, arg: &FuseArg) -> Vec { + pub fn strides(&self, arg: &FuseArg) -> Strides { match self.resolve_arg(arg) { - TensorArg::Handle { handle, .. } => handle.strides.to_vec(), + TensorArg::Handle { handle, .. } => handle.strides.into(), TensorArg::Alias { .. } => panic!("Unsupported yet"), } } - pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Vec { + pub fn strides_ref(&self, ref_layout: &RefLayout, rank: usize) -> Strides { match ref_layout { RefLayout::Concrete(arg) => self.strides(arg), // When not concrete, we operate on the contiguous layout. _ => { let shape = self.shape_ref(ref_layout, rank); - let mut strides = vec![0; shape.len()]; + let mut strides = strides![0; shape.len()]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { diff --git a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs index 47ec0c6391..f069ae5be1 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs @@ -71,6 +71,8 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { return Ok(tune_output); } + let mut configs = Vec::with_capacity(plan.blocks.len()); + let mut inputs = GlobalArgsLaunch::default(); let mut outputs = GlobalArgsLaunch::default(); @@ -87,13 +89,11 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { for s in layout.shape.iter() { inputs.runtime_layouts.push(ScalarArg::new(*s)); } - for s in layout.strides { - inputs.runtime_layouts.push(ScalarArg::new(s)); + for s in layout.strides.iter() { + inputs.runtime_layouts.push(ScalarArg::new(*s)); } } - let mut configs = Vec::with_capacity(plan.blocks.len()); - for (block_plan, block) in plan.blocks.into_iter().zip(self.blocks) { let reference = match block_plan.reference { ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout), @@ -113,6 +113,8 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { RefLayout::Virtual(VirtualLayout::Runtime { pos }) } ReferenceSelection::Searching => { + drop(inputs); + drop(outputs); return Err(ExecutionError::new( TraceError::ReferenceNotFound, plan.handle_inputs, diff --git a/crates/burn-cubecl-fusion/src/engine/launch/input.rs b/crates/burn-cubecl-fusion/src/engine/launch/input.rs index f022bfcae4..9792b87b07 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/input.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/input.rs @@ -225,7 +225,11 @@ impl<'a, R: Runtime> InputPlanner<'a, R> { } if original == &tensor_relative.id { - let shape = tensor_relative.shape.clone().swap(dims.0, dims.1).unwrap(); + let shape = tensor_relative + .shape + .clone() + .swapped(dims.0, dims.1) + .unwrap(); if block_plan.potential_reference_input.is_none() && shape == block.shape_ref { block_plan.potential_reference_input = Some(InputReference::SwapDims { diff --git a/crates/burn-cubecl-fusion/src/engine/launch/output.rs b/crates/burn-cubecl-fusion/src/engine/launch/output.rs index e8e5b0cc53..e425ff4474 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/output.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/output.rs @@ -14,8 +14,11 @@ use crate::{ }; use burn_fusion::stream::Context; use burn_ir::{TensorId, TensorIr}; -use burn_std::tensor::{ReshapeAction, contiguous_strides, is_contiguous, reshape_action}; use burn_std::{DType, Shape}; +use burn_std::{ + Strides, + tensor::{ReshapeAction, contiguous_strides, is_contiguous, reshape_action}, +}; use cubecl::{CubeElement, Runtime, client::ComputeClient, ir::StorageType}; /// Create or reuse handles for the outputs. @@ -340,7 +343,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { .find(|(_pos, pi)| { pi.tensor_relative.dtype == tensor_global.dtype && pi.tensor_relative.shape == output.tensor_relative.shape - && pi.strides == strides + && &*pi.strides == strides && block.reference.compatible_strides_for_inplace(strides) }) .map(|(pos, _)| OutputKind::Inplace { input_pos: pos }) @@ -440,7 +443,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, - strides: Vec, + strides: Strides, block_idx: usize, ) { let block = &mut plan.blocks[block_idx]; @@ -525,7 +528,7 @@ impl<'a, R: Runtime> OutputPlanner<'a, R> { plan: &mut LaunchPlan<'a, R>, output: OutputSorted, tensor_global: TensorIr, - strides: Vec, + strides: Strides, original: TensorId, block_idx: usize, ) { diff --git a/crates/burn-cubecl-fusion/src/engine/launch/plan.rs b/crates/burn-cubecl-fusion/src/engine/launch/plan.rs index 4943e96662..1ba8c23277 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/plan.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/plan.rs @@ -7,7 +7,7 @@ use crate::{ }, }; use burn_ir::{TensorId, TensorIr}; -use burn_std::Shape; +use burn_std::{Shape, Strides}; use cubecl::{Runtime, ir::LineSize}; use std::collections::BTreeMap; @@ -81,7 +81,7 @@ pub enum ReferenceSelection { Concrete { layout: FuseArg, shape: Shape, - strides: Vec, + strides: Strides, }, /// Layout from a swapped dim tensor. SwapDims { @@ -94,7 +94,7 @@ pub enum ReferenceSelection { VirtualShape { original: FuseArg, shape: Shape, - strides: Vec, + strides: Strides, }, /// The layout is provided dynamically by the host at runtime. Runtime { pos: usize }, @@ -141,7 +141,7 @@ impl LaunchPlan<'_, R> { pub struct HandleOutputAliasDebugInfo { pub handle: CubeFusionHandle, pub relative_id: TensorId, - pub global_shape: Vec, + pub global_shape: Shape, } /// Represents the output of a fused kernel execution. @@ -178,7 +178,7 @@ pub struct NormalHandleInput { pub line_size: LineSize, pub broadcated: bool, /// Stores the original strides of the handle for restoration during plan rollback. - pub orig_strides: Vec, + pub orig_strides: Strides, } /// An input handle containing values for a quantized tensor. @@ -224,7 +224,7 @@ impl NormalHandleInput { tensor_relative: &TensorIr, precision: FuseType, mut handle: CubeFusionHandle, - mut strides: Vec, + mut strides: Strides, ) -> Self { // Swap current handle strides with provided strides to track the original state for rollback. core::mem::swap(&mut handle.strides, &mut strides); @@ -256,7 +256,7 @@ pub struct PotentialInplace<'a> { /// Reference to the IR of the relative tensor. pub tensor_relative: &'a TensorIr, /// Current strides of the potential in-place candidate. - pub strides: Vec, + pub strides: Strides, } impl ReferenceSelection { @@ -266,7 +266,7 @@ impl ReferenceSelection { pub fn compatible_strides_for_inplace(&self, strides_inplace: &[usize]) -> bool { match self { - ReferenceSelection::Concrete { strides, .. } => strides == strides_inplace, + ReferenceSelection::Concrete { strides, .. } => &**strides == strides_inplace, _ => false, } } diff --git a/crates/burn-cubecl-fusion/src/engine/trace/base.rs b/crates/burn-cubecl-fusion/src/engine/trace/base.rs index 8e4bccdb3f..7281a85355 100644 --- a/crates/burn-cubecl-fusion/src/engine/trace/base.rs +++ b/crates/burn-cubecl-fusion/src/engine/trace/base.rs @@ -3,7 +3,7 @@ use crate::engine::{ trace::block::FuseBlock, }; use burn_ir::{TensorId, TensorIr}; -use burn_std::Shape; +use burn_std::{Shape, Strides}; use cubecl::prelude::*; use serde::{Deserialize, Serialize}; use std::{ @@ -195,14 +195,14 @@ pub struct FuseResources { #[derive(Clone, Serialize, Deserialize, Debug)] pub struct RuntimeLayout { pub shape: Shape, - pub strides: Vec, + pub strides: Strides, } impl Default for RuntimeLayout { fn default() -> Self { Self { shape: Shape::new([]), - strides: Default::default(), + strides: Strides::new(&[]), } } } diff --git a/crates/burn-cubecl/src/fusion.rs b/crates/burn-cubecl/src/fusion.rs index 85293da81a..6404c7e0c2 100644 --- a/crates/burn-cubecl/src/fusion.rs +++ b/crates/burn-cubecl/src/fusion.rs @@ -19,6 +19,7 @@ use burn_fusion::{ stream::{Operation, OrderedExecution}, }; use burn_ir::{BackendIr, TensorHandle}; +use burn_std::Metadata; use core::marker::PhantomData; use std::sync::Arc; @@ -184,8 +185,7 @@ fn into_tensor(handle: CubeFusionHandle, shape: Shape) -> Cub client: handle.client, handle: handle.handle, device: handle.device, - shape, - strides: handle.strides, + meta: Box::new(Metadata::new(shape, handle.strides)), dtype: handle.dtype, qparams: handle.qparams, } @@ -197,7 +197,7 @@ impl From> for CubeFusionHandle { client: value.client, handle: value.handle, device: value.device, - strides: value.strides, + strides: value.meta.strides, dtype: value.dtype, qparams: value.qparams, } diff --git a/crates/burn-cubecl/src/kernel/attention/base.rs b/crates/burn-cubecl/src/kernel/attention/base.rs index b1aa0b9a9d..508fa5eba3 100644 --- a/crates/burn-cubecl/src/kernel/attention/base.rs +++ b/crates/burn-cubecl/src/kernel/attention/base.rs @@ -19,10 +19,10 @@ pub fn flash_attention( let client = &query.client; let device = &query.device; - let num_batches = query.shape[0]; - let num_heads = query.shape[1]; - let seq_q = query.shape[2]; - let val_dim = value.shape[3]; + let num_batches = query.meta.shape[0]; + let num_heads = query.meta.shape[1]; + let seq_q = query.meta.shape[2]; + let val_dim = value.meta.shape[3]; let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]); let out = empty_device_dtype::(client.clone(), device.clone(), out_shape, out_dtype); diff --git a/crates/burn-cubecl/src/kernel/binary.rs b/crates/burn-cubecl/src/kernel/binary.rs index edc47c7707..7b137a8eee 100644 --- a/crates/burn-cubecl/src/kernel/binary.rs +++ b/crates/burn-cubecl/src/kernel/binary.rs @@ -6,7 +6,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; -use burn_backend::{bf16, f16}; +use burn_backend::{TensorMetadata, bf16, f16}; use cubecl::{ calculate_cube_count_elemwise, intrinsic, prelude::*, std::tensor::layout::linear::LinearView, }; @@ -254,7 +254,7 @@ pub(crate) fn launch_scalar_binop( // Vectorization is only enabled when the last dimension is contiguous. let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let dtype = tensor.dtype; let working_units = num_elems / line_size as usize; @@ -280,7 +280,7 @@ pub(crate) fn launch_scalar_binop( let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), dtype, ); diff --git a/crates/burn-cubecl/src/kernel/binary_int.rs b/crates/burn-cubecl/src/kernel/binary_int.rs index 1f7272d378..89a50c0640 100644 --- a/crates/burn-cubecl/src/kernel/binary_int.rs +++ b/crates/burn-cubecl/src/kernel/binary_int.rs @@ -6,6 +6,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; +use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { @@ -181,7 +182,7 @@ pub(crate) fn launch_scalar_binop_int( ) -> CubeTensor { let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.shape.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); @@ -206,7 +207,7 @@ pub(crate) fn launch_scalar_binop_int( let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), tensor.dtype, ); diff --git a/crates/burn-cubecl/src/kernel/cast/base.rs b/crates/burn-cubecl/src/kernel/cast/base.rs index c189242b40..0c9931fd2c 100644 --- a/crates/burn-cubecl/src/kernel/cast/base.rs +++ b/crates/burn-cubecl/src/kernel/cast/base.rs @@ -4,7 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; -use burn_backend::DType; +use burn_backend::{DType, TensorMetadata}; use cubecl::std::tensor::layout::linear::LinearView; use cubecl::{calculate_cube_count_elemwise, prelude::*}; @@ -42,7 +42,7 @@ pub fn cast(input: CubeTensor, dtype: DType) -> CubeTensor let line_size = max_line_size(&input); - let num_elems: usize = input.shape.num_elements(); + let num_elems: usize = input.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&client, working_units); @@ -51,7 +51,7 @@ pub fn cast(input: CubeTensor, dtype: DType) -> CubeTensor let output = empty_device_dtype( client.clone(), input.device.clone(), - input.shape.clone(), + input.shape(), dtype, // We take the same dtype as passed as input (Flex32 not F32) ); diff --git a/crates/burn-cubecl/src/kernel/cast/bool_cast.rs b/crates/burn-cubecl/src/kernel/cast/bool_cast.rs index d8ee4fa484..f024f96a4f 100644 --- a/crates/burn-cubecl/src/kernel/cast/bool_cast.rs +++ b/crates/burn-cubecl/src/kernel/cast/bool_cast.rs @@ -4,6 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device}, tensor::CubeTensor, }; +use burn_backend::TensorMetadata; use cubecl::{ CubeDim, calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView, }; @@ -28,14 +29,11 @@ fn bool_cast_kernel( /// it may hold an uncanny bit combination. Naively casting it would not /// necessarily yield 0 or 1. pub fn bool_cast(tensor: CubeTensor) -> CubeTensor { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); + let output = + empty_device::(tensor.client.clone(), tensor.device.clone(), tensor.shape()); let line_size = max_line_size(&tensor); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/comparison.rs b/crates/burn-cubecl/src/kernel/comparison.rs index d14dc415b7..de1851c3e6 100644 --- a/crates/burn-cubecl/src/kernel/comparison.rs +++ b/crates/burn-cubecl/src/kernel/comparison.rs @@ -6,7 +6,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; -use burn_backend::DType; +use burn_backend::{DType, TensorMetadata}; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; #[cube] @@ -150,14 +150,7 @@ pub(crate) fn launch_cmp( .expect("Kernel to never fail"); } - CubeTensor::new( - lhs.client, - lhs.handle, - lhs.shape, - lhs.device, - lhs.strides, - dtype_bool, - ) + CubeTensor::new(lhs.client, lhs.handle, *lhs.meta, lhs.device, dtype_bool) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { unsafe { kernel_cmp::launch_unchecked::( @@ -173,14 +166,7 @@ pub(crate) fn launch_cmp( .expect("Kernel to never fail"); }; - CubeTensor::new( - rhs.client, - rhs.handle, - rhs.shape, - rhs.device, - rhs.strides, - dtype_bool, - ) + CubeTensor::new(rhs.client, rhs.handle, *rhs.meta, rhs.device, dtype_bool) } else { let output = empty_device_dtype( lhs.client.clone(), @@ -214,7 +200,7 @@ pub(crate) fn launch_scalar_cmp( ) -> CubeTensor { let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); @@ -241,16 +227,15 @@ pub(crate) fn launch_scalar_cmp( CubeTensor::new( tensor.client, tensor.handle, - tensor.shape, + *tensor.meta, tensor.device, - tensor.strides, dtype_bool, ) } else { let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), dtype_bool, ); @@ -408,7 +393,7 @@ pub(crate) fn launch_predicate( let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let dtypes = [tensor.dtype.into(), dtype_bool.into()]; let working_units = num_elems / line_size as usize; @@ -418,7 +403,7 @@ pub(crate) fn launch_predicate( let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), dtype_bool, ); diff --git a/crates/burn-cubecl/src/kernel/contiguous.rs b/crates/burn-cubecl/src/kernel/contiguous.rs index 4d5ae016c7..c2cf6935c7 100644 --- a/crates/burn-cubecl/src/kernel/contiguous.rs +++ b/crates/burn-cubecl/src/kernel/contiguous.rs @@ -1,4 +1,4 @@ -use burn_backend::{DType, QTensorPrimitive}; +use burn_backend::{DType, QTensorPrimitive, TensorMetadata}; use cubecl::quant::scheme::{QuantStore, QuantValue}; use cubecl::server::AllocationKind; @@ -24,9 +24,8 @@ pub fn into_contiguous(tensor: CubeTensor) -> CubeTensor { CubeTensor::new( tensor.client, output.handle, - output.shape.into(), + *output.metadata, tensor.device, - output.strides, tensor.dtype, ) } @@ -38,7 +37,7 @@ pub fn into_contiguous(tensor: CubeTensor) -> CubeTensor { tracing::instrument(level = "trace", skip(tensor)) )] pub fn into_contiguous_aligned(tensor: CubeTensor) -> CubeTensor { - if R::can_read_tensor(&tensor.shape, &tensor.strides) { + if R::can_read_tensor(tensor.meta.shape(), tensor.meta.strides()) { return tensor; } @@ -56,9 +55,8 @@ pub fn into_contiguous_aligned(tensor: CubeTensor) -> CubeTen CubeTensor::new( tensor.client, output.handle, - output.shape.into(), + *output.metadata, tensor.device, - output.strides, tensor.dtype, ) } @@ -72,7 +70,7 @@ fn into_contiguous_quantized( kind: AllocationKind, ) -> CubeTensor { let scheme = tensor.scheme(); - let output = empty_qtensor(tensor.shape.clone(), *tensor.scheme(), &tensor.device, kind); + let output = empty_qtensor(tensor.shape(), *tensor.scheme(), &tensor.device, kind); let (values, scales) = tensor.quantized_handles().unwrap(); let (out_values, out_scales) = output.quantized_handles().unwrap(); @@ -83,7 +81,7 @@ fn into_contiguous_quantized( &values.as_handle_ref(), &out_values.as_handle_ref(), packed_dim, - &tensor.shape, + tensor.meta.shape(), scheme.num_quants(), DType::U32.into(), ) @@ -97,7 +95,7 @@ fn into_contiguous_quantized( &values.as_handle_ref(), &out_values.as_handle_ref(), packed_dim, - &tensor.shape, + tensor.meta.shape(), scheme.num_quants(), DType::U8.into(), ) diff --git a/crates/burn-cubecl/src/kernel/conv/backward_data/fallback.rs b/crates/burn-cubecl/src/kernel/conv/backward_data/fallback.rs index 708c087788..b289c1c3e8 100644 --- a/crates/burn-cubecl/src/kernel/conv/backward_data/fallback.rs +++ b/crates/burn-cubecl/src/kernel/conv/backward_data/fallback.rs @@ -20,9 +20,9 @@ pub(crate) fn conv_data_backward_fallback( ) -> Result, ConvSetupError> { let dim_c = out_grad.rank(); - let kernel_size = &weights.shape[1..dim_c]; + let kernel_size = &weights.meta.shape()[1..dim_c]; let in_shape = &in_shape[1..dim_c]; - let out_shape = &out_grad.shape[1..dim_c]; + let out_shape = &out_grad.meta.shape()[1..dim_c]; let mut padding_out = [0; N_DIM]; diff --git a/crates/burn-cubecl/src/kernel/conv/backward_data/tune.rs b/crates/burn-cubecl/src/kernel/conv/backward_data/tune.rs index 5688e333e2..7805bf6c0a 100644 --- a/crates/burn-cubecl/src/kernel/conv/backward_data/tune.rs +++ b/crates/burn-cubecl/src/kernel/conv/backward_data/tune.rs @@ -97,14 +97,14 @@ fn create_key( options: &ConvOptions, ) -> CubeAutotuneKey { let dtype = out_grad.dtype; - let rank = out_grad.shape.num_dims(); + let rank = out_grad.meta.num_dims(); let dim_c = rank - 1; - let batch_size = out_grad.shape[0]; + let batch_size = out_grad.meta.shape()[0]; let in_channels = input_shape[dim_c]; - let out_channels = out_grad.shape[dim_c]; + let out_channels = out_grad.meta.shape()[dim_c]; - let kernel_size = weights.shape[1..dim_c].to_vec(); + let kernel_size = weights.meta.shape()[1..dim_c].to_vec(); let in_shape = input_shape[1..dim_c] .iter() .map(|shape| anchor(*shape, None, None, None)) @@ -117,14 +117,14 @@ fn create_key( groups, } = options.clone(); - let lhs_stride_align = if out_grad.strides[dim_c] == 1 { - stride_align(&out_grad.strides, out_grad.dtype.into()) + let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 { + stride_align(out_grad.meta.strides(), out_grad.dtype.into()) } else { 0 }; let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align); - let rhs_stride_align = if weights.strides[dim_c] == 1 { - stride_align(&weights.strides, weights.dtype.into()) + let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 { + stride_align(weights.meta.strides(), weights.dtype.into()) } else { 0 }; diff --git a/crates/burn-cubecl/src/kernel/conv/backward_weight/fallback.rs b/crates/burn-cubecl/src/kernel/conv/backward_weight/fallback.rs index 11bf2c9ba8..6ddd9f2ab3 100644 --- a/crates/burn-cubecl/src/kernel/conv/backward_weight/fallback.rs +++ b/crates/burn-cubecl/src/kernel/conv/backward_weight/fallback.rs @@ -40,7 +40,7 @@ fn conv_weight_grad_no_groups( Default::default(), )?; let mut weight_grad = swap_dims(weight_grad_swapped, 0, dim_c); - if weight_grad.shape != weight_shape { + if weight_grad.shape() != weight_shape { let ranges = weight_shape.iter().map(|&s| 0..s).collect::>(); weight_grad = slice(weight_grad, &ranges); } @@ -72,7 +72,7 @@ fn conv_weight_grad_groups( let kernel_size = &weight_shape[1..dim_c]; let kernel_size_slice = kernel_size.iter().map(|&s| 0..s).collect::>(); - let increment_ci = weight_grad.shape[dim_c]; + let increment_ci = weight_grad.meta.shape()[dim_c]; for g in 0..options.groups { let start_idx_ci = g * increment_ci; @@ -91,7 +91,7 @@ fn conv_weight_grad_groups( Default::default(), )?; let mut weight_grad_tmp = swap_dims(weight_grad_tmp, 0, dim_c); - let kernel_size_tmp = &weight_grad_tmp.shape[1..dim_c]; + let kernel_size_tmp = &weight_grad_tmp.meta.shape()[1..dim_c]; if kernel_size != kernel_size_tmp { let mut slices = vec![0..increment_co]; diff --git a/crates/burn-cubecl/src/kernel/conv/backward_weight/tune.rs b/crates/burn-cubecl/src/kernel/conv/backward_weight/tune.rs index b9c8330bd6..f82f3088e2 100644 --- a/crates/burn-cubecl/src/kernel/conv/backward_weight/tune.rs +++ b/crates/burn-cubecl/src/kernel/conv/backward_weight/tune.rs @@ -100,15 +100,15 @@ fn create_key( options: &ConvOptions, ) -> CubeAutotuneKey { let dtype = input.dtype; - let rank = input.shape.num_dims(); + let rank = input.meta.num_dims(); let dim_c = rank - 1; - let batch_size = input.shape[0]; - let in_channels = input.shape[dim_c]; + let batch_size = input.meta.shape()[0]; + let in_channels = input.meta.shape()[dim_c]; let out_channels = weight_shape[0]; let kernel_size = weight_shape[1..dim_c].to_vec(); - let in_shape = input.shape[1..dim_c] + let in_shape = input.meta.shape()[1..dim_c] .iter() .map(|shape| anchor(*shape, None, None, None)) .collect(); @@ -120,14 +120,14 @@ fn create_key( groups, } = options.clone(); - let lhs_stride_align = if out_grad.strides[dim_c] == 1 { - stride_align(&out_grad.strides, out_grad.dtype.into()) + let lhs_stride_align = if out_grad.meta.strides()[dim_c] == 1 { + stride_align(out_grad.meta.strides(), out_grad.dtype.into()) } else { 0 }; let lhs_shape_align = pow2_factor(out_channels).min(lhs_stride_align); - let rhs_stride_align = if input.strides[dim_c] == 1 { - stride_align(&input.strides, input.dtype.into()) + let rhs_stride_align = if input.meta.strides()[dim_c] == 1 { + stride_align(input.meta.strides(), input.dtype.into()) } else { 0 }; diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs index c46521777d..df237069af 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/col2im.rs @@ -33,8 +33,8 @@ pub fn conv_transpose2d_col2im( bias: Option>, options: ConvTransposeOptions<2>, ) -> Result, ConvSetupError> { - let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims(); - let [batch_size, _, input_h, input_w] = input.shape.dims(); + let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.meta.shape().dims(); + let [batch_size, _, input_h, input_w] = input.meta.shape().dims(); let groups = options.groups; let input_ch_per_group = input_channels / groups; let ConvTransposeOptions { @@ -135,12 +135,11 @@ pub fn conv_transpose2d_col2im( pub(crate) fn index(tensor: CubeTensor, i: usize) -> CubeTensor { #[allow(clippy::single_range_in_vec_init)] let mut indices = vec![i..i + 1]; - for dim in tensor.shape[1..].iter() { + for dim in tensor.meta.shape()[1..].iter() { indices.push(0..*dim); } let mut tensor = slice(tensor, &indices); - tensor.shape.remove(0); - tensor.strides.remove(0); + tensor.meta.remove(0); tensor } @@ -154,8 +153,8 @@ fn execute( kernel_h: usize, kernel_w: usize, ) -> Result<(), ConvSetupError> { - let [batch_size, _, input_h, input_w] = input.shape.dims(); - let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims(); + let [batch_size, _, input_h, input_w] = input.meta.shape().dims(); + let [groups, col_shape_0, input_ch_per_group] = weight.meta.shape().dims(); let col_shape_1 = batch_size * input_h * input_w; @@ -190,7 +189,7 @@ fn col2im( let columns = into_contiguous_aligned(columns); let bias = bias.map(into_contiguous_aligned); - let num_elems = out.shape.num_elements(); + let num_elems = out.meta.num_elements(); let cube_dim = CubeDim::new(&columns.client, num_elems); let cube_count = calculate_cube_count_elemwise(&columns.client, num_elems, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs index aa8d579bee..44c9014ac2 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/transpose_direct.rs @@ -134,8 +134,8 @@ pub fn conv_transpose2d_direct( bias: Option>, options: ConvTransposeOptions<2>, ) -> Result, ConvSetupError> { - let [batch_size, _, in_height, in_width] = input.shape.dims(); - let [_, out_channels, kernel_0, kernel_1] = weight.shape.dims(); + let [batch_size, _, in_height, in_width] = input.meta.shape().dims(); + let [_, out_channels, kernel_0, kernel_1] = weight.meta.shape().dims(); let out_0 = (in_height - 1) * options.stride[0] + options.dilation[0] * (kernel_0 - 1) @@ -157,7 +157,7 @@ pub fn conv_transpose2d_direct( input.dtype, ); - let num_elems = output.shape.num_elements(); + let num_elems = output.meta.num_elements(); let cube_dim = CubeDim::new(&input.client, num_elems); let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/tune.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/tune.rs index 8fc0736871..ef832ea3c1 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/tune.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose2d/tune.rs @@ -64,8 +64,8 @@ fn create_key( bias: &Option>, options: &ConvTransposeOptions<2>, ) -> CubeAutotuneKey { - let [batch_size, in_channels, height, width] = input.shape.dims(); - let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims(); + let [batch_size, in_channels, height, width] = input.meta.shape().dims(); + let [out_channels, _, kernel_h, kernel_w] = weights.meta.shape().dims(); let ConvTransposeOptions { stride, padding, diff --git a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs index 9c105d7165..dfd2978c26 100644 --- a/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-cubecl/src/kernel/conv/conv_transpose3d.rs @@ -157,8 +157,8 @@ pub(crate) fn conv_transpose3d( bias: Option>, options: ConvTransposeOptions<3>, ) -> Result, LaunchError> { - let [batch_size, _, in_depth, in_height, in_width] = input.shape.dims(); - let [_, out_channels, kernel_0, kernel_1, kernel_2] = weight.shape.dims(); + let [batch_size, _, in_depth, in_height, in_width] = input.meta.shape().dims(); + let [_, out_channels, kernel_0, kernel_1, kernel_2] = weight.meta.shape().dims(); let out_0 = (in_depth - 1) * options.stride[0] + options.dilation[0] * (kernel_0 - 1) @@ -191,7 +191,7 @@ pub(crate) fn conv_transpose3d( input.dtype, ); - let num_elems = output.shape.num_elements(); + let num_elems = output.meta.num_elements(); let cube_dim = CubeDim::new(&input.client, num_elems); let cube_count = calculate_cube_count_elemwise(&input.client, num_elems, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs index e8a9b157f9..7ac82daf29 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv2d.rs @@ -203,7 +203,7 @@ pub(crate) fn deform_im2col( let device = input.device.clone(); let dtype = input.dtype; - let [batch_size, in_channels, _, _] = input.shape.dims(); + let [batch_size, in_channels, _, _] = input.meta.shape().dims(); let (out_height, out_width) = out_dims; let (kernel_height, kernel_width) = kernel_dims; @@ -274,8 +274,8 @@ pub(crate) fn deform_conv2d( let mask = mask.map(|it| into_contiguous_aligned(it)); let bias = bias.map(|it| into_contiguous_aligned(it)); - let [batch_size, _, in_height, in_width] = input.shape.dims(); - let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); + let [batch_size, _, in_height, in_width] = input.meta.shape().dims(); + let [out_channels, _, kernel_h, kernel_w] = weight.meta.shape().dims(); let groups = options.weight_groups; let out_h = calculate_conv_output_size( @@ -296,7 +296,7 @@ pub(crate) fn deform_conv2d( let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w))?; - let [col_size_0, col_size_1] = columns.shape.dims(); + let [col_size_0, col_size_1] = columns.meta.shape().dims(); let col_size_0 = col_size_0 / groups; let out_c_per_group = out_channels / groups; diff --git a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs index fab94ffd3f..992dbb630a 100644 --- a/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs @@ -14,7 +14,7 @@ use crate::{ }, tensor::CubeTensor, }; -use burn_backend::{DType, Shape, ops::DeformConvOptions}; +use burn_backend::{DType, Shape, TensorMetadata, ops::DeformConvOptions}; use cubecl::{ CubeDim, CubeLaunch, calculate_cube_count_elemwise, cube, features::TypeUsage, @@ -55,8 +55,8 @@ pub(crate) fn deform_conv2d_backward( ), ConvSetupError, > { - let [_, _, out_h, out_w] = out_grad.shape.dims(); - let [_, _, kernel_h, kernel_w] = weight.shape.dims(); + let [_, _, out_h, out_w] = out_grad.meta.shape().dims(); + let [_, _, kernel_h, kernel_w] = weight.meta.shape().dims(); let gradient_bias = bias.map(|bias| { let grad = reduce_dim( @@ -84,7 +84,7 @@ pub(crate) fn deform_conv2d_backward( ) .unwrap(); - reshape(grad, bias.shape) + reshape(grad, bias.meta.shape) }); let input = into_contiguous_aligned(input); @@ -130,8 +130,8 @@ fn compute_weight_grad( kernel_dims: (usize, usize), out_dims: (usize, usize), ) -> Result, ConvSetupError> { - let [_, in_channels, _, _] = input.shape.dims(); - let [_, out_channels, _, _] = out_grad.shape.dims(); + let [_, in_channels, _, _] = input.meta.shape().dims(); + let [_, out_channels, _, _] = out_grad.meta.shape().dims(); let (kernel_h, kernel_w) = kernel_dims; let groups = options.weight_groups; let dtype = input.dtype; @@ -140,7 +140,7 @@ fn compute_weight_grad( let out_c_per_group = out_channels / groups; let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims)?; - let [col_size_0, col_size_1] = columns.shape.dims(); + let [col_size_0, col_size_1] = columns.meta.shape().dims(); let col_size_0 = col_size_0 / groups; let out_grad = swap_dims(out_grad, 0, 1); @@ -171,8 +171,8 @@ fn backward_gradient_inputs( let client = out_grad.client.clone(); let device = out_grad.device.clone(); - let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape.dims(); - let [batch_size, _, out_h, out_w] = out_grad.shape.dims(); + let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.meta.shape().dims(); + let [batch_size, _, out_h, out_w] = out_grad.meta.shape().dims(); let groups = options.weight_groups; let out_c_per_group = out_channels / groups; @@ -207,7 +207,7 @@ fn backward_gradient_inputs( let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); - let input_shape = image.shape.clone(); + let input_shape = image.shape(); let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient( columns.clone(), image, @@ -235,7 +235,7 @@ fn compute_offset_and_mask_gradient( let device = offset.device.clone(); let (kernel_h, kernel_w) = kernel_dims; - let [batches, _, out_h, out_w] = offset.shape.dims(); + let [batches, _, out_h, out_w] = offset.meta.shape().dims(); let offset_groups = options.offset_groups; let pos_shape = [batches, offset_groups, kernel_h, kernel_w, 2, out_h, out_w]; @@ -244,22 +244,13 @@ fn compute_offset_and_mask_gradient( .map(|s| FastDivmodArgs::new(&client, s)) .collect(); - let grad_offset = empty_device_dtype( - client.clone(), - device.clone(), - offset.shape.clone(), - offset.dtype, - ); - let grad_mask = mask.as_ref().map(|mask| { - empty_device_dtype( - client.clone(), - device.clone(), - mask.shape.clone(), - mask.dtype, - ) - }); + let grad_offset = + empty_device_dtype(client.clone(), device.clone(), offset.shape(), offset.dtype); + let grad_mask = mask + .as_ref() + .map(|mask| empty_device_dtype(client.clone(), device.clone(), mask.shape(), mask.dtype)); - let num_elements_offset = offset.shape.num_elements(); + let num_elements_offset = offset.meta.num_elements(); let cube_dim = CubeDim::new(&image.client, num_elements_offset); let cube_count = calculate_cube_count_elemwise(&image.client, num_elements_offset, cube_dim); @@ -502,7 +493,7 @@ fn compute_input_grad( .contains(TypeUsage::AtomicAdd); let [batches, in_channels, height, width] = input_shape.dims(); - let [_, _, out_h, out_w] = offset.shape.dims(); + let [_, _, out_h, out_w] = offset.meta.shape().dims(); let (kernel_h, kernel_w) = kernel_dims; let pos_shape = [in_channels, kernel_h, kernel_w, batches, out_h, out_w]; @@ -520,7 +511,7 @@ fn compute_input_grad( }; let grad_arg = grad_in.as_tensor_arg(1); - let num_elements = columns.shape.num_elements(); + let num_elements = columns.meta.num_elements(); let cube_dim = CubeDim::new(&offset.client, num_elements); let cube_count = calculate_cube_count_elemwise(&offset.client, num_elements, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/conv/direct.rs b/crates/burn-cubecl/src/kernel/conv/direct.rs index 213c189fc2..ac4f48ad26 100644 --- a/crates/burn-cubecl/src/kernel/conv/direct.rs +++ b/crates/burn-cubecl/src/kernel/conv/direct.rs @@ -8,7 +8,10 @@ use crate::{ tensor::CubeTensor, }; use crate::{kernel::utils::decompose_linear, ops::numeric::empty_device_dtype}; -use burn_backend::ops::{ConvOptions, conv::calculate_conv_output_sizes}; +use burn_backend::{ + TensorMetadata, + ops::{ConvOptions, conv::calculate_conv_output_sizes}, +}; use cubecl::std::{CubeOption, CubeOptionExpand, FastDivmod, FastDivmodArgs}; use cubecl::{ calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView, @@ -227,21 +230,21 @@ pub fn conv_direct( ) -> Result, ConvSetupError> { let client = input.client.clone(); let out_dtype = input.dtype; - let rank = input.shape.num_dims(); + let rank = input.meta.shape().num_dims(); let dim_c = rank - 1; // We only care about the channels here, everything else can be permuted - if input.strides[dim_c] != 1 { + if input.meta.strides()[dim_c] != 1 { input = into_contiguous_aligned(input); } - if weight.strides[dim_c] != 1 { + if weight.meta.strides()[dim_c] != 1 { weight = into_contiguous_aligned(weight); } - let batch_size = input.shape[0]; - let in_shape = &input.shape[1..dim_c]; - let out_channels = weight.shape[0]; - let kernel_shape = &weight.shape[1..dim_c]; + let batch_size = input.meta.shape()[0]; + let in_shape = &input.meta.shape()[1..dim_c]; + let out_channels = weight.meta.shape()[0]; + let kernel_shape = &weight.meta.shape()[1..dim_c]; let channels_per_group = out_channels / options.groups; @@ -266,18 +269,18 @@ pub fn conv_direct( // Need custom line size calculation here to account for the groups division. Need to vectorize // over `channels_per_group` instead. - let mut grouped_out_shape = output.shape.clone(); + let mut grouped_out_shape = output.shape(); grouped_out_shape[dim_c] = channels_per_group; let line_size_out = tensor_line_size_parallel( input.client.io_optimized_line_sizes(input.dtype.size()), &grouped_out_shape, - &output.strides, + output.meta.strides(), dim_c, ); // Use channels_per_group instead of in_channels to avoid issues here let line_size_in = max_line_size(&weight); - let shape_out = output.shape[1..dim_c] + let shape_out = output.meta.shape()[1..dim_c] .iter() .map(|s| FastDivmodArgs::::new(&client, *s as u32)) .collect(); @@ -293,7 +296,7 @@ pub fn conv_direct( )); } - let working_units = output.shape.num_elements() / line_size_out; + let working_units = output.meta.num_elements() / line_size_out; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs b/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs index 2e2de4f686..4637fcd5fb 100644 --- a/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs +++ b/crates/burn-cubecl/src/kernel/conv/forward/implicit_gemm/launch.rs @@ -109,13 +109,13 @@ pub fn launch_convolution_forward( } let out_dtype = input.dtype; - let rank = input.shape.num_dims(); - let batch_size = input.shape[0]; + let rank = input.meta.shape().num_dims(); + let batch_size = input.meta.shape()[0]; let dim_c = rank - 1; - let shape = &input.shape[1..dim_c]; + let shape = &input.meta.shape()[1..dim_c]; - let out_channels = weight.shape[0]; - let weight_shape = &weight.shape[1..dim_c]; + let out_channels = weight.meta.shape()[0]; + let weight_shape = &weight.meta.shape()[1..dim_c]; let mut out_shape = calculate_conv_output_sizes( weight_shape, diff --git a/crates/burn-cubecl/src/kernel/conv/forward/tune.rs b/crates/burn-cubecl/src/kernel/conv/forward/tune.rs index c63d860211..e7d4e70071 100644 --- a/crates/burn-cubecl/src/kernel/conv/forward/tune.rs +++ b/crates/burn-cubecl/src/kernel/conv/forward/tune.rs @@ -99,15 +99,15 @@ fn create_key( options: &ConvOptions, ) -> CubeAutotuneKey { let dtype = input.dtype; - let rank = input.shape.num_dims(); + let rank = input.meta.shape().num_dims(); let dim_c = rank - 1; - let batch_size = input.shape[0]; - let in_channels = input.shape[dim_c]; - let out_channels = weights.shape[0]; + let batch_size = input.meta.shape()[0]; + let in_channels = input.meta.shape()[dim_c]; + let out_channels = weights.meta.shape()[0]; - let kernel_size = weights.shape[1..dim_c].to_vec(); - let in_shape = input.shape[1..dim_c] + let kernel_size = weights.meta.shape()[1..dim_c].to_vec(); + let in_shape = input.meta.shape()[1..dim_c] .iter() .map(|shape| anchor(*shape, None, None, None)) .collect(); @@ -119,14 +119,14 @@ fn create_key( groups, } = options.clone(); - let lhs_stride_align = if input.strides[dim_c] == 1 { - stride_align(&input.strides, input.dtype.into()) + let lhs_stride_align = if input.meta.strides()[dim_c] == 1 { + stride_align(input.meta.strides(), input.dtype.into()) } else { 0 }; let lhs_shape_align = pow2_factor(in_channels).min(lhs_stride_align); - let rhs_stride_align = if weights.strides[dim_c] == 1 { - stride_align(&weights.strides, weights.dtype.into()) + let rhs_stride_align = if weights.meta.strides()[dim_c] == 1 { + stride_align(weights.meta.strides(), weights.dtype.into()) } else { 0 }; diff --git a/crates/burn-cubecl/src/kernel/conv/im2col.rs b/crates/burn-cubecl/src/kernel/conv/im2col.rs index 2ea0a24a3e..f1128581b4 100644 --- a/crates/burn-cubecl/src/kernel/conv/im2col.rs +++ b/crates/burn-cubecl/src/kernel/conv/im2col.rs @@ -2,7 +2,7 @@ use burn_backend::{ DType, ops::{ConvOptions, conv::calculate_conv_output_sizes}, }; -use burn_std::Shape; +use burn_std::Metadata; use core::iter; use cubecl::{ prelude::*, @@ -66,14 +66,14 @@ pub fn conv_im2col_1x1( return Err(ConvSetupError::Groups(options.groups)); } - let rank = input.shape.num_dims(); + let rank = input.meta.num_dims(); let dim_c = rank - 1; - let batch_size = input.shape[0]; - let in_channels = input.shape[dim_c]; - let in_shape = &input.shape[1..dim_c]; - let out_channels = weight.shape[0]; - let kernel_shape = &weight.shape[1..dim_c]; + let batch_size = input.meta.shape()[0]; + let in_channels = input.meta.shape()[dim_c]; + let in_shape = &input.meta.shape()[1..dim_c]; + let out_channels = weight.meta.shape()[0]; + let kernel_shape = &weight.meta.shape()[1..dim_c]; if kernel_shape.iter().any(|s| *s != 1) { return Err(ConvSetupError::Unknown); @@ -98,16 +98,17 @@ pub fn conv_im2col_1x1( let dtype = input.dtype; // Efficient permutation that takes the stride required for TMA into account - let weight = if weight.strides[dim_c] != 1 { + let weight = if weight.meta.strides()[dim_c] != 1 { // Remove kernel dims so padded dim is channels - weight.shape = Shape::new([out_channels, in_channels]); // [N, K] - weight.strides = vec![weight.strides[0], weight.strides[dim_c]]; + *weight.meta = Metadata::new( + [out_channels, in_channels], // [N, K] + [weight.meta.strides()[0], weight.meta.strides()[dim_c]], + ); // Pitched contiguous to skip running another kernel for TMA into_contiguous_aligned(weight) } else { // Already compatible, skip initial reshape - weight.shape = Shape::new([out_channels, in_channels]); // [N, K] - weight.strides = vec![weight.strides[0], 1]; + *weight.meta = Metadata::new([out_channels, in_channels], [weight.meta.strides()[0], 1]); weight }; @@ -132,22 +133,24 @@ pub fn conv_im2col_1x1( /// Reshapes NHWC input to [(N, H, W), C] fn reshape_input(mut input: CubeTensor) -> CubeTensor { - let rank = input.shape.num_dims(); + let rank = input.meta.num_dims(); let dim_c = rank - 1; let dtype = input.dtype; - let batch_size = input.shape[0]; - let in_c: usize = input.shape[dim_c]; - let in_shape = input.shape[1..dim_c].to_vec(); + let batch_size = input.meta.shape()[0]; + let in_c: usize = input.meta.shape()[dim_c]; + let in_shape = input.meta.shape()[1..dim_c].to_vec(); - if !is_spatial_contiguous(&input.shape, &input.strides) { + if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) { let contiguous = into_contiguous_pitched_ref(&input.client, &input.as_handle_ref(), dtype.into()) .expect("Kernel to never fail"); input = from_handle(&input.client, &input.device, contiguous, dtype); } - input.shape = Shape::new([batch_size * in_shape.iter().product::(), in_c]); // [M, K] - input.strides = vec![input.strides[dim_c - 1], input.strides[dim_c]]; + *input.meta = Metadata::new( + [batch_size * in_shape.iter().product::(), in_c], // [M, K] + [input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]], + ); input } @@ -177,9 +180,8 @@ fn from_handle( CubeTensor::new( client.clone(), handle.handle, - handle.shape.into(), + *handle.metadata, device.clone(), - handle.strides, dtype, ) } diff --git a/crates/burn-cubecl/src/kernel/cross.rs b/crates/burn-cubecl/src/kernel/cross.rs index 6a685bd023..21071b6a21 100644 --- a/crates/burn-cubecl/src/kernel/cross.rs +++ b/crates/burn-cubecl/src/kernel/cross.rs @@ -46,13 +46,15 @@ pub(crate) fn cross( rhs: CubeTensor, dim: usize, ) -> CubeTensor { - let ndims = lhs.shape.num_dims(); + let ndims = lhs.meta.num_dims(); // Validate that the cross dimension has size 3 - if lhs.shape[dim] != 3 || rhs.shape[dim] != 3 { + if lhs.meta.shape()[dim] != 3 || rhs.meta.shape()[dim] != 3 { panic!( "Cross product requires dimension {} to have size 3, but got {} and {}", - dim, lhs.shape[dim], rhs.shape[dim] + dim, + lhs.meta.shape()[dim], + rhs.meta.shape()[dim] ); } diff --git a/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs b/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs index c91d217828..fa7edc9556 100644 --- a/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/grid_sample/bilinear.rs @@ -131,8 +131,8 @@ pub(crate) fn grid_sample_bilinear_launch( grid: CubeTensor, options: GridSampleOptions, ) -> CubeTensor { - let [batch_size, channels, _h_in, _w_in] = input.shape.dims(); - let [_n, h_out, w_out, two] = grid.shape.dims(); + let [batch_size, channels, _h_in, _w_in] = input.meta.shape().dims(); + let [_n, h_out, w_out, two] = grid.meta.shape().dims(); assert_eq!(two, 2, "Grid last dimension must be 2"); // Create output tensor [N, C, H_out, W_out] diff --git a/crates/burn-cubecl/src/kernel/index/flip.rs b/crates/burn-cubecl/src/kernel/index/flip.rs index 5a1856252c..5bbfb0a97e 100644 --- a/crates/burn-cubecl/src/kernel/index/flip.rs +++ b/crates/burn-cubecl/src/kernel/index/flip.rs @@ -4,7 +4,7 @@ use crate::{ ops::numeric::empty_device_dtype, tensor::CubeTensor, }; -use burn_backend::DType; +use burn_backend::{DType, TensorMetadata}; use cubecl::{ calculate_cube_count_elemwise, prelude::*, @@ -53,7 +53,7 @@ pub(crate) fn flip( let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), tensor.dtype, ); flip_on_output(tensor, output, indices, dtype_bool) @@ -66,7 +66,7 @@ pub(crate) fn flip_on_output( dtype_bool: DType, ) -> CubeTensor { let dtype_input = tensor.dtype; - let ndims = tensor.shape.num_dims(); + let ndims = tensor.meta.num_dims(); let mut indices_sequence = SequenceArg::<'_, R, InputScalar>::new(); for i in 0..ndims { @@ -76,7 +76,7 @@ pub(crate) fn flip_on_output( }); } - let num_elements = output.shape.num_elements(); + let num_elements = output.meta.num_elements(); let cube_dim = CubeDim::new(&tensor.client, num_elements); let cube_count = calculate_cube_count_elemwise(&tensor.client, num_elements, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/index/gather.rs b/crates/burn-cubecl/src/kernel/index/gather.rs index 121862fc53..163f6e926f 100644 --- a/crates/burn-cubecl/src/kernel/index/gather.rs +++ b/crates/burn-cubecl/src/kernel/index/gather.rs @@ -4,6 +4,7 @@ use crate::{ ops::numeric::empty_device_dtype, tensor::CubeTensor, }; +use burn_backend::TensorMetadata; use cubecl::frontend::{ABSOLUTE_POS, Numeric, Tensor}; use cubecl::std::{FastDivmod, tensor::index_offset_contiguous_fastdivmod}; use cubecl::{CubeDim, std::tensor::layout::linear::LinearView}; @@ -40,7 +41,7 @@ pub(crate) fn gather( tensor: CubeTensor, indices: CubeTensor, ) -> CubeTensor { - let shape_output = indices.shape.clone(); + let shape_output = indices.shape(); let total_elem = shape_output.num_elements(); let output = empty_device_dtype( tensor.client.clone(), diff --git a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs index 73dbd8d158..452b46d7bf 100644 --- a/crates/burn-cubecl/src/kernel/index/repeat_dim.rs +++ b/crates/burn-cubecl/src/kernel/index/repeat_dim.rs @@ -53,13 +53,13 @@ pub(crate) fn repeat_dim( dim: usize, times: usize, ) -> CubeTensor { - if input.shape[dim] == 1 { - input.strides[dim] = 0; - input.shape = input.shape.repeat(dim, times).unwrap(); + if input.meta.shape()[dim] == 1 { + input.meta.strides[dim] = 0; + input.meta.shape = input.meta.shape.repeat(dim, times).unwrap(); return input; } - let shape = input.shape.clone().repeat(dim, times).unwrap(); + let shape = input.meta.shape.clone().repeat(dim, times).unwrap(); // Create output handle let output = empty_device_dtype( @@ -69,7 +69,7 @@ pub(crate) fn repeat_dim( input.dtype, ); - let working_units = output.shape.num_elements(); + let working_units = output.meta.num_elements(); let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); @@ -82,7 +82,7 @@ pub(crate) fn repeat_dim( input.as_tensor_arg(1), output.as_tensor_arg(1), shape_divmod(&output), - FastDivmodArgs::new(&input.client, input.shape[dim]), + FastDivmodArgs::new(&input.client, input.meta.shape()[dim]), dim, output.dtype.into(), ) diff --git a/crates/burn-cubecl/src/kernel/index/scatter.rs b/crates/burn-cubecl/src/kernel/index/scatter.rs index ad4cf16230..e80dedc3bf 100644 --- a/crates/burn-cubecl/src/kernel/index/scatter.rs +++ b/crates/burn-cubecl/src/kernel/index/scatter.rs @@ -79,7 +79,7 @@ pub(crate) fn scatter( false => tensor.copy(), }; - let num_elems = tensor.shape.num_elements() / tensor.shape[dim]; + let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim]; let working_units = num_elems; let cube_dim = CubeDim::new(&indices.client, working_units); diff --git a/crates/burn-cubecl/src/kernel/index/select.rs b/crates/burn-cubecl/src/kernel/index/select.rs index 6356bcaf64..14b40a826c 100644 --- a/crates/burn-cubecl/src/kernel/index/select.rs +++ b/crates/burn-cubecl/src/kernel/index/select.rs @@ -3,6 +3,7 @@ use crate::{ kernel::utils::{linear_view, shape_divmod}, ops::numeric::empty_device_dtype, }; +use burn_backend::TensorMetadata; use cubecl::{CubeDim, calculate_cube_count_elemwise, std::tensor::layout::linear::LinearView}; use cubecl::{prelude::*, std::FastDivmod}; @@ -47,8 +48,8 @@ pub(crate) fn select( dim: usize, indices: CubeTensor, ) -> CubeTensor { - let mut shape_output = tensor.shape.clone(); - shape_output[dim] = indices.shape[0]; + let mut shape_output = tensor.shape(); + shape_output[dim] = indices.meta.shape()[0]; let total_elem = shape_output.num_elements(); let output = empty_device_dtype( diff --git a/crates/burn-cubecl/src/kernel/index/select_assign.rs b/crates/burn-cubecl/src/kernel/index/select_assign.rs index f6cde5b823..821bcb2908 100644 --- a/crates/burn-cubecl/src/kernel/index/select_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/select_assign.rs @@ -67,7 +67,7 @@ pub(crate) fn select_assign( false => tensor.copy(), }; - let num_elems = tensor.shape.num_elements() / tensor.shape[dim]; + let num_elems = tensor.meta.num_elements() / tensor.meta.shape()[dim]; let working_units = num_elems; let cube_dim = CubeDim::new(&indices.client, working_units); let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/index/slice.rs b/crates/burn-cubecl/src/kernel/index/slice.rs index 3965958c42..f371e39852 100644 --- a/crates/burn-cubecl/src/kernel/index/slice.rs +++ b/crates/burn-cubecl/src/kernel/index/slice.rs @@ -4,7 +4,8 @@ use crate::{ ops::numeric::empty_device_dtype, tensor::CubeTensor, }; -use burn_backend::Slice; +use burn_backend::{Slice, TensorMetadata}; +use burn_std::{Metadata, SliceOps}; use cubecl::{ calculate_cube_count_elemwise, intrinsic, prelude::*, @@ -14,13 +15,13 @@ use std::ops::Range; /// Slice a jit tensor with a set of ranges pub fn slice(tensor: CubeTensor, indices: &[Range]) -> CubeTensor { - let mut dims = tensor.shape.clone(); + let mut dims = tensor.shape(); let mut offset_start = 0u64; let mut offset_end = 0u64; for i in 0..indices.len() { - offset_start += (tensor.strides[i] * indices[i].start) as u64; - offset_end += (tensor.strides[i] * (dims[i] - indices[i].end)) as u64; + offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64; + offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64; dims[i] = indices[i].end - indices[i].start; } @@ -38,9 +39,8 @@ pub fn slice(tensor: CubeTensor, indices: &[Range]) -> .handle .offset_start(offset_start) .offset_end(offset_end), - dims, + Metadata::new(dims, tensor.meta.strides), tensor.device, - tensor.strides, tensor.dtype, ) } else { @@ -92,7 +92,7 @@ pub(crate) fn slice_on_output( output: CubeTensor, indices: &[Range], ) -> CubeTensor { - let ndims = tensor.shape.num_dims(); + let ndims = tensor.meta.num_dims(); let mut indices_sequence = SequenceArg::::new(); for i in 0..ndims { @@ -100,7 +100,7 @@ pub(crate) fn slice_on_output( indices_sequence.push(ScalarArg::new(start)); } - let working_units = output.shape.num_elements(); + let working_units = output.meta.num_elements(); let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); @@ -179,13 +179,13 @@ pub fn slice_with_steps(tensor: CubeTensor, slices: &[Slice]) let simple_ranges: Vec> = slices .iter() .enumerate() - .map(|(i, slice)| slice.to_range(tensor.shape[i])) + .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); return slice(tensor, &simple_ranges); } // Calculate output shape - let shape_output = tensor.shape.clone().slice(slices).unwrap(); + let shape_output = tensor.shape().slice(slices).unwrap(); // Create output tensor let output = empty_device_dtype( @@ -201,16 +201,16 @@ pub fn slice_with_steps(tensor: CubeTensor, slices: &[Slice]) let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { - let range = slice.to_range(tensor.shape[dim]); + let range = slice.to_range(tensor.meta.shape()[dim]); starts.push(ScalarArg::new(range.start)); ends.push(ScalarArg::new(range.end)); steps.push(ScalarArg::new(slice.step as i32)); } // Pad with default values if needed to match tensor dimensions - for dim in slices.len()..tensor.shape.num_dims() { + for dim in slices.len()..tensor.meta.num_dims() { starts.push(ScalarArg::new(0)); - ends.push(ScalarArg::new(tensor.shape[dim])); + ends.push(ScalarArg::new(tensor.meta.shape()[dim])); steps.push(ScalarArg::new(1)); } diff --git a/crates/burn-cubecl/src/kernel/index/slice_assign.rs b/crates/burn-cubecl/src/kernel/index/slice_assign.rs index 9cffbdc05e..e0851b33cc 100644 --- a/crates/burn-cubecl/src/kernel/index/slice_assign.rs +++ b/crates/burn-cubecl/src/kernel/index/slice_assign.rs @@ -113,26 +113,27 @@ pub(crate) fn slice_assign( true => tensor, false => tensor.copy(), }; - let ndims = tensor.shape.num_dims(); + let ndims = tensor.meta.num_dims(); - let line_size = if tensor.strides[ndims - 1] == 1 && value.strides[ndims - 1] == 1 { + let line_size = if tensor.meta.strides()[ndims - 1] == 1 && value.meta.strides()[ndims - 1] == 1 + { let last = indices .get(ndims - 1) .cloned() .unwrap_or(burn_backend::Slice { start: 0, - end: Some(tensor.shape[ndims - 1] as isize), + end: Some(tensor.meta.shape()[ndims - 1] as isize), step: 1, }); - let end = last.end.unwrap_or(tensor.shape[ndims - 1] as isize); + let end = last.end.unwrap_or(tensor.meta.shape()[ndims - 1] as isize); let shape = (end - last.start) as usize; let offset = last.start as usize; client .io_optimized_line_sizes(tensor.dtype.size()) .filter(|&it| { shape.is_multiple_of(it) - && strides_compatible(&tensor.strides, it) - && strides_compatible(&value.strides, it) + && strides_compatible(tensor.meta.strides(), it) + && strides_compatible(value.meta.strides(), it) && offset.is_multiple_of(it) }) .max() @@ -147,18 +148,18 @@ pub(crate) fn slice_assign( for i in 0..ndims { let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice { start: 0, - end: Some(tensor.shape[i] as isize), + end: Some(tensor.meta.shape()[i] as isize), step: 1, }); let start = slice.start as usize; - let end = slice.end.unwrap_or(tensor.shape[i] as isize); + let end = slice.end.unwrap_or(tensor.meta.shape()[i] as isize); let length = (end - slice.start) as usize; shape.push(FastDivmodArgs::::new(&client, length)); offsets.push(ScalarArg::new(start)); } - let working_units = value.shape.num_elements() / line_size; + let working_units = value.meta.num_elements() / line_size; let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); @@ -205,21 +206,21 @@ pub(crate) fn slice_assign_with_steps( let mut steps = SequenceArg::::new(); for (dim, slice) in slices.iter().enumerate() { - let range = slice.to_range(tensor.shape[dim]); + let range = slice.to_range(tensor.meta.shape()[dim]); starts.push(ScalarArg::new(range.start)); ends.push(ScalarArg::new(range.end)); steps.push(ScalarArg::new(slice.step as i32)); } // Pad with default values if needed to match tensor dimensions - for dim in slices.len()..tensor.shape.num_dims() { + for dim in slices.len()..tensor.meta.num_dims() { starts.push(ScalarArg::new(0)); - ends.push(ScalarArg::new(tensor.shape[dim])); + ends.push(ScalarArg::new(tensor.meta.shape()[dim])); steps.push(ScalarArg::new(1)); } // Launch kernel - let working_units = value.shape.num_elements(); + let working_units = value.meta.num_elements(); let cube_dim = CubeDim::new(&tensor.client, working_units); let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/base.rs b/crates/burn-cubecl/src/kernel/interpolate/base.rs index 051d1b3da0..b89feec364 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/base.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/base.rs @@ -5,7 +5,7 @@ use crate::{ tensor::CubeTensor, }; use burn_backend::{ - Shape, + Shape, TensorMetadata, ops::{InterpolateMode, InterpolateOptions}, }; @@ -22,7 +22,7 @@ pub fn interpolate( output_size: [usize; 2], options: InterpolateOptions, ) -> CubeTensor { - let [batch_size, channels, _, _] = input.shape.dims(); + let [batch_size, channels, _, _] = input.meta.shape().dims(); let [out_height, out_width] = output_size; let input = into_contiguous(permute_nchw_to_nhwc(input)); @@ -57,7 +57,7 @@ pub fn interpolate_backward( let input = permute_nchw_to_nhwc(input); let out_grad = permute_nchw_to_nhwc(out_grad); - let output_shape = input.shape.clone(); + let output_shape = input.shape(); let output = empty_device_dtype( input.client.clone(), input.device.clone(), diff --git a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs index 6686cc80cf..3373de33f7 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bicubic.rs @@ -172,7 +172,7 @@ pub(crate) fn interpolate_bicubic_launch( let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, line_size); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs index 15aa1e19bb..ca7c65f99e 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/bilinear.rs @@ -127,7 +127,7 @@ pub(crate) fn interpolate_bilinear_launch( let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, line_size); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs index fc63650d87..50398f9856 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest.rs @@ -54,7 +54,7 @@ pub(crate) fn interpolate_nearest_launch( let line_size = max_line_size(&input); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs index bebea358de..b7b8cdc76a 100644 --- a/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-cubecl/src/kernel/interpolate/nearest_backward.rs @@ -80,7 +80,7 @@ pub(crate) fn interpolate_nearest_backward_launch( let out_shape = shape_divmod(&output); let out_layout = linear_layout(&output, line_size); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&out_grad.client, working_units); let cube_count = calculate_cube_count_elemwise(&out_grad.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs index 0b38bc6c7d..b65df81ede 100644 --- a/crates/burn-cubecl/src/kernel/mask/mask_fill.rs +++ b/crates/burn-cubecl/src/kernel/mask/mask_fill.rs @@ -1,4 +1,4 @@ -use burn_backend::DType; +use burn_backend::{DType, TensorMetadata}; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; use crate::{ @@ -48,19 +48,19 @@ pub fn mask_fill( strategy: MaskFillStrategy, dtype_bool: DType, ) -> CubeTensor { - let ndims = input.shape.num_dims(); + let ndims = input.meta.num_dims(); let output = match strategy { MaskFillStrategy::Readonly => empty_device_dtype( input.client.clone(), input.device.clone(), - input.shape.clone(), + input.shape(), input.dtype, ), MaskFillStrategy::Inplace => input.clone(), }; let line_size = max_line_size_many(&[&input, &mask], ndims - 1); - let working_units = input.shape.num_elements() / line_size as usize; + let working_units = input.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/mask/mask_where.rs b/crates/burn-cubecl/src/kernel/mask/mask_where.rs index 85a7f00a14..1a6cb5cc50 100644 --- a/crates/burn-cubecl/src/kernel/mask/mask_where.rs +++ b/crates/burn-cubecl/src/kernel/mask/mask_where.rs @@ -49,9 +49,9 @@ pub fn mask_where( strategy: MaskWhereStrategy, dtype_bool: DType, ) -> CubeTensor { - let line_size = max_line_size_many(&[&input, &mask, &value], input.shape.num_dims() - 1); + let line_size = max_line_size_many(&[&input, &mask, &value], input.meta.num_dims() - 1); - let working_units = input.shape.num_elements() / line_size as usize; + let working_units = input.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&input.client, working_units); let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/matmul/base.rs b/crates/burn-cubecl/src/kernel/matmul/base.rs index 0524a614ac..2f9759cef2 100644 --- a/crates/burn-cubecl/src/kernel/matmul/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/base.rs @@ -96,7 +96,7 @@ pub(crate) fn launch_matmul( MatmulInputHandleRef::quantized( data.as_handle_ref(), scale.as_handle_ref(), - &lhs.shape, + lhs.meta.shape(), lhs.scheme(), data.dtype.into(), scale.dtype.into(), @@ -127,7 +127,7 @@ pub(crate) fn launch_matmul( MatmulInputHandleRef::quantized( data.as_handle_ref(), scale.as_handle_ref(), - &rhs.shape, + rhs.meta.shape(), rhs.scheme(), data.dtype.into(), scale.dtype.into(), diff --git a/crates/burn-cubecl/src/kernel/matmul/tune/base.rs b/crates/burn-cubecl/src/kernel/matmul/tune/base.rs index 072db39bbe..ae58638983 100644 --- a/crates/burn-cubecl/src/kernel/matmul/tune/base.rs +++ b/crates/burn-cubecl/src/kernel/matmul/tune/base.rs @@ -396,10 +396,10 @@ fn create_key( ) -> MatmulAutotuneKey { MatmulAutotuneKey::generate( &lhs.client, - &lhs.shape, - &rhs.shape, - &lhs.strides, - &rhs.strides, + lhs.meta.shape(), + rhs.meta.shape(), + lhs.meta.strides(), + rhs.meta.strides(), lhs.dtype.into(), rhs.dtype.into(), out.dtype.into(), diff --git a/crates/burn-cubecl/src/kernel/matmul/utils.rs b/crates/burn-cubecl/src/kernel/matmul/utils.rs index 8ae21f9a2b..fbbf69ff77 100644 --- a/crates/burn-cubecl/src/kernel/matmul/utils.rs +++ b/crates/burn-cubecl/src/kernel/matmul/utils.rs @@ -10,7 +10,7 @@ pub fn init_matmul_output( empty_device_dtype( lhs.client.clone(), lhs.device.clone(), - calculate_matmul_output(&lhs.shape, &rhs.shape).unwrap(), + calculate_matmul_output(lhs.meta.shape(), rhs.meta.shape()).unwrap(), dtype, ) } diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs index 9327288f92..1a182cf79e 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d.rs @@ -82,7 +82,7 @@ pub(crate) fn adaptive_avg_pool2d( input: CubeTensor, output_size: [usize; 2], ) -> CubeTensor { - let [batch_size, channels, _, _] = input.shape.dims(); + let [batch_size, channels, _, _] = input.meta.shape().dims(); let input = into_contiguous_aligned(permute_nchw_to_nhwc(input)); let line_size = max_line_size(&input); diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs index af42f0debb..da309dd0b3 100644 --- a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -88,7 +88,7 @@ pub(crate) fn adaptive_avg_pool2d_backward( x: CubeTensor, out_grad: CubeTensor, ) -> CubeTensor { - let [batches, channels, height, width] = x.shape.dims(); + let [batches, channels, height, width] = x.meta.shape().dims(); let out_grad = into_contiguous_aligned(permute_nchw_to_nhwc(out_grad)); let line_size = max_line_size(&out_grad); @@ -96,7 +96,7 @@ pub(crate) fn adaptive_avg_pool2d_backward( let out_shape = Shape::new([batches, height, width, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); - let num_elems = output.shape.num_elements(); + let num_elems = output.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs index 76c25418c5..3146ffffed 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d.rs @@ -100,7 +100,7 @@ pub(crate) fn avg_pool2d( count_include_pad: bool, ceil_mode: bool, ) -> CubeTensor { - let [batch_size, channels, in_h, in_w] = x.shape.dims(); + let [batch_size, channels, in_h, in_w] = x.meta.shape().dims(); let dilation = 1; let size_0 = calculate_pool_output_size( @@ -130,7 +130,7 @@ pub(crate) fn avg_pool2d( let shape_out = Shape::new([batch_size, size_0, size_1, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs index f6559fc2f9..59b462c963 100644 --- a/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool2d_backward.rs @@ -134,11 +134,11 @@ pub(crate) fn avg_pool2d_backward( count_include_pad: bool, _ceil_mode: bool, ) -> CubeTensor { - let [batches, channels, height, width] = x.shape.dims(); + let [batches, channels, height, width] = x.meta.shape().dims(); let grad = permute_nchw_to_nhwc(grad); - let line_size = if x.strides[3] == grad.strides[3] { + let line_size = if x.meta.strides()[3] == grad.meta.strides()[3] { max_line_size(&x) } else { 1 @@ -149,7 +149,7 @@ pub(crate) fn avg_pool2d_backward( let out_shape = Shape::new([batches, height, width, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs index 0e4a6b7eb6..582672702d 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d.rs @@ -124,14 +124,14 @@ pub(crate) fn max_pool2d( dilation: [usize; 2], ceil_mode: bool, ) -> CubeTensor { - let [batch_size, channels, _, _] = x.shape.dims(); + let [batch_size, channels, height, width] = x.meta.shape().dims(); let size_0 = calculate_pool_output_size( kernel_size[0], stride[0], padding[0], dilation[0], - x.shape[2], + height, ceil_mode, ); let size_1 = calculate_pool_output_size( @@ -139,7 +139,7 @@ pub(crate) fn max_pool2d( stride[1], padding[1], dilation[1], - x.shape[3], + width, ceil_mode, ); @@ -150,7 +150,7 @@ pub(crate) fn max_pool2d( let shape_out = Shape::new([batch_size, size_0, size_1, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); @@ -190,14 +190,14 @@ pub(crate) fn max_pool2d_with_indices( ceil_mode: bool, dtype_indices: DType, ) -> (CubeTensor, CubeTensor) { - let [batch_size, channels, _, _] = x.shape.dims(); + let [batch_size, channels, size_0, size_1] = x.meta.shape().dims(); let size_0 = calculate_pool_output_size( kernel_size[0], stride[0], padding[0], dilation[0], - x.shape[2], + size_0, ceil_mode, ); let size_1 = calculate_pool_output_size( @@ -205,7 +205,7 @@ pub(crate) fn max_pool2d_with_indices( stride[1], padding[1], dilation[1], - x.shape[3], + size_1, ceil_mode, ); @@ -221,7 +221,7 @@ pub(crate) fn max_pool2d_with_indices( ); let indices = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, dtype_indices); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs index ce3e65c6b5..22c223951c 100644 --- a/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-cubecl/src/kernel/pool/max_pool2d_backward.rs @@ -108,12 +108,12 @@ pub(crate) fn max_pool2d_with_indices_backward( dilation: [usize; 2], _ceil_mode: bool, ) -> CubeTensor { - let [batches, channels, height, width] = x.shape.dims(); + let [batches, channels, height, width] = x.meta.shape().dims(); let grad = into_contiguous_aligned(permute_nchw_to_nhwc(grad)); let indices = into_contiguous_aligned(permute_nchw_to_nhwc(indices)); - let line_size = if grad.strides[3] == indices.strides[3] { + let line_size = if grad.meta.strides()[3] == indices.meta.strides()[3] { max_line_size(&grad) } else { 1 @@ -122,7 +122,7 @@ pub(crate) fn max_pool2d_with_indices_backward( let out_shape = Shape::new([batches, height, width, channels]); let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); - let working_units = output.shape.num_elements() / line_size as usize; + let working_units = output.meta.num_elements() / line_size as usize; let cube_dim = CubeDim::new(&x.client, working_units); let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/kernel/pool/pool2d.rs b/crates/burn-cubecl/src/kernel/pool/pool2d.rs index ead3c4f61c..7dde8fc9e4 100644 --- a/crates/burn-cubecl/src/kernel/pool/pool2d.rs +++ b/crates/burn-cubecl/src/kernel/pool/pool2d.rs @@ -137,7 +137,7 @@ pub(super) fn view4d( tensor: &CubeTensor, line_size: LineSize, ) -> ViewArg<'_, Position, R> { - let shape = &tensor.shape; + let shape = tensor.meta.shape(); let shape = ( ScalarArg::new(shape[0]), ScalarArg::new(shape[1]), diff --git a/crates/burn-cubecl/src/kernel/prng/uniform.rs b/crates/burn-cubecl/src/kernel/prng/uniform.rs index 9152f45a03..06cf097ab1 100644 --- a/crates/burn-cubecl/src/kernel/prng/uniform.rs +++ b/crates/burn-cubecl/src/kernel/prng/uniform.rs @@ -1,5 +1,5 @@ use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor}; -use burn_backend::{DType, Shape}; +use burn_backend::{DType, Shape, TensorMetadata}; /// Pseudo-random generator with uniform distribution pub fn random_uniform( @@ -34,7 +34,7 @@ pub fn random_like_uniform( dtype: DType, ) -> CubeTensor { random_uniform( - tensor.shape.clone(), + tensor.shape(), &tensor.device, lower_bound, upper_bound, diff --git a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs index d874ebf3d4..747d3a2985 100644 --- a/crates/burn-cubecl/src/kernel/quantization/dequantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/dequantize.rs @@ -1,6 +1,6 @@ use crate::tensor::CubeTensor; use crate::{CubeRuntime, ops::numeric::empty_device_dtype}; -use burn_backend::DType; +use burn_backend::{DType, TensorMetadata}; /// Convert the tensor back to a higher precision data type. pub fn dequantize(tensor: CubeTensor, dtype: DType) -> CubeTensor @@ -15,7 +15,7 @@ where let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), dtype, ); let (values, params) = tensor.quantized_handles().unwrap(); diff --git a/crates/burn-cubecl/src/kernel/quantization/quantize.rs b/crates/burn-cubecl/src/kernel/quantization/quantize.rs index 37de6edffc..1c2d8bcc86 100644 --- a/crates/burn-cubecl/src/kernel/quantization/quantize.rs +++ b/crates/burn-cubecl/src/kernel/quantization/quantize.rs @@ -1,6 +1,6 @@ use crate::CubeRuntime; use crate::{ops::empty_qtensor_optimized, tensor::CubeTensor}; -use burn_backend::quantization::QuantScheme; +use burn_backend::{TensorMetadata, quantization::QuantScheme}; /// Convert the tensor to a lower precision data type based on the quantization scheme and parameters. pub fn quantize( @@ -11,7 +11,7 @@ pub fn quantize( where R: CubeRuntime, { - let output = empty_qtensor_optimized(tensor.shape.clone(), *scheme, &tensor.device); + let output = empty_qtensor_optimized(tensor.shape(), *scheme, &tensor.device); let (out_values, out_params) = output.clone().quantized_handles().unwrap(); cubek::quantization::quantize::launch_ref( diff --git a/crates/burn-cubecl/src/kernel/reduce/base.rs b/crates/burn-cubecl/src/kernel/reduce/base.rs index 15e997b06c..7d7b283f37 100644 --- a/crates/burn-cubecl/src/kernel/reduce/base.rs +++ b/crates/burn-cubecl/src/kernel/reduce/base.rs @@ -5,7 +5,8 @@ use crate::{ ops::numeric::{empty_device_contiguous_dtype, zeros_client}, tensor::CubeTensor, }; -use burn_backend::{DType, Shape}; +use burn_backend::{DType, TensorMetadata}; +use burn_std::Metadata; use cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType}; use cubek::reduce::{ ReduceDtypes, ReduceError, ReduceStrategy, @@ -118,13 +119,12 @@ pub fn reduce( ) -> Result, cubek::reduce::ReduceError> { // In practice, it looks like starting by the axis with the smallest shape // and going in increasing order lead to the fastest calculation. - let sorted_axis = argsort(&tensor.shape); + let sorted_axis = argsort(tensor.meta.shape()); for axis in sorted_axis { tensor = reduce_dim::(tensor, output_dtype, axis, strategy.clone(), config)?; } // reshape to scalar tensor - tensor.shape = Shape::new([1]); - tensor.strides = vec![1]; + *tensor.meta = Metadata::new([1], [1]); Ok(tensor) } @@ -162,7 +162,7 @@ pub fn reduce_dim( let output = init_reduce_output::(&input, dim, &dtypes).ok_or( cubek::reduce::ReduceError::InvalidAxis { axis: dim, - rank: input.shape.num_dims(), + rank: input.meta.num_dims(), }, )?; @@ -206,8 +206,8 @@ pub fn init_reduce_output( dim: usize, dtypes: &ReduceDtypes, ) -> Option> { - (dim < input.shape.num_dims()).then(|| { - let mut shape_out = input.shape.clone(); + (dim < input.meta.num_dims()).then(|| { + let mut shape_out = input.shape(); shape_out[dim] = 1; empty_device_contiguous_dtype( input.client.clone(), diff --git a/crates/burn-cubecl/src/kernel/reduce/tune.rs b/crates/burn-cubecl/src/kernel/reduce/tune.rs index 4ab62930b5..25c3c2364c 100644 --- a/crates/burn-cubecl/src/kernel/reduce/tune.rs +++ b/crates/burn-cubecl/src/kernel/reduce/tune.rs @@ -167,8 +167,8 @@ pub(crate) fn create_key( elem_input, elem_output, elem_acc, - &input.shape, - input.strides[*axis] == 1, + input.meta.shape(), + input.meta.strides()[*axis] == 1, *axis, ) } @@ -236,7 +236,7 @@ impl SumAutotuneKey { #[allow(unused)] pub(crate) fn generate(input: &CubeTensor) -> Self { let dtype = input.dtype; - let length = input.shape.num_elements(); + let length = input.meta.num_elements(); Self::new(dtype, length) } } diff --git a/crates/burn-cubecl/src/kernel/unary_float.rs b/crates/burn-cubecl/src/kernel/unary_float.rs index b7d5437a16..91bc2779db 100644 --- a/crates/burn-cubecl/src/kernel/unary_float.rs +++ b/crates/burn-cubecl/src/kernel/unary_float.rs @@ -4,6 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; +use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync { @@ -43,7 +44,7 @@ where let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); @@ -68,7 +69,7 @@ where let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), tensor.dtype, ); diff --git a/crates/burn-cubecl/src/kernel/unary_int.rs b/crates/burn-cubecl/src/kernel/unary_int.rs index 322dd45d16..0f7e36b498 100644 --- a/crates/burn-cubecl/src/kernel/unary_int.rs +++ b/crates/burn-cubecl/src/kernel/unary_int.rs @@ -4,6 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; +use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync { @@ -40,7 +41,7 @@ where { let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); @@ -65,7 +66,7 @@ where let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), tensor.dtype, ); diff --git a/crates/burn-cubecl/src/kernel/unary_numeric.rs b/crates/burn-cubecl/src/kernel/unary_numeric.rs index 5c46713139..af58fa7bbc 100644 --- a/crates/burn-cubecl/src/kernel/unary_numeric.rs +++ b/crates/burn-cubecl/src/kernel/unary_numeric.rs @@ -4,6 +4,7 @@ use crate::{ ops::{max_line_size, numeric::empty_device_dtype}, tensor::CubeTensor, }; +use burn_backend::TensorMetadata; use cubecl::{calculate_cube_count_elemwise, prelude::*, std::tensor::layout::linear::LinearView}; pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync { @@ -42,7 +43,7 @@ where { let line_size = max_line_size(&tensor); let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); + let num_elems = tensor.meta.num_elements(); let working_units = num_elems / line_size as usize; let cube_dim = CubeDim::new(&tensor.client, working_units); @@ -67,7 +68,7 @@ where let output = empty_device_dtype( tensor.client.clone(), tensor.device.clone(), - tensor.shape.clone(), + tensor.shape(), tensor.dtype, ); diff --git a/crates/burn-cubecl/src/kernel/utils.rs b/crates/burn-cubecl/src/kernel/utils.rs index c9f1a41bac..fbe4c4e262 100644 --- a/crates/burn-cubecl/src/kernel/utils.rs +++ b/crates/burn-cubecl/src/kernel/utils.rs @@ -15,7 +15,7 @@ pub fn shape_divmod<'a, R: CubeRuntime>( tensor: &CubeTensor, ) -> SequenceArg<'a, R, FastDivmod> { let mut arg = SequenceArg::new(); - for dim in tensor.shape.iter() { + for dim in tensor.meta.shape().iter() { arg.push(FastDivmodArgs::::new(&tensor.client, *dim)); } arg @@ -25,7 +25,12 @@ pub fn linear_layout<'a, R: CubeRuntime>( tensor: &'a CubeTensor, line_size: LineSize, ) -> LinearLayoutArgs<'a, R> { - LinearLayoutArgs::from_shape_strides(&tensor.client, &tensor.shape, &tensor.strides, line_size) + LinearLayoutArgs::from_shape_strides( + &tensor.client, + tensor.meta.shape(), + tensor.meta.strides(), + line_size, + ) } pub fn linear_layout_ref<'a, R: CubeRuntime>( @@ -35,9 +40,9 @@ pub fn linear_layout_ref<'a, R: CubeRuntime>( ) -> LinearLayoutArgs<'a, R> { LinearLayoutArgs::from_shape_strides_with_reference( &tensor.client, - &tensor.shape, - &reference.shape, - &tensor.strides, + tensor.meta.shape(), + reference.meta.shape(), + tensor.meta.strides(), line_size, ) } @@ -46,7 +51,7 @@ pub fn linear_view<'a, R: CubeRuntime>( tensor: &'a CubeTensor, line_size: LineSize, ) -> LinearViewLaunch<'a, R> { - let len = tensor.shape.iter().product::(); + let len = tensor.meta.num_elements(); let layout = linear_layout(tensor, line_size); let buffer = unsafe { ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size()) @@ -59,7 +64,7 @@ pub fn linear_view_ref<'a, R: CubeRuntime>( reference: &'a CubeTensor, line_size: LineSize, ) -> LinearViewLaunch<'a, R> { - let len = tensor.shape.iter().product::(); + let len = tensor.meta.num_elements(); let layout = linear_layout_ref(tensor, reference, line_size); let buffer = unsafe { ArrayArg::from_raw_parts_and_size(&tensor.handle, len, line_size, tensor.elem_size()) @@ -82,13 +87,11 @@ pub fn split_dim( dim: usize, shape: &[usize], ) -> CubeTensor { - let mut stride = tensor.strides[dim]; - tensor.shape.remove(dim); - tensor.strides.remove(dim); + let mut stride = tensor.meta.strides()[dim]; + tensor.meta.remove(dim); for size in shape.iter().rev() { - tensor.shape.insert(dim, *size); - tensor.strides.insert(dim, stride); + tensor.meta.insert(dim, *size, stride); stride *= size; } @@ -96,19 +99,19 @@ pub fn split_dim( } pub fn broadcast_shape(tensors: &[&CubeTensor]) -> Shape { - let rank = tensors[0].shape.num_dims(); + let rank = tensors[0].meta.num_dims(); debug_assert!( - tensors.iter().all(|it| it.shape.num_dims() == rank), + tensors.iter().all(|it| it.meta.num_dims() == rank), "Broadcast tensors must have the same rank" ); let dims = (0..rank).map(|dim| { - let max = tensors.iter().map(|it| it.shape[dim]).max(); + let max = tensors.iter().map(|it| it.meta.shape()[dim]).max(); let max = max.unwrap_or(1); debug_assert!( tensors .iter() - .all(|it| it.shape[dim] == max || it.shape[dim] == 1), + .all(|it| it.meta.shape()[dim] == max || it.meta.shape()[dim] == 1), "Broadcast dims must be size 1" ); max @@ -121,16 +124,29 @@ pub fn broadcast_strides<'a, R: CubeRuntime>( reference: &CubeTensor, tensor: &'a CubeTensor, ) -> SequenceArg<'a, R, usize> { - if reference.shape != tensor.shape { + if reference.meta.shape() != tensor.meta.shape() { tensor - .strides + .meta + .strides() .iter() - .zip(tensor.shape.iter().zip(reference.shape.iter())) + .zip( + tensor + .meta + .shape() + .iter() + .zip(reference.meta.shape().iter()), + ) .map(|(stride, (shape, ref_shape))| if *shape == *ref_shape { *stride } else { 0 }) .map(ScalarArg::new) .collect() } else { - tensor.strides.iter().copied().map(ScalarArg::new).collect() + tensor + .meta + .strides() + .iter() + .copied() + .map(ScalarArg::new) + .collect() } } diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index fdfeb6be36..c99fb2de98 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -4,7 +4,10 @@ use burn_backend::{ quantization::{QuantLevel, QuantStore, params_shape}, }; use burn_backend::{TensorMetadata, ops::unfold::calculate_unfold_shape}; -use burn_std::tensor::{ReshapeAction, contiguous_strides, reshape_action}; +use burn_std::{ + Metadata, strides, + tensor::{ReshapeAction, contiguous_strides, reshape_action}, +}; use cubecl::{ir::LineSize, server::CopyDescriptor}; use cubecl::{quant::scheme::BlockSize, tensor_line_size_parallel}; @@ -15,9 +18,8 @@ pub(crate) fn from_data(data: TensorData, device: &R::Device) -> CubeTensor::new( client, alloc.handle, - shape, + Metadata::new(shape, alloc.strides), device.clone(), - alloc.strides, data.dtype, ) } @@ -28,8 +30,9 @@ pub(crate) async fn into_data( let tensor = kernel::into_contiguous_aligned(tensor); let elem_size = tensor.elem_size(); - let shape = &tensor.shape; - let binding = CopyDescriptor::new(tensor.handle.binding(), shape, &tensor.strides, elem_size); + let shape = tensor.meta.shape(); + let strides = tensor.meta.strides(); + let binding = CopyDescriptor::new(tensor.handle.binding(), shape, strides, elem_size); let bytes = tensor .client .read_one_tensor_async(binding) @@ -38,7 +41,11 @@ pub(crate) async fn into_data( reason: format!("{err}"), })?; - Ok(TensorData::from_bytes(bytes, tensor.shape, tensor.dtype)) + Ok(TensorData::from_bytes( + bytes, + tensor.meta.shape, + tensor.dtype, + )) } /// Read data from a `CubeTensor` synchronously @@ -75,9 +82,8 @@ pub(crate) fn empty( CubeTensor::new( client, alloc.handle, - shape, + Metadata::new(shape, alloc.strides), device.clone(), - alloc.strides, dtype, ) } @@ -87,8 +93,7 @@ pub(crate) fn swap_dims( dim1: usize, dim2: usize, ) -> CubeTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape = tensor.shape.swap(dim1, dim2).unwrap(); + tensor.meta.swap(dim1, dim2); if let DType::QFloat(scheme) = tensor.dtype && let QuantLevel::Block(block_size) = scheme.level @@ -104,8 +109,7 @@ pub(crate) fn swap_dims( panic!("Swapped block size would exceed max dims"); } - qparams.scales.shape.inner_mut().swap(dim1, dim2); - qparams.scales.strides.swap(dim1, dim2); + qparams.scales.metadata.swap(dim1, dim2); tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::Block(block_size))) } @@ -114,7 +118,7 @@ pub(crate) fn swap_dims( && let QuantStore::PackedU32(packed_dim) | QuantStore::PackedNative(packed_dim) = &mut scheme.store { - let rank = tensor.shape.len(); + let rank = tensor.meta.num_dims(); if *packed_dim == rank - dim1 - 1 { *packed_dim = rank - dim2 - 1; @@ -128,11 +132,7 @@ pub(crate) fn swap_dims( /// Permute a tensor's dimensions pub fn permute(mut tensor: CubeTensor, axes: &[usize]) -> CubeTensor { - // remap strides - tensor.strides = axes.iter().map(|i| tensor.strides[*i]).collect(); - - // remap shape - tensor.shape = tensor.shape.permute(axes).unwrap(); + tensor.meta.permute(axes).unwrap(); if let DType::QFloat(scheme) = tensor.dtype && let QuantLevel::Block(block_size) = scheme.level @@ -152,8 +152,7 @@ pub fn permute(mut tensor: CubeTensor, axes: &[usize]) -> Cub panic!("Swapped block size would exceed max dims"); } - qparams.scales.strides = axes.iter().map(|i| qparams.scales.strides[*i]).collect(); - qparams.scales.shape = qparams.scales.shape.clone().permute(axes).unwrap(); + qparams.scales.metadata.permute(axes).unwrap(); tensor.dtype = DType::QFloat(scheme.with_level(QuantLevel::block(&block_size))) } @@ -161,7 +160,7 @@ pub fn permute(mut tensor: CubeTensor, axes: &[usize]) -> Cub if let DType::QFloat(scheme) = &mut tensor.dtype && let QuantStore::PackedU32(packed_dim) = &mut scheme.store { - let rank = tensor.shape.len(); + let rank = tensor.meta.num_dims(); let new_pos = axes .iter() .position(|axis| *axis == rank - *packed_dim - 1) @@ -174,7 +173,7 @@ pub fn permute(mut tensor: CubeTensor, axes: &[usize]) -> Cub /// Permute a tensor's dimensions from NCHW to NHWC, or the N-dimensional equivalent pub fn permute_nchw_to_nhwc(tensor: CubeTensor) -> CubeTensor { - let rank = tensor.shape.num_dims(); + let rank = tensor.meta.num_dims(); let c_dim = 1; let mut dims = vec![0]; @@ -193,12 +192,12 @@ pub fn permute_nchw_to_nhwc_shape(shape: Shape) -> Shape { dims.extend(2..rank); dims.push(c_dim); - shape.permute(&dims).expect("Shape permute should succeed") + shape.permuted(&dims).expect("Shape permute should succeed") } /// Permute a tensor's dimensions from NHWC to NCHW, or the N-dimensional equivalent pub fn permute_nhwc_to_nchw(tensor: CubeTensor) -> CubeTensor { - let rank = tensor.shape.num_dims(); + let rank = tensor.meta.num_dims(); let c_dim = rank - 1; let mut dims = vec![0]; @@ -217,28 +216,28 @@ pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape { dims.push(c_dim); dims.extend(1..c_dim); - shape.permute(&dims).expect("Shape permute should succeed") + shape.permuted(&dims).expect("Shape permute should succeed") } pub(crate) fn expand(tensor: CubeTensor, target_shape: Shape) -> CubeTensor { - let ndims_in = tensor.shape.num_dims(); + let ndims_in = tensor.meta.shape().num_dims(); let ndims_out = target_shape.num_dims(); // Initialize new strides with zeros - let mut new_strides = vec![0usize; ndims_out]; + let mut new_strides = strides![0usize; ndims_out]; // Calculate the difference in dimensions let dim_diff = ndims_out.saturating_sub(ndims_in); // Compare dimensions from the end, setting strides for matching dimensions or broadcasted ones - let mut tensor_dim_iter = tensor.shape.iter().rev(); + let mut tensor_dim_iter = tensor.meta.shape().iter().rev(); for i in (0..ndims_out).rev() { if i >= dim_diff { if let Some(&tensor_dim) = tensor_dim_iter.next() { if tensor_dim == target_shape[i] || tensor_dim == 1 { // Copy stride for non-broadcast dimensions or set to 0 for broadcast ones new_strides[i] = if tensor_dim == target_shape[i] { - tensor.strides[i - dim_diff] + tensor.meta.strides()[i - dim_diff] } else { 0 }; @@ -270,8 +269,7 @@ pub(crate) fn expand(tensor: CubeTensor, target_shape: Shape) CubeTensor { client: tensor.client, device: tensor.device, - shape: target_shape, - strides: new_strides, + meta: Box::new(Metadata::new(target_shape, new_strides)), handle: tensor.handle, dtype: tensor.dtype, qparams: tensor.qparams, @@ -280,12 +278,11 @@ pub(crate) fn expand(tensor: CubeTensor, target_shape: Shape) /// Reshape a jit tensor to a new shape pub fn reshape(mut tensor: CubeTensor, shape: Shape) -> CubeTensor { - let analysis = reshape_action(&tensor.shape, &tensor.strides, &shape); + let analysis = reshape_action(tensor.meta.shape(), tensor.meta.strides(), &shape); match analysis { ReshapeAction::UpdateStrides { strides } => { - tensor.shape = shape; - tensor.strides = strides; + *tensor.meta = Metadata::new(shape, strides); return tensor; } ReshapeAction::NoChange => return tensor, @@ -323,8 +320,8 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub let shape_scales = params_shape(&shape, scheme.level); let (values, scales) = tensor.quantized_handles().unwrap(); - let analysis_values = reshape_action(&values.shape, &values.strides, &shape_values); - let analysis_scales = reshape_action(&scales.shape, &scales.strides, &shape_scales); + let analysis_values = reshape_action(values.meta.shape(), values.meta.strides(), &shape_values); + let analysis_scales = reshape_action(scales.meta.shape(), scales.meta.strides(), &shape_scales); match (analysis_values, analysis_scales) { ( @@ -335,15 +332,11 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub ) => { let qparams = tensor.qparams.as_mut().unwrap(); - tensor.shape = shape; - tensor.strides = strides; - - qparams.scales.shape = shape_scales; - qparams.scales.strides = scales_strides; + *tensor.meta = Metadata::new(shape, strides); + qparams.scales.metadata = Metadata::new(shape_scales, scales_strides); } (ReshapeAction::UpdateStrides { strides }, ReshapeAction::NoChange) => { - tensor.shape = shape; - tensor.strides = strides; + *tensor.meta = Metadata::new(shape, strides); } ( ReshapeAction::NoChange, @@ -353,19 +346,17 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub ) => { let qparams = tensor.qparams.as_mut().unwrap(); - qparams.scales.shape = shape_scales; - qparams.scales.strides = scales_strides; + qparams.scales.metadata = Metadata::new(shape_scales, scales_strides); } (ReshapeAction::NoChange, ReshapeAction::NoChange) => {} _ => { tensor = kernel::into_contiguous(tensor); - tensor.shape = shape; - tensor.strides = contiguous_strides(&shape_values); + *tensor.meta = Metadata::new(shape, contiguous_strides(&shape_values)); let qparams = tensor.qparams.as_mut().unwrap(); - qparams.scales.strides = contiguous_strides(&shape_scales); - qparams.scales.shape = shape_scales; + let strides = contiguous_strides(&shape_scales); + qparams.scales.metadata = Metadata::new(shape_scales, strides); } } @@ -375,9 +366,9 @@ pub fn q_reshape(mut tensor: CubeTensor, shape: Shape) -> Cub pub(crate) fn max_line_size(tensor: &CubeTensor) -> LineSize { tensor_line_size_parallel( tensor.client.io_optimized_line_sizes(tensor.dtype.size()), - &tensor.shape, - &tensor.strides, - tensor.shape.len() - 1, + tensor.meta.shape(), + tensor.meta.strides(), + tensor.meta.num_dims() - 1, ) } @@ -390,8 +381,8 @@ pub(crate) fn max_line_size_many( .map(|tensor| { tensor_line_size_parallel( tensor.client.io_optimized_line_sizes(tensor.dtype.size()), - &tensor.shape, - &tensor.strides, + tensor.meta.shape(), + tensor.meta.strides(), axis, ) }) @@ -427,16 +418,15 @@ pub fn unfold( size: usize, step: usize, ) -> CubeTensor { - let shape = calculate_unfold_shape(tensor.shape, dim, size, step); + let shape = calculate_unfold_shape(tensor.shape(), dim, size, step); - let d_stride = tensor.strides[dim]; - let mut strides = tensor.strides.clone(); + let d_stride = tensor.meta.strides()[dim]; + let mut strides = tensor.meta.strides.clone(); strides[dim] = step * d_stride; strides.push(d_stride); CubeTensor { - shape, - strides, + meta: Box::new(Metadata::new(shape, strides)), ..tensor } } diff --git a/crates/burn-cubecl/src/ops/bool_tensor.rs b/crates/burn-cubecl/src/ops/bool_tensor.rs index 319948dfbb..fcadd63083 100644 --- a/crates/burn-cubecl/src/ops/bool_tensor.rs +++ b/crates/burn-cubecl/src/ops/bool_tensor.rs @@ -69,7 +69,7 @@ where let simple_ranges: Vec> = slices .iter() .enumerate() - .map(|(i, slice)| slice.to_range(tensor.shape[i])) + .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); kernel::slice(tensor, &simple_ranges) @@ -112,8 +112,7 @@ where } fn bool_swap_dims(mut tensor: BoolTensor, dim1: usize, dim2: usize) -> BoolTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape = tensor.shape.swap(dim1, dim2).unwrap(); + tensor.meta.swap(dim1, dim2); tensor } diff --git a/crates/burn-cubecl/src/ops/int_tensor.rs b/crates/burn-cubecl/src/ops/int_tensor.rs index ef16fa5bdc..0a04114f61 100644 --- a/crates/burn-cubecl/src/ops/int_tensor.rs +++ b/crates/burn-cubecl/src/ops/int_tensor.rs @@ -76,7 +76,7 @@ where let simple_ranges: Vec> = slices .iter() .enumerate() - .map(|(i, slice)| slice.to_range(tensor.shape[i])) + .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); kernel::slice(tensor, &simple_ranges) @@ -439,8 +439,7 @@ where } fn int_swap_dims(mut tensor: IntTensor, dim1: usize, dim2: usize) -> IntTensor { - tensor.strides.swap(dim1, dim2); - tensor.shape = tensor.shape.swap(dim1, dim2).unwrap(); + tensor.meta.swap(dim1, dim2); tensor } diff --git a/crates/burn-cubecl/src/ops/module.rs b/crates/burn-cubecl/src/ops/module.rs index abc92c90cb..f6915d914d 100644 --- a/crates/burn-cubecl/src/ops/module.rs +++ b/crates/burn-cubecl/src/ops/module.rs @@ -3,11 +3,14 @@ use crate::{ element::BoolElement, kernel::{self, conv::ConvTranspose2dStrategy}, }; -use burn_backend::ops::{ - AttentionOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, - InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, -}; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor}; +use burn_backend::{ + TensorMetadata, + ops::{ + AttentionOptions, ConvOptions, ConvTransposeOptions, DeformConv2dBackward, + DeformConvOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + }, +}; impl ModuleOps for CubeBackend where @@ -31,8 +34,14 @@ where output_grad: FloatTensor, options: ConvOptions<1>, ) -> FloatTensor { - kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default()) - .unwrap() + kernel::conv::conv_data_backward( + output_grad, + weight, + x.shape(), + options, + Default::default(), + ) + .unwrap() } fn conv1d_weight_backward( @@ -44,7 +53,7 @@ where kernel::conv::conv_weight_backward::( x, output_grad, - weight.shape.clone(), + weight.shape(), options, Default::default(), ) @@ -66,8 +75,14 @@ where output_grad: FloatTensor, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default()) - .unwrap() + kernel::conv::conv_data_backward( + output_grad, + weight, + x.shape(), + options, + Default::default(), + ) + .unwrap() } fn conv2d_weight_backward( @@ -79,7 +94,7 @@ where kernel::conv::conv_weight_backward::( x, output_grad, - weight.shape.clone(), + weight.shape(), options, Default::default(), ) @@ -134,8 +149,14 @@ where output_grad: FloatTensor, options: ConvOptions<3>, ) -> FloatTensor { - kernel::conv::conv_data_backward(output_grad, weight, x.shape, options, Default::default()) - .unwrap() + kernel::conv::conv_data_backward( + output_grad, + weight, + x.shape(), + options, + Default::default(), + ) + .unwrap() } fn conv3d_weight_backward( @@ -147,7 +168,7 @@ where kernel::conv::conv_weight_backward::( x, output_grad, - weight.shape.clone(), + weight.shape(), options, Default::default(), ) diff --git a/crates/burn-cubecl/src/ops/numeric.rs b/crates/burn-cubecl/src/ops/numeric.rs index 6f0693ac8b..b10a32de28 100644 --- a/crates/burn-cubecl/src/ops/numeric.rs +++ b/crates/burn-cubecl/src/ops/numeric.rs @@ -10,7 +10,8 @@ use crate::{ }, ops::max_line_size, }; -use burn_backend::{DType, Shape}; +use burn_backend::{DType, Shape, TensorMetadata}; +use burn_std::Metadata; use cubecl::{calculate_cube_count_elemwise, prelude::*}; use cubecl::{client::ComputeClient, server::Allocation}; use cubecl::{ @@ -63,7 +64,7 @@ pub fn full_device_dtype( tensor[ABSOLUTE_POS] = value.get::(); } - let num_elems = empty.shape.num_elements(); + let num_elems = empty.meta.num_elements(); let line_size = max_line_size(&empty); let working_units = num_elems / line_size as usize; @@ -126,7 +127,13 @@ pub fn empty_device( ) -> CubeTensor { let Allocation { handle, strides } = client.empty_tensor(&shape, size_of::()); - CubeTensor::new(client, handle, shape, device, strides, E::dtype()) + CubeTensor::new( + client, + handle, + Metadata::new(shape, strides), + device, + E::dtype(), + ) } /// Create a tensor with uninitialized memory @@ -138,7 +145,7 @@ pub fn empty_device_dtype( ) -> CubeTensor { let Allocation { handle, strides } = client.empty_tensor(&shape, dtype.size()); - CubeTensor::new(client, handle, shape, device, strides, dtype) + CubeTensor::new(client, handle, Metadata::new(shape, strides), device, dtype) } /// Create a contiguous tensor with uninitialized memory @@ -151,7 +158,7 @@ pub fn empty_device_contiguous_dtype( let descriptor = AllocationDescriptor::contiguous(&shape, dtype.size()); let Allocation { handle, strides } = client.empty_tensors(vec![descriptor]).remove(0); - CubeTensor::new(client, handle, shape, device, strides, dtype) + CubeTensor::new(client, handle, Metadata::new(shape, strides), device, dtype) } /// Add two tensors @@ -409,9 +416,9 @@ fn cumulative_op( let client = input.client.clone(); let device = input.device.clone(); - let output = empty_device_dtype(client.clone(), device, input.shape.clone(), input.dtype); + let output = empty_device_dtype(client.clone(), device, input.shape(), input.dtype); - let num_elems = output.shape.num_elements(); + let num_elems = output.meta.num_elements(); let working_units = num_elems; let cube_dim = CubeDim::new(&client, working_units); let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim); diff --git a/crates/burn-cubecl/src/ops/qtensor.rs b/crates/burn-cubecl/src/ops/qtensor.rs index 713f9bae9f..70e5a7c266 100644 --- a/crates/burn-cubecl/src/ops/qtensor.rs +++ b/crates/burn-cubecl/src/ops/qtensor.rs @@ -1,5 +1,6 @@ use burn_backend::{ - Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive, + Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorMetadata, + TensorPrimitive, ops::QTensorOps, quantization::{ QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue, @@ -7,6 +8,7 @@ use burn_backend::{ }, tensor::{Device, FloatElem, FloatTensor, IntTensor, QuantizedTensor}, }; +use burn_std::Metadata; use cubecl::server::{Allocation, AllocationDescriptor, AllocationKind}; use cubecl::{e2m1x2, quant::scheme::QuantStore}; @@ -136,8 +138,7 @@ fn new_quantized( let scales = QParamTensor { offset_start: scales_handle.offset_start.unwrap_or(0) as usize, offset_end: scales_handle.offset_end.unwrap_or(0) as usize, - shape: scales_shape, - strides: scales_strides, + metadata: Metadata::new(scales_shape, scales_strides), dtype: scales_dtype, }; let qparams = QParams { scales }; @@ -221,7 +222,7 @@ where return into_data(tensor).await; } - let (shape, dtype) = (tensor.shape.clone(), tensor.dtype); + let (shape, dtype) = (tensor.shape(), tensor.dtype); let (values, params) = tensor.quantized_handles().unwrap(); let mut data_values = into_data(values).await?; diff --git a/crates/burn-cubecl/src/ops/tensor.rs b/crates/burn-cubecl/src/ops/tensor.rs index b9dccbb44d..3307d36056 100644 --- a/crates/burn-cubecl/src/ops/tensor.rs +++ b/crates/burn-cubecl/src/ops/tensor.rs @@ -59,7 +59,7 @@ where #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), - fields(from = ?tensor.device, shape = ?tensor.shape, dtype = ?tensor.dtype) + fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype) ))] async fn float_into_data(tensor: FloatTensor) -> Result { super::into_data(tensor).await @@ -72,7 +72,7 @@ where #[cfg_attr(feature = "tracing", tracing::instrument( level="trace", skip(tensor), - fields(from = ?tensor.device, shape = ?tensor.shape, dtype = ?tensor.dtype) + fields(from = ?tensor.device, meta = ?tensor.meta, dtype = ?tensor.dtype) ))] fn float_to_device(tensor: FloatTensor, device: &Device) -> FloatTensor { super::to_device(tensor, device) @@ -219,7 +219,7 @@ where let simple_ranges: Vec> = slices .iter() .enumerate() - .map(|(i, slice)| slice.to_range(tensor.shape[i])) + .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i])) .collect(); kernel::slice(tensor, &simple_ranges) diff --git a/crates/burn-cubecl/src/ops/transaction.rs b/crates/burn-cubecl/src/ops/transaction.rs index e1a356c46e..6326802f49 100644 --- a/crates/burn-cubecl/src/ops/transaction.rs +++ b/crates/burn-cubecl/src/ops/transaction.rs @@ -3,6 +3,7 @@ use burn_backend::{ backend::ExecutionError, ops::{TransactionOps, TransactionPrimitive, TransactionPrimitiveData}, }; +use burn_std::{Shape, Strides}; use cubecl::server::{Binding, CopyDescriptor}; use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement}; @@ -30,8 +31,8 @@ where index: usize, kind: Kind, handle: Option, - shape: Vec, - strides: Vec, + shape: Shape, + strides: Strides, dtype: DType, } @@ -49,8 +50,8 @@ where num_bindings, Kind::Float, Some(t.handle.binding()), - t.shape.into(), - t.strides, + t.meta.shape, + t.meta.strides, t.dtype, ); @@ -67,8 +68,8 @@ where num_bindings, Kind::Int, Some(t.handle.binding()), - t.shape.into(), - t.strides, + t.meta.shape, + t.meta.strides, t.dtype, ); @@ -85,8 +86,8 @@ where num_bindings, Kind::Bool, Some(t.handle.binding()), - t.shape.into(), - t.strides, + t.meta.shape, + t.meta.strides, t.dtype, ); diff --git a/crates/burn-cubecl/src/template/base.rs b/crates/burn-cubecl/src/template/base.rs index bedb44c643..1ee2f9b64e 100644 --- a/crates/burn-cubecl/src/template/base.rs +++ b/crates/burn-cubecl/src/template/base.rs @@ -82,20 +82,20 @@ macro_rules! kernel_source { /// | (2 * D + 1)..(3 * D + 1) | lhs shape | /// | (3 * D + 1)..(4 * D + 1) | rhs shape | pub fn build_info(tensors: &[&CubeTensor]) -> Vec { - let ndims = tensors[0].shape.num_dims(); + let ndims = tensors[0].meta.num_dims(); let mut info: Vec = vec![0; tensors.len() * 2 * ndims + 1]; info[0] = ndims as u32; let mut current = 1; for tensor in tensors.iter() { for d in 0..ndims { - info[current] = tensor.strides[d] as u32; + info[current] = tensor.meta.strides()[d] as u32; current += 1; } } for tensor in tensors.iter() { for d in 0..ndims { - info[current] = tensor.shape[d] as u32; + info[current] = tensor.meta.shape()[d] as u32; current += 1; } } diff --git a/crates/burn-cubecl/src/tensor/base.rs b/crates/burn-cubecl/src/tensor/base.rs index cf584bcf1f..118b5de2bb 100644 --- a/crates/burn-cubecl/src/tensor/base.rs +++ b/crates/burn-cubecl/src/tensor/base.rs @@ -3,7 +3,7 @@ use crate::element::CubeElement; use crate::kernel::{NumericUnaryOp, NumericUnaryOpFamily, launch_unary_numeric}; use burn_backend::quantization::QuantScheme; use burn_backend::{DType, QTensorPrimitive, Shape, TensorMetadata}; -use burn_std::tensor::is_contiguous; +use burn_std::{Metadata, strides, tensor::is_contiguous}; use cubecl::client::ComputeClient; use cubecl::frontend::Numeric; use cubecl::prelude::{TensorHandleRef, *}; @@ -19,12 +19,10 @@ pub struct CubeTensor { pub client: ComputeClient, /// The buffer where the data are stored. pub handle: Handle, - /// The shape of the tensor. - pub shape: Shape, + /// The metadata of the tensor. + pub meta: Box, /// The device of the tensor. pub device: R::Device, - /// The strides of the tensor. - pub strides: Vec, /// The datatype of the tensor. pub dtype: DType, /// Runtime quantization parameters, if applicable @@ -35,8 +33,8 @@ impl From> for TensorHandle { fn from(val: CubeTensor) -> Self { TensorHandle::new( val.handle, - val.shape.to_vec(), - val.strides.to_vec(), + val.meta.shape().clone(), + val.meta.strides().clone(), val.dtype.into(), ) } @@ -61,9 +59,9 @@ where fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!( "CubeTensor {{ shape: {:?}, device: {:?}, strides: {:?}, elem: {}, runtime: {}}}", - self.shape, + self.meta.shape(), self.device, - self.strides, + self.meta.strides(), self.dtype.name(), R::name(&self.client), )) @@ -78,9 +76,8 @@ where Self { client: self.client.clone(), handle: self.handle.clone(), - shape: self.shape.clone(), + meta: self.meta.clone(), device: self.device.clone(), - strides: self.strides.clone(), dtype: self.dtype, qparams: self.qparams.clone(), } @@ -93,11 +90,11 @@ impl TensorMetadata for CubeTensor { } fn shape(&self) -> Shape { - self.shape.clone() + self.meta.shape().clone() } fn rank(&self) -> usize { - self.shape.num_dims() + self.meta.rank() } } @@ -122,17 +119,15 @@ where pub fn new( client: ComputeClient, handle: Handle, - shape: Shape, + metadata: Metadata, device: R::Device, - strides: Vec, dtype: DType, ) -> Self { CubeTensor { client, handle, - shape, + meta: Box::new(metadata), device, - strides, dtype, qparams: None, } @@ -147,7 +142,7 @@ where dtype: DType, ) -> Self { let ndims = shape.num_dims(); - let mut strides = vec![0; ndims]; + let mut strides = strides![0; ndims]; let mut current = 1; shape.iter().enumerate().rev().for_each(|(index, val)| { @@ -158,8 +153,7 @@ where Self { client, handle, - shape, - strides, + meta: Box::new(Metadata::new(shape, strides)), device, dtype, qparams: None, @@ -168,17 +162,16 @@ where /// Change the context of the current tensor and return the newly transferred tensor. pub fn to_client(&self, client: ComputeClient, device: R::Device) -> Self { - let desc = self - .handle - .copy_descriptor(&self.shape, &self.strides, self.elem_size()); + let desc = + self.handle + .copy_descriptor(self.meta.shape(), self.meta.strides(), self.elem_size()); let alloc = self.client.to_client_tensor(desc, &client); Self { client, handle: alloc.handle, - shape: self.shape.clone(), + meta: Box::new(Metadata::new(self.shape(), alloc.strides)), device, - strides: alloc.strides, dtype: self.dtype, qparams: self.qparams.clone(), } @@ -188,8 +181,8 @@ where pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> { TensorHandleRef { handle: &self.handle, - strides: &self.strides, - shape: &self.shape, + strides: self.meta.strides(), + shape: self.meta.shape(), runtime: PhantomData, elem_size: self.elem_size(), } @@ -250,11 +243,11 @@ where if !self.handle.can_mut() || !self.is_nonoverlapping() { return false; } - let ndims = self.shape.num_dims(); + let ndims = self.meta.num_dims(); for i in 0..ndims { - let shape_lhs = self.shape[i]; - let shape_rhs = rhs.shape[i]; + let shape_lhs = self.meta.shape()[i]; + let shape_rhs = rhs.meta.shape()[i]; // Output tensor will be different from the mutable tensor. if shape_lhs < shape_rhs { @@ -309,23 +302,26 @@ where /// strides at position k is equal to the product of the shapes /// at all positions greater than k. However, all axes with a shape of 1 are ignored. pub fn is_contiguous(&self) -> bool { - is_contiguous(&self.shape, &self.strides) + is_contiguous(self.meta.shape(), self.meta.strides()) } /// Check if the current tensor has a contiguous backing buffer (no overlap and no empty memory /// regions within the shape). pub fn is_contiguous_buffer(&self) -> bool { - self.shape.num_elements() * self.dtype.size() == self.handle.size() as usize + self.meta.shape().num_elements() * self.dtype.size() == self.handle.size() as usize } /// Checks if the tensor is non-overlapping (can be safely written to). pub fn is_nonoverlapping(&self) -> bool { - if self.strides.contains(&0) { + let shape = self.meta.shape(); + let strides = self.meta.strides(); + + if strides.contains(&0) { return false; } let rank = self.rank(); if rank > 1 { - let mut dims = self.shape.iter().zip(&self.strides).collect::>(); + let mut dims = shape.iter().zip(strides.iter()).collect::>(); dims.sort_by_key(|(_, stride)| **stride); let mut max_offset = 0; diff --git a/crates/burn-cubecl/src/tensor/quantization.rs b/crates/burn-cubecl/src/tensor/quantization.rs index 5f0c586e3e..c77b3ba6b8 100644 --- a/crates/burn-cubecl/src/tensor/quantization.rs +++ b/crates/burn-cubecl/src/tensor/quantization.rs @@ -1,4 +1,5 @@ use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor}; +use burn_std::{Metadata, Strides}; use cubecl::quant::scheme::{QuantStore, QuantValue}; use cubecl::{client::ComputeClient, server::Handle}; @@ -17,16 +18,15 @@ impl CubeTensor { handle: Handle, shape: Shape, device: R::Device, - strides: Vec, + strides: Strides, dtype: DType, qparams: QParams, ) -> Self { CubeTensor { client, handle, - shape, + meta: Box::new(Metadata::new(shape, strides)), device, - strides, dtype, qparams: Some(qparams), } @@ -47,18 +47,16 @@ impl CubeTensor { QuantValue::Q8F | QuantValue::Q8S => CubeTensor { client: self.client.clone(), handle: self.handle.clone(), - shape: self.shape.clone(), + meta: self.meta.clone(), device: self.device.clone(), - strides: self.strides.clone(), dtype: DType::I8, qparams: None, }, QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor { client: self.client.clone(), handle: self.handle.clone(), - shape: self.shape.clone(), + meta: self.meta.clone(), device: self.device.clone(), - strides: self.strides.clone(), dtype: DType::U8, qparams: None, }, @@ -72,15 +70,14 @@ impl CubeTensor { }, QuantStore::PackedU32(packed_dim) => { let packed_dim = self.rank() - packed_dim - 1; - let mut shape = self.shape.clone(); + let mut shape = self.shape(); shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants()); CubeTensor { client: self.client.clone(), handle: self.handle.clone(), - shape, + meta: Box::new(Metadata::new(shape, self.meta.strides.clone())), device: self.device.clone(), - strides: self.strides.clone(), dtype: DType::U32, qparams: None, } @@ -88,15 +85,14 @@ impl CubeTensor { QuantStore::PackedNative(packed_dim) => match scheme.value { QuantValue::E2M1 => { let packed_dim = self.rank() - packed_dim - 1; - let mut shape = self.shape.clone(); + let mut shape = self.shape(); shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants()); CubeTensor { client: self.client.clone(), handle: self.handle.clone(), - shape, + meta: Box::new(Metadata::new(shape, self.meta.strides.clone())), device: self.device.clone(), - strides: self.strides.clone(), dtype: DType::U8, qparams: None, } @@ -118,9 +114,8 @@ impl CubeTensor { Some(CubeTensor::new( self.client.clone(), handle, - qparams.scales.shape.clone(), + qparams.scales.metadata.clone(), self.device.clone(), - qparams.scales.strides.clone(), qparams.scales.dtype, )) } diff --git a/crates/burn-ir/src/builder.rs b/crates/burn-ir/src/builder.rs index 4f9f09eebd..1bdd1a3673 100644 --- a/crates/burn-ir/src/builder.rs +++ b/crates/burn-ir/src/builder.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; use burn_backend::{ - DType, Distribution, Shape, Slice, calculate_matmul_output, + DType, Distribution, Shape, Slice, SliceOps, calculate_matmul_output, ops::{ conv::{ calculate_conv_output_shape, calculate_conv_transpose_output_shape, @@ -242,13 +242,13 @@ impl_ir_create!( dim1: usize, dim2: usize }, - shape = input.shape.clone().swap(dim1, dim2).unwrap(), + shape = input.shape.clone().swapped(dim1, dim2).unwrap(), dtype = input.dtype ); impl_ir_create!( PermuteOpIr { input: TensorIr, axes: Vec }, - shape = input.shape.clone().permute(&axes).unwrap(), + shape = input.shape.clone().permuted(&axes).unwrap(), dtype = input.dtype ); diff --git a/crates/burn-std/Cargo.toml b/crates/burn-std/Cargo.toml index 165558616e..ba5ff9a692 100644 --- a/crates/burn-std/Cargo.toml +++ b/crates/burn-std/Cargo.toml @@ -35,6 +35,7 @@ cubecl-common = { workspace = true, default-features = false, features = [ "serde", "shared-bytes", ] } +cubecl-zspace = { workspace = true, default-features = false } # Enable extra-platforms for portable-atomic support on targets without native atomics (e.g., thumbv6m) # This is needed because cubecl-common's shared-bytes feature pulls in bytes bytes = { workspace = true } diff --git a/crates/burn-std/src/errors.rs b/crates/burn-std/src/errors.rs deleted file mode 100644 index 1aaa460b8f..0000000000 --- a/crates/burn-std/src/errors.rs +++ /dev/null @@ -1,189 +0,0 @@ -//! # Common Burn Errors - -use alloc::string::String; -use core::ops::Range; - -/// Describes the kind of an index. -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub enum IndexKind { - /// The index of an element in a dimension. - Element, - - /// The index of a dimension. - Dimension, -} - -impl IndexKind { - /// Get the display name of the kind. - pub fn name(&self) -> &'static str { - match self { - IndexKind::Element => "element", - IndexKind::Dimension => "dimension", - } - } -} - -/// Access Bounds Error. -#[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub enum BoundsError { - /// Generic bounds error. - Generic(String), - - /// Index out of bounds. - Index { - /// The kind of index that was out of bounds. - kind: IndexKind, - - /// The index that was out of bounds. - index: isize, - - /// The range of valid indices. - bounds: Range, - }, -} - -impl BoundsError { - /// Create a new index error. - pub fn index(kind: IndexKind, index: isize, bounds: Range) -> Self { - Self::Index { - kind, - index, - bounds, - } - } -} - -impl core::fmt::Display for BoundsError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::Generic(msg) => write!(f, "BoundsError: {}", msg), - Self::Index { - kind, - index, - bounds: range, - } => write!( - f, - "BoundsError: {} {} out of bounds: {:?}", - kind.name(), - index, - range - ), - } - } -} - -impl core::error::Error for BoundsError {} - -/// Common Expression Error. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExpressionError { - /// Parse Error. - ParseError { - /// The error message. - message: String, - /// The source expression. - source: String, - }, - - /// Invalid Expression. - InvalidExpression { - /// The error message. - message: String, - /// The source expression. - source: String, - }, -} - -impl core::fmt::Display for ExpressionError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - Self::ParseError { message, source } => { - write!(f, "ExpressionError: ParseError: {} ({})", message, source) - } - Self::InvalidExpression { message, source } => write!( - f, - "ExpressionError: InvalidExpression: {} ({})", - message, source - ), - } - } -} - -impl core::error::Error for ExpressionError {} - -impl ExpressionError { - /// Constructs a new [`ExpressionError::ParseError`]. - /// - /// This function is a utility for creating instances where a parsing error needs to be represented, - /// encapsulating a descriptive error message and the source of the error. - /// - /// # Parameters - /// - /// - `message`: A value that can be converted into a `String`, representing a human-readable description - /// of the parsing error. - /// - `source`: A value that can be converted into a `String`, typically identifying the origin or - /// input that caused the parsing error. - pub fn parse_error(message: impl Into, source: impl Into) -> Self { - Self::ParseError { - message: message.into(), - source: source.into(), - } - } - - /// Creates a new [`ExpressionError::InvalidExpression`]. - /// - /// # Parameters - /// - `message`: A detailed message describing the nature of the invalid expression. - /// Accepts any type that can be converted into a `String`. - /// - `source`: The source or context in which the invalid expression occurred. - /// Accepts any type that can be converted into a `String`. - pub fn invalid_expression(message: impl Into, source: impl Into) -> Self { - Self::InvalidExpression { - message: message.into(), - source: source.into(), - } - } -} -#[cfg(test)] -mod tests { - use super::*; - use alloc::format; - use alloc::string::ToString; - - #[test] - fn test_bounds_error_display() { - assert_eq!( - format!("{}", BoundsError::Generic("test".to_string())), - "BoundsError: test" - ); - assert_eq!( - format!( - "{}", - BoundsError::Index { - kind: IndexKind::Element, - index: 1, - bounds: 0..2 - } - ), - "BoundsError: element 1 out of bounds: 0..2" - ); - } - - #[test] - fn test_parse_error() { - let err = ExpressionError::parse_error("test", "source"); - assert_eq!( - format!("{:?}", err), - "ParseError { message: \"test\", source: \"source\" }" - ); - } - - #[test] - fn test_invalid_expression() { - let err = ExpressionError::invalid_expression("test", "source"); - assert_eq!( - format!("{:?}", err), - "InvalidExpression { message: \"test\", source: \"source\" }" - ); - } -} diff --git a/crates/burn-std/src/lib.rs b/crates/burn-std/src/lib.rs index fda2d1484b..e16bce84fa 100644 --- a/crates/burn-std/src/lib.rs +++ b/crates/burn-std/src/lib.rs @@ -17,8 +17,7 @@ pub mod tensor; pub use tensor::*; /// Common Errors. -pub mod errors; -pub use errors::*; +pub use cubecl_zspace::errors::{self, *}; /// Network utilities. #[cfg(feature = "network")] diff --git a/crates/burn-std/src/tensor/index_conversion.rs b/crates/burn-std/src/tensor/index_conversion.rs deleted file mode 100644 index a4e7c15b3b..0000000000 --- a/crates/burn-std/src/tensor/index_conversion.rs +++ /dev/null @@ -1,159 +0,0 @@ -//! # Common Index Coercions -//! -//! This module contains common index coercions that can be used to implement -//! various indexing operations. - -use super::indexing::IndexWrap; -use core::fmt::Debug; - -/// Types which can be converted to a `usize` Size. -pub trait AsSize: Debug + Copy + Sized { - /// Convert to a `usize` Size. - fn as_size(self) -> usize; -} - -impl AsSize for &T -where - T: AsSize, -{ - fn as_size(self) -> usize { - (*self).as_size() - } -} - -macro_rules! gen_as_size { - ($ty:ty) => { - impl AsSize for $ty { - fn as_size(self) -> usize { - self.try_into() - .unwrap_or_else(|_| panic!( - "Unable to convert value to usize: {}_{}", - self, - stringify!($ty))) - } - } - }; - ($($ty:ty),*) => {$(gen_as_size!($ty);)*}; -} - -gen_as_size!(usize, isize, i64, u64, i32, u32, i16, u16, i8, u8); - -/// Helper trait for implementing indexing with support for negative indices. -/// -/// # Example -/// ```rust -/// use burn_std::AsIndex; -/// -/// fn example(dim: I, size: usize) -> isize { -/// let dim: usize = dim.expect_dim_index(D); -/// unimplemented!() -/// } -/// ``` -pub trait AsIndex: Debug + Copy + Sized { - /// Converts into an `isize` index. - fn as_index(self) -> isize; - - /// Short-form [`IndexWrap::expect_index(idx, size)`]. - fn expect_elem_index(self, size: usize) -> usize { - IndexWrap::expect_elem(self, size) - } - - /// Short-form [`IndexWrap::expect_dim(idx, size)`]. - fn expect_dim_index(self, size: usize) -> usize { - IndexWrap::expect_dim(self, size) - } -} - -impl AsIndex for &T -where - T: AsIndex, -{ - fn as_index(self) -> isize { - (*self).as_index() - } -} - -macro_rules! gen_as_index { - ($ty:ty) => { - impl AsIndex for $ty { - fn as_index(self) -> isize { - self as isize - } - } - }; - ($($ty:ty),*) => {$(gen_as_index!($ty);)*}; -} - -gen_as_index!(usize, isize, i64, u64, i32, u32, i16, u16, i8, u8); - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_as_size() { - assert_eq!(1_usize.as_size(), 1_usize); - assert_eq!(1_isize.as_size(), 1_usize); - assert_eq!(1_i64.as_size(), 1_usize); - assert_eq!(1_u64.as_size(), 1_usize); - assert_eq!(1_i32.as_size(), 1_usize); - assert_eq!(1_u32.as_size(), 1_usize); - assert_eq!(1_i16.as_size(), 1_usize); - assert_eq!(1_u16.as_size(), 1_usize); - assert_eq!(1_i8.as_size(), 1_usize); - assert_eq!(1_u8.as_size(), 1_usize); - - assert_eq!((&1_usize).as_size(), 1_usize); - } - - #[test] - #[should_panic(expected = "Unable to convert value to usize: -1_isize")] - fn test_as_size_isize_panic() { - (-1_isize).as_size(); - } - #[test] - #[should_panic(expected = "Unable to convert value to usize: -1_i64")] - fn test_as_size_i64() { - (-1_i64).as_size(); - } - - #[test] - #[should_panic(expected = "Unable to convert value to usize: -1_i32")] - fn test_as_size_i32() { - (-1_i32).as_size(); - } - - #[test] - #[should_panic(expected = "Unable to convert value to usize: -1_i16")] - fn test_as_size_i16() { - (-1_i16).as_size(); - } - - #[test] - #[should_panic(expected = "Unable to convert value to usize: -1_i8")] - fn test_as_size_i8() { - (-1_i8).as_size(); - } - - #[test] - fn test_as_index() { - assert_eq!(1_usize.as_index(), 1_isize); - assert_eq!(1_isize.as_index(), 1_isize); - assert_eq!(1_i64.as_index(), 1_isize); - assert_eq!(1_u64.as_index(), 1_isize); - assert_eq!(1_i32.as_index(), 1_isize); - assert_eq!(1_u32.as_index(), 1_isize); - assert_eq!(1_i16.as_index(), 1_isize); - assert_eq!(1_u16.as_index(), 1_isize); - assert_eq!(1_i8.as_index(), 1_isize); - assert_eq!(1_u8.as_index(), 1_isize); - - assert_eq!((&1_usize).as_index(), 1_isize); - - assert_eq!(-1_isize.as_index(), -1_isize); - assert_eq!(-1_i64.as_index(), -1_isize); - assert_eq!(-1_i32.as_index(), -1_isize); - assert_eq!(-1_i16.as_index(), -1_isize); - assert_eq!(-1_i8.as_index(), -1_isize); - } -} diff --git a/crates/burn-std/src/tensor/indexing.rs b/crates/burn-std/src/tensor/indexing.rs deleted file mode 100644 index 187f4ae85d..0000000000 --- a/crates/burn-std/src/tensor/indexing.rs +++ /dev/null @@ -1,321 +0,0 @@ -//! A module for indexing utility machinery. - -use crate::IndexKind; -pub use crate::errors::BoundsError; -#[allow(unused_imports)] -use alloc::format; -#[allow(unused_imports)] -use alloc::string::{String, ToString}; -use core::fmt::Debug; - -pub use crate::tensor::index_conversion::AsIndex; - -/// Wraps an index with negative indexing support. -#[derive(Debug)] -pub struct IndexWrap { - kind: IndexKind, - wrap_scalar: bool, -} - -impl IndexWrap { - /// Get an instance for wrapping negative indices. - pub fn index() -> Self { - Self { - kind: IndexKind::Element, - wrap_scalar: false, - } - } - - /// Get an instance for wrapping negative dimensions. - pub fn dim() -> Self { - Self { - kind: IndexKind::Dimension, - wrap_scalar: false, - } - } - - /// Set the policy for wrapping 0-size ranges. - /// - /// When ``size`` == 0: - /// - if `wrap_scalar`; then ``size == 1`` - /// - otherwise; an error. - pub fn with_wrap_scalar(self, wrap_scalar: bool) -> Self { - Self { - wrap_scalar, - ..self - } - } - - /// Wrap an index with negative indexing support. - pub fn try_wrap(&self, idx: I, size: usize) -> Result { - try_wrap(idx, size, self.kind, self.wrap_scalar) - } - - /// Wrap an index with negative indexing support. - pub fn expect_wrap(&self, idx: I, size: usize) -> usize { - expect_wrap(idx, size, self.kind, self.wrap_scalar) - } - - /// Short-form [`NegativeWrap::index().expect_wrap(idx, size)`]. - pub fn expect_elem(idx: I, size: usize) -> usize { - Self::index().expect_wrap(idx, size) - } - - /// Short-form [`NegativeWrap::dim().expect_wrap(idx, size)`]. - pub fn expect_dim(idx: I, size: usize) -> usize { - Self::dim().expect_wrap(idx, size) - } -} - -/// Wraps an index with negative indexing support. -/// -/// ## Arguments -/// - `idx` - The index to canonicalize. -/// - `size` - The size of the index range. -/// - `kind` - The kind of index (for error messages). -/// - `size_name` - The name of the size (for error messages). -/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1. -/// -/// ## Returns -/// -/// A `Result` of the canonicalized index. -pub fn expect_wrap(idx: I, size: usize, kind: IndexKind, wrap_scalar: bool) -> usize -where - I: AsIndex, -{ - try_wrap(idx, size, kind, wrap_scalar).expect("valid index") -} - -/// Wraps an index with negative indexing support. -/// -/// ## Arguments -/// - `idx` - The index to canonicalize. -/// - `size` - The size of the index range. -/// - `kind` - The kind of index (for error messages). -/// - `size_name` - The name of the size (for error messages). -/// - `wrap_scalar` - If true, treat 0-size ranges as having size 1. -/// -/// ## Returns -/// -/// A `Result` of the canonicalized index. -pub fn try_wrap( - idx: I, - size: usize, - kind: IndexKind, - wrap_scalar: bool, -) -> Result -where - I: AsIndex, -{ - let source_idx = idx.as_index(); - let source_size = size; - - let size = if source_size > 0 { - source_size - } else { - if !wrap_scalar { - return Err(BoundsError::index(kind, source_idx, 0..0)); - } - 1 - }; - - if source_idx >= 0 && (source_idx as usize) < size { - return Ok(source_idx as usize); - } - - let _idx = if source_idx < 0 { - source_idx + size as isize - } else { - source_idx - }; - - if _idx < 0 || (_idx as usize) >= size { - let rank = size as isize; - - return Err(BoundsError::index(kind, source_idx, 0..rank)); - } - - Ok(_idx as usize) -} - -/// Wraps a dimension index to be within the bounds of the dimension size. -/// -/// ## Arguments -/// -/// * `idx` - The dimension index to wrap. -/// * `size` - The size of the dimension. -/// -/// ## Returns -/// -/// The positive wrapped dimension index. -#[inline] -#[must_use] -pub fn wrap_index(idx: I, size: usize) -> usize -where - I: AsIndex, -{ - if size == 0 { - return 0; // Avoid modulo by zero - } - let wrapped = idx.as_index().rem_euclid(size as isize); - if wrapped < 0 { - (wrapped + size as isize) as usize - } else { - wrapped as usize - } -} - -/// Compute the ravel index for the given coordinates. -/// -/// This returns the row-major order raveling: -/// * `strides[-1] = 1` -/// * `strides[i] = strides[i+1] * dims[i+1]` -/// * `dim_strides = coords * strides` -/// * `ravel = sum(dim_strides)` -/// -/// # Arguments -/// - `indices`: the index for each dimension; must be the same length as `shape`. -/// - `shape`: the shape of each dimension; be the same length as `indices`. -/// -/// # Returns -/// - the ravel offset index. -pub fn ravel_index(indices: &[I], shape: &[usize]) -> usize { - assert_eq!( - shape.len(), - indices.len(), - "Coordinate rank mismatch: expected {}, got {}", - shape.len(), - indices.len(), - ); - - let mut ravel_idx = 0; - let mut stride = 1; - - for (i, &dim) in shape.iter().enumerate().rev() { - let idx = indices[i]; - let coord = IndexWrap::index().expect_wrap(idx, dim); - ravel_idx += coord * stride; - stride *= dim; - } - - ravel_idx -} - -#[cfg(test)] -#[allow(clippy::identity_op, reason = "useful for clarity")] -mod tests { - use super::*; - use alloc::vec; - - #[test] - fn test_ravel() { - let shape = vec![2, 3, 4, 5]; - - assert_eq!(ravel_index(&[0, 0, 0, 0], &shape), 0); - assert_eq!( - ravel_index(&[1, 2, 3, 4], &shape), - 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 - ); - } - - #[test] - fn test_wrap_idx() { - assert_eq!(wrap_index(0, 3), 0_usize); - assert_eq!(wrap_index(3, 3), 0_usize); - assert_eq!(wrap_index(2 * 3, 3), 0_usize); - assert_eq!(wrap_index(0 - 3, 3), 0_usize); - assert_eq!(wrap_index(0 - 2 * 3, 3), 0_usize); - - assert_eq!(wrap_index(1, 3), 1_usize); - assert_eq!(wrap_index(1 + 3, 3), 1_usize); - assert_eq!(wrap_index(1 + 2 * 3, 3), 1_usize); - assert_eq!(wrap_index(1 - 3, 3), 1_usize); - assert_eq!(wrap_index(1 - 2 * 3, 3), 1_usize); - - assert_eq!(wrap_index(2, 3), 2_usize); - assert_eq!(wrap_index(2 + 3, 3), 2_usize); - assert_eq!(wrap_index(2 + 2 * 3, 3), 2_usize); - assert_eq!(wrap_index(2 - 3, 3), 2_usize); - assert_eq!(wrap_index(2 - 2 * 3, 3), 2_usize); - } - - #[test] - fn test_negative_wrap() { - assert_eq!(IndexWrap::index().expect_wrap(0, 3), 0); - assert_eq!(IndexWrap::index().expect_wrap(1, 3), 1); - assert_eq!(IndexWrap::index().expect_wrap(2, 3), 2); - assert_eq!(IndexWrap::index().expect_wrap(-1, 3), 2); - assert_eq!(IndexWrap::index().expect_wrap(-2, 3), 1); - assert_eq!(IndexWrap::index().expect_wrap(-3, 3), 0); - - assert_eq!(IndexWrap::dim().expect_wrap(0, 3), 0); - assert_eq!(IndexWrap::dim().expect_wrap(1, 3), 1); - assert_eq!(IndexWrap::dim().expect_wrap(2, 3), 2); - assert_eq!(IndexWrap::dim().expect_wrap(-1, 3), 2); - assert_eq!(IndexWrap::dim().expect_wrap(-2, 3), 1); - assert_eq!(IndexWrap::dim().expect_wrap(-3, 3), 0); - - assert_eq!( - IndexWrap::index().try_wrap(3, 3), - Err(BoundsError::Index { - kind: IndexKind::Element, - index: 3, - bounds: 0..3, - }) - ); - assert_eq!( - IndexWrap::index().try_wrap(-4, 3), - Err(BoundsError::Index { - kind: IndexKind::Element, - index: -4, - bounds: 0..3, - }) - ); - assert_eq!( - IndexWrap::dim().try_wrap(3, 3), - Err(BoundsError::Index { - kind: IndexKind::Dimension, - index: 3, - bounds: 0..3, - }) - ); - assert_eq!( - IndexWrap::dim().try_wrap(-4, 3), - Err(BoundsError::Index { - kind: IndexKind::Dimension, - index: -4, - bounds: 0..3, - }) - ); - } - - #[test] - fn test_negative_wrap_scalar() { - assert_eq!( - IndexWrap::index().try_wrap(0, 0), - Err(BoundsError::Index { - kind: IndexKind::Element, - index: 0, - bounds: 0..0, - }) - ); - - assert_eq!( - IndexWrap::index().with_wrap_scalar(true).expect_wrap(0, 0), - 0 - ); - assert_eq!( - IndexWrap::index().with_wrap_scalar(true).expect_wrap(-1, 0), - 0 - ); - - assert_eq!( - IndexWrap::index().with_wrap_scalar(false).try_wrap(1, 0), - Err(BoundsError::Index { - kind: IndexKind::Element, - index: 1, - bounds: 0..0, - }) - ); - } -} diff --git a/crates/burn-std/src/tensor/mod.rs b/crates/burn-std/src/tensor/mod.rs index 7fa2ad3b35..19e6802219 100644 --- a/crates/burn-std/src/tensor/mod.rs +++ b/crates/burn-std/src/tensor/mod.rs @@ -1,18 +1,17 @@ pub mod dtype; -pub mod index_conversion; -pub mod indexing; pub mod quantization; pub mod shape; pub mod slice; pub use dtype::*; -pub use index_conversion::*; -pub use indexing::*; pub use quantization::*; pub use shape::*; pub use slice::*; -use alloc::{vec, vec::Vec}; +pub use cubecl_zspace::indexing::{self, *}; +pub use cubecl_zspace::{Strides, metadata::Metadata, strides}; + +use alloc::vec; /// Check if the current tensor is contiguous. /// @@ -26,7 +25,7 @@ pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { return true; } - for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) { + for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) { if expected != stride { return false; } @@ -39,16 +38,15 @@ pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool { /// /// In a contiguous row-major tensor, the stride for each dimension /// equals the product of all dimension sizes to its right. -pub fn contiguous_strides(shape: &[usize]) -> Vec { - let mut strides = Vec::with_capacity(shape.len()); +pub fn contiguous_strides(shape: &[usize]) -> Strides { + let mut strides = strides![0; shape.len()]; let mut current = 1; - for &dim in shape.iter().rev() { - strides.push(current); + for (i, &dim) in shape.iter().enumerate().rev() { + strides[i] = current; current *= dim; } - strides.reverse(); strides } @@ -58,7 +56,7 @@ pub enum ReshapeAction { /// Updating the strides is sufficient to handle the reshape. UpdateStrides { /// The new strides. - strides: Vec, + strides: Strides, }, /// The strides are not compatible, we should recompute the buffer. Recompute, @@ -127,8 +125,8 @@ pub fn broadcast_strides( rank_prev: usize, num_elems: usize, strides: &[usize], -) -> Vec { - let mut strides_new = vec![num_elems; rank_prev + n_new_batch]; +) -> Strides { + let mut strides_new = strides![num_elems; rank_prev + n_new_batch]; for (i, s) in strides.iter().enumerate() { strides_new[i + n_new_batch] = *s; @@ -138,8 +136,8 @@ pub fn broadcast_strides( } /// Calculate the new strides given added split dimensions. -pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Vec { - let mut strides_new = vec![1; shape_new.len()]; +pub fn split_strides(shape: &[usize], strides: &[usize], shape_new: &[usize]) -> Strides { + let mut strides_new = strides![1; shape_new.len()]; let mut old_idx = shape.len() - 1; let mut current_stride = strides[old_idx]; diff --git a/crates/burn-std/src/tensor/quantization.rs b/crates/burn-std/src/tensor/quantization.rs index 4e4361b6a1..7048552769 100644 --- a/crates/burn-std/src/tensor/quantization.rs +++ b/crates/burn-std/src/tensor/quantization.rs @@ -17,7 +17,7 @@ use core::any::TypeId; use num_traits::PrimInt; use serde::{Deserialize, Serialize}; -use crate::{DType, Shape, bytes::Bytes}; +use crate::{DType, Metadata, Shape, bytes::Bytes}; #[derive( Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default, @@ -60,10 +60,8 @@ pub struct QParamTensor { pub offset_start: usize, /// Offset of tensor end from the end of the buffer pub offset_end: usize, - /// Shape of the tensor - pub shape: Shape, - /// Strides of the tensor - pub strides: Vec, + /// Metadata of the tensor + pub metadata: Metadata, /// Data type of the tensor pub dtype: DType, } diff --git a/crates/burn-std/src/tensor/shape.rs b/crates/burn-std/src/tensor/shape.rs index 65eef15e65..10acbd3835 100644 --- a/crates/burn-std/src/tensor/shape.rs +++ b/crates/burn-std/src/tensor/shape.rs @@ -1,173 +1,17 @@ //! Tensor shape definition. -use super::indexing::ravel_index; -use super::{AsIndex, Slice, SliceArg}; -use alloc::format; -use alloc::string::{String, ToString}; +use super::{Slice, SliceArg}; use alloc::vec::Vec; -use core::fmt::{Debug, Display, Formatter}; -use core::str::FromStr; -use core::{ - ops::{Deref, DerefMut, Index, IndexMut, Range}, - slice::{Iter, IterMut, SliceIndex}, -}; -use serde::{Deserialize, Serialize}; -use smallvec::{SmallVec, smallvec}; +use core::ops::Range; pub use crate::errors::ExpressionError; -pub use crate::tensor::index_conversion::AsSize; -/// Shape of a tensor. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub struct Shape { - /// The dimensions of the tensor. - dims: SmallVec<[usize; 5]>, -} - -#[allow(missing_docs)] -#[derive(Debug, Clone, PartialEq, Eq)] -/// Error that can occur when attempting to modify shapes. -pub enum ShapeError { - /// The operands have different ranks. - RankMismatch { left: usize, right: usize }, - /// A pair of dimensions are incompatible for broadcasting. - IncompatibleDims { - left: usize, - right: usize, - dim: usize, - }, - /// Invalid dimension specified for the rank. - OutOfBounds { dim: usize, rank: usize }, - /// A pair of shapes are incompatible for the operation. - IncompatibleShapes { left: Shape, right: Shape }, - /// Invalid shape. - Invalid { reason: String }, -} - -impl ShapeError { - fn empty() -> Self { - Self::Invalid { - reason: "Shape is empty.".into(), - } - } -} - -impl Shape { - /// Constructs a new `Shape`. - pub fn new(dims: [usize; D]) -> Self { - // For backward compat - Self { - dims: SmallVec::from_slice(&dims), - } - } - - /// Returns the total number of elements of a tensor having this shape - pub fn num_elements(&self) -> usize { - self.dims.iter().product() - } - - /// Returns the number of dimensions. - /// - /// Alias for `Shape::rank()`. - pub fn num_dims(&self) -> usize { - self.dims.len() - } - - /// Returns the rank (the number of dimensions). - /// - /// Alias for `Shape::num_dims()`. - pub fn rank(&self) -> usize { - self.num_dims() - } - - // For compat with dims: [usize; D] - /// Returns the dimensions of the tensor as an array. - pub fn dims(&self) -> [usize; D] { - let mut dims = [1; D]; - dims[..D].copy_from_slice(&self.dims[..D]); - dims - } - - /// Change the shape to one dimensional with the same number of elements. - pub fn flatten(mut self) -> Self { - self.dims = SmallVec::from_slice(&[self.num_elements()]); - self - } - - /// Flatten the shape along a given range of dimensions. - /// - /// This function collapses the specified range of dimensions into a single dimension, - /// effectively flattening the tensor in that range. - /// - /// # Arguments - /// - /// - `start_dim`: The starting dimension of the range to be flattened, - /// supports negative indexing. - /// - `end_dim`: The ending dimension of the range to be flattened (inclusive), - /// supports negative indexing. - /// - /// # Returns - /// - /// A new `Shape` instance with the specified range of dimensions flattened. - /// - /// # Example - /// - /// ```rust - /// use burn_std::Shape; - /// - /// fn example() { - /// let shape = Shape::new([2, 3, 4]); - /// - /// let flattened = shape.flatten_dims(1, 2); - /// println!("{flattened}"); - /// // [2, 12] - /// } - /// ``` - pub fn flatten_dims(self, start_dim: impl AsIndex, end_dim: impl AsIndex) -> Self { - let rank = self.rank(); - let start = start_dim.expect_dim_index(rank); - let end = end_dim.expect_dim_index(rank); - - assert!( - start <= end, - "start_dim ({start}) must be <= than end_dim ({end})" - ); - - let existing = self.dims; - - let flattened_size = existing[start..=end].iter().product(); - - let new_rank = rank - (end - start); - let mut dims = smallvec![0; new_rank]; - dims[..start].copy_from_slice(&existing[..start]); - dims[start] = flattened_size; - dims[start + 1..].copy_from_slice(&existing[end + 1..]); - - Self { dims } - } - - /// Compute the ravel index for the given coordinates. - /// - /// This returns the row-major order raveling: - /// * `strides[-1] = 1` - /// * `strides[i] = strides[i+1] * dims[i+1]` - /// * `dim_strides = coords * strides` - /// * `ravel = sum(dim_strides)` - /// - /// # Arguments - /// - `indices`: the index for each dimension; must be the same length as `shape`. - /// - /// # Returns - /// - the ravel offset index. - pub fn ravel_index(&self, indices: &[I]) -> usize { - ravel_index(indices, &self.dims) - } +pub use cubecl_zspace::{Shape, ShapeError, calculate_matmul_output, shape}; +/// Slice-relatedo ops on [`Shape`] +pub trait SliceOps: Sized { /// Convert shape dimensions to full covering ranges (0..dim) for each dimension. - pub fn into_ranges(self) -> Vec> { - self.iter().map(|&d| 0..d).collect() - } - + fn into_ranges(self) -> Vec>; /// Converts slice arguments into an array of slice specifications for the shape. /// /// This method returns an array of `Slice` objects that can be used for slicing operations. @@ -227,157 +71,26 @@ impl Shape { /// - [`Shape::into_ranges`] - Convert to full covering ranges /// /// [`s!`]: crate::s! - pub fn into_slices(self, slices: S) -> Vec + fn into_slices(self, slices: S) -> Vec where - S: SliceArg, - { - slices.into_slices(&self) - } - - /// Construct a vector of the dims. - pub fn to_vec(&self) -> Vec { - self.dims.to_vec() - } - - /// Returns an iterator over the shape dimensions. - pub fn iter(&self) -> Iter<'_, usize> { - self.dims.iter() - } - - /// Mutable iterator over the dimensions. - pub fn iter_mut(&mut self) -> IterMut<'_, usize> { - self.dims.iter_mut() - } - - /// Borrow the underlying dimensions slice. - pub fn as_slice(&self) -> &[usize] { - &self.dims - } - - /// Borrow the underlying dimensions slice mutably. - pub fn as_mut_slice(&mut self) -> &mut [usize] { - &mut self.dims - } - - /// Insert a dimension of `size` at position `index`. - pub fn insert(&mut self, index: usize, size: usize) { - self.dims.insert(index, size); - } - - /// Remove and return the dimension at position `index` from the shape. - pub fn remove(&mut self, index: usize) -> usize { - self.dims.remove(index) - } - - /// Appends a dimension of `size` to the back of the shape. - pub fn push(&mut self, size: usize) { - self.dims.push(size) - } - - /// Extend the shape with the content of another shape or iterator. - pub fn extend(&mut self, iter: impl IntoIterator) { - self.dims.extend(iter) - } - - /// Swap two dimensions in the shape. - pub fn swap(mut self, dim1: usize, dim2: usize) -> Result { - if dim1 > self.rank() { - return Err(ShapeError::OutOfBounds { - dim: dim1, - rank: self.rank(), - }); - } - if dim2 > self.rank() { - return Err(ShapeError::OutOfBounds { - dim: dim2, - rank: self.rank(), - }); - } - self.dims.swap(dim1, dim2); - Ok(self) - } - - /// Reorder the shape dimensions according to the permutation of `axes`. - pub fn permute(mut self, axes: &[usize]) -> Result { - if axes.len() != self.rank() { - return Err(ShapeError::RankMismatch { - left: self.rank(), - right: axes.len(), - }); - } - debug_assert!(axes.iter().all(|i| i < &self.rank())); - - self.dims = axes.iter().map(|&i| self.dims[i]).collect(); - Ok(self) - } - - /// Repeated the specified `dim` a number of `times`. - pub fn repeat(mut self, dim: usize, times: usize) -> Result { - if dim >= self.rank() { - return Err(ShapeError::OutOfBounds { - dim, - rank: self.rank(), - }); - } - - self.dims[dim] *= times; - Ok(self) - } - - /// Returns a new shape where the specified `dim` is reduced to size 1. - pub fn reduce(mut self, dim: usize) -> Result { - if dim >= self.rank() { - return Err(ShapeError::OutOfBounds { - dim, - rank: self.rank(), - }); - } + S: SliceArg; + /// Compute the output shape from the given slices. + fn slice(self, slices: &[Slice]) -> Result; +} - self.dims[dim] = 1; - Ok(self) +impl SliceOps for Shape { + fn into_ranges(self) -> Vec> { + self.iter().map(|&d| 0..d).collect() } - /// Concatenates all shapes into a new one along the given dimension. - pub fn cat<'a, I>(shapes: I, dim: usize) -> Result + fn into_slices(self, slices: S) -> Vec where - I: IntoIterator, + S: SliceArg, { - let mut iter = shapes.into_iter(); - - let first = iter.next().ok_or(ShapeError::empty())?; - - if dim >= first.rank() { - return Err(ShapeError::OutOfBounds { - dim, - rank: first.rank(), - }); - } - - let mut shape = first.clone(); - - for s in iter { - if s.rank() != shape.rank() { - return Err(ShapeError::RankMismatch { - left: shape.rank(), - right: s.rank(), - }); - } - - if s[..dim] != shape[..dim] || s[dim + 1..] != shape[dim + 1..] { - return Err(ShapeError::IncompatibleShapes { - left: shape.clone(), - right: s.clone(), - }); - } - - shape[dim] += s[dim]; - } - - Ok(shape) + slices.into_slices(&self) } - /// Compute the output shape from the given slices. - pub fn slice(mut self, slices: &[Slice]) -> Result { + fn slice(mut self, slices: &[Slice]) -> Result { if slices.len() > self.rank() { return Err(ShapeError::RankMismatch { left: self.rank(), @@ -392,302 +105,6 @@ impl Shape { Ok(self) } - - /// Compute the output shape for binary operations with broadcasting support. - /// - /// - Shapes must be of the same rank (missing dimensions are not handled automatically). - /// - Two dimensions are compatible if they are equal, or one of them is 1. - /// - /// For example, a shape `[1, 1, 2, 4]` can be broadcast into `[7, 6, 2, 4]` - /// because its axes are either equal or 1. On the other hand, a shape `[2, 2]` - /// can *not* be broadcast into `[2, 4]`. - pub fn broadcast(&self, other: &Self) -> Result { - Self::broadcast_many([self, other]) - } - - /// Compute the broadcasted output shape across multiple input shapes. - /// - /// See also [broadcast](Self::broadcast). - pub fn broadcast_many<'a, I>(shapes: I) -> Result - where - I: IntoIterator, - { - let mut iter = shapes.into_iter(); - let mut broadcasted = iter.next().ok_or(ShapeError::empty())?.clone(); - let rank = broadcasted.rank(); - - for shape in iter { - if shape.rank() != rank { - return Err(ShapeError::RankMismatch { - left: rank, - right: shape.rank(), - }); - } - - for (dim, (d_lhs, &d_rhs)) in broadcasted.iter_mut().zip(shape.iter()).enumerate() { - match (*d_lhs, d_rhs) { - (a, b) if a == b => {} // same - (1, b) => *d_lhs = b, // broadcast to rhs - (_a, 1) => {} // keep existing dimension - _ => { - return Err(ShapeError::IncompatibleDims { - left: *d_lhs, - right: d_rhs, - dim, - }); - } - } - } - } - - Ok(broadcasted) - } - - /// Expand this shape to match the target shape, following broadcasting rules. - pub fn expand(&self, target: Shape) -> Result { - let target_rank = target.rank(); - if self.rank() > target_rank { - return Err(ShapeError::RankMismatch { - left: self.rank(), - right: target_rank, - }); - } - - for (i, (dim_target, dim_self)) in target.iter().rev().zip(self.iter().rev()).enumerate() { - if dim_self != dim_target && *dim_self != 1 { - return Err(ShapeError::IncompatibleDims { - left: *dim_self, - right: *dim_target, - dim: target_rank - i - 1, - }); - } - } - - Ok(target) - } - - /// Reshape this shape to the target shape. - pub fn reshape(&self, args: A) -> Result - where - A: AsRef<[T]> + Debug, - T: AsIndex, - { - let args = args.as_ref(); - let mut infer_index = None; - let mut dims = Vec::new(); - - let mut new_size = 1; - - for (idx, &s) in args.iter().enumerate() { - let s = s.as_index(); - if s > 0 { - let s = s as usize; - new_size *= s; - dims.push(s); - } else if s == 0 { - // We need to find the index of the 0 dimensions and - // replace them with the actual dimension value. - let s = self.dims[idx]; - new_size *= s; - dims.push(s); - } else if s == -1 { - match infer_index { - None => { - infer_index = Some(idx); - // Used by / Replaced by handling later. - dims.push(1); - } - Some(_) => { - return Err(ShapeError::Invalid { - reason: "Repeated -1 in reshape".to_string(), - }); - } - } - } else { - return Err(ShapeError::Invalid { - reason: "The given shape cannot contain negative dimensions (other than -1)." - .to_string(), - }); - } - } - - let source_size = self.num_elements(); - match infer_index { - None => { - if source_size != new_size { - return Err(ShapeError::Invalid { - reason: format!( - "The given shape doesn't have the same number of elements as the current shape. Current shape: {self}, target shape: {dims:?}.", - ), - }); - } - } - Some(idx) => { - if !source_size.is_multiple_of(new_size) { - return Err(ShapeError::Invalid { - reason: format!( - "Cannot infer a valid target shape. Current shape: {self}, target dimensions: {args:?}." - ), - }); - } - dims[idx] = source_size / new_size; - } - } - - Ok(dims.into()) - } - - /// Returns the raw inner storage of the shape. Avoid using this where possible, since the storage - /// may change at any time. - pub fn inner_mut(&mut self) -> &mut SmallVec<[usize; 5]> { - &mut self.dims - } -} - -/// Compute the output shape for matrix multiplication with broadcasting support. -/// -/// The last two dimensions are treated as matrices, while preceding dimensions -/// follow broadcast semantics similar to elementwise operations. -pub fn calculate_matmul_output(lhs: &Shape, rhs: &Shape) -> Result { - let rank = lhs.rank(); - if rank != rhs.rank() { - return Err(ShapeError::RankMismatch { - left: rank, - right: rhs.rank(), - }); - } - - if lhs[rank - 1] != rhs[rank - 2] { - return Err(ShapeError::IncompatibleShapes { - left: lhs.clone(), - right: rhs.clone(), - }); - } - - let mut shape = if rank > 2 { - // Broadcast leading dims - Shape::from(&lhs[..rank - 2]).broadcast(&Shape::from(&rhs[..rank - 2]))? - } else { - Shape::new([]) - }; - shape.extend([lhs[rank - 2], rhs[rank - 1]]); - - Ok(shape) -} - -impl Display for Shape { - fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { - self.dims.fmt(f) - } -} - -impl FromStr for Shape { - type Err = crate::ExpressionError; - - fn from_str(source: &str) -> Result { - let mut s = source.trim(); - - const DELIMS: [(&str, &str); 2] = [("[", "]"), ("(", ")")]; - - for (open, close) in DELIMS { - if let Some(p) = s.strip_prefix(open) { - if let Some(p) = p.strip_suffix(close) { - s = p.trim(); - break; - } else { - return Err(crate::ExpressionError::ParseError { - message: "Unbalanced delimiters".to_string(), - source: source.to_string(), - }); - } - } - } - - if s.is_empty() { - return Ok(Shape::new([])); - } - - let dims = - s.split(',') - .map(|dim_str| { - dim_str.trim().parse::().map_err(|_| { - crate::ExpressionError::ParseError { - message: "Unable to parse shape".to_string(), - source: source.to_string(), - } - }) - }) - .collect::, crate::ExpressionError>>()?; - - if dims.is_empty() { - unreachable!("Split should have returned at least one element"); - } - - Ok(Shape { dims }) - } -} - -impl Index for Shape -where - Idx: SliceIndex<[usize]>, -{ - type Output = Idx::Output; - - fn index(&self, index: Idx) -> &Self::Output { - &self.dims[index] - } -} - -impl IndexMut for Shape -where - Idx: SliceIndex<[usize]>, -{ - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { - &mut self.dims[index] - } -} - -// Allow `&shape` to behave like a slice `&[usize]` directly -impl Deref for Shape { - type Target = [usize]; - - fn deref(&self) -> &Self::Target { - &self.dims - } -} - -// Allow `&shape` to behave like a mut slice `&mut [usize]` directly -impl DerefMut for Shape { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.dims - } -} -// Allow `shape.reshape(other_shape)`. -// -// By implementing `AsRef<[usize]>`, `Shape` behaves like a slice of dimensions, -// similar to how `Vec` can be passed to functions expecting a slice. -impl AsRef<[usize]> for Shape { - fn as_ref(&self) -> &[usize] { - &self.dims - } -} - -impl From for Vec { - fn from(shape: Shape) -> Self { - shape.dims.to_vec() - } -} - -impl From for Shape -where - T: IntoIterator, - I: AsSize, -{ - fn from(dims: T) -> Self { - Shape { - dims: dims.into_iter().map(|d| d.as_size()).collect(), - } - } } #[cfg(test)] @@ -695,108 +112,8 @@ where mod tests { use super::*; use crate::s; - use alloc::string::ToString; use alloc::vec; - #[test] - fn test_shape_to_str() { - let shape = Shape::new([2, 3, 4, 5]); - assert_eq!(shape.to_string(), "[2, 3, 4, 5]"); - } - - #[test] - fn test_shape_from_str() { - assert_eq!( - "[2, 3, 4, 5]".parse::().unwrap(), - Shape::new([2, 3, 4, 5]) - ); - assert_eq!( - "(2, 3, 4, 5)".parse::().unwrap(), - Shape::new([2, 3, 4, 5]) - ); - assert_eq!( - "2, 3, 4, 5".parse::().unwrap(), - Shape::new([2, 3, 4, 5]) - ); - - assert_eq!("[2]".parse::().unwrap(), Shape::new([2])); - assert_eq!("(2)".parse::().unwrap(), Shape::new([2])); - assert_eq!("2".parse::().unwrap(), Shape::new([2])); - - assert_eq!("[]".parse::().unwrap(), Shape::new([])); - assert_eq!("".parse::().unwrap(), Shape::new([])); - - assert_eq!( - "[".parse::(), - Err(crate::ExpressionError::ParseError { - message: "Unbalanced delimiters".to_string(), - source: "[".to_string() - }) - ); - - assert_eq!( - "[[1]".parse::(), - Err(crate::ExpressionError::ParseError { - message: "Unable to parse shape".to_string(), - source: "[[1]".to_string() - }) - ); - assert_eq!( - "[[1]]".parse::(), - Err(crate::ExpressionError::ParseError { - message: "Unable to parse shape".to_string(), - source: "[[1]]".to_string() - }) - ); - assert_eq!( - "[1)".parse::(), - Err(crate::ExpressionError::ParseError { - message: "Unbalanced delimiters".to_string(), - source: "[1)".to_string() - }) - ); - - assert_eq!( - "]".parse::(), - Err(crate::ExpressionError::ParseError { - message: "Unable to parse shape".to_string(), - source: "]".to_string() - }) - ); - - assert_eq!( - "[a]".parse::(), - Err(crate::ExpressionError::ParseError { - message: "Unable to parse shape".to_string(), - source: "[a]".to_string() - }) - ); - } - - #[test] - fn num_dims_and_rank() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - assert_eq!(4, shape.num_dims()); - assert_eq!(4, shape.rank()); - } - - #[test] - fn num_elements() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - assert_eq!(120, shape.num_elements()); - } - - #[test] - #[allow(clippy::into_iter_on_ref)] - fn test_shape_into_iter() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - - assert_eq!(shape.into_iter().sum::(), 14); - } - #[test] fn test_into_ranges() { let dims = [2, 3, 4, 5]; @@ -804,13 +121,6 @@ mod tests { assert_eq!(shape.into_ranges(), vec![0..2, 0..3, 0..4, 0..5]); } - #[test] - fn test_to_vec() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - assert_eq!(shape.to_vec(), vec![2, 3, 4, 5]); - } - #[allow(clippy::single_range_in_vec_init)] #[test] fn test_into_slices() { @@ -833,56 +143,6 @@ mod tests { assert_eq!(slices[1].to_range(3), 2..3); } - #[test] - fn test_shape_index() { - let shape = Shape::new([2, 3, 4, 5]); - - assert_eq!(shape[0], 2); - assert_eq!(shape[1], 3); - assert_eq!(shape[2], 4); - assert_eq!(shape[3], 5); - - // Works with ranges - assert_eq!(shape[1..3], *&[3, 4]); - assert_eq!(shape[1..=2], *&[3, 4]); - assert_eq!(shape[..], *&[2, 3, 4, 5]); - } - - #[test] - fn test_shape_slice_methods() { - let shape = Shape::new([2, 3, 4, 5]); - - let dim = shape.first(); - assert_eq!(dim, Some(&2)); - let dim = shape.last(); - assert_eq!(dim, Some(&5)); - - assert!(!shape.is_empty()); - let shape = Shape::new([]); - assert!(shape.is_empty()); - } - - #[test] - fn test_shape_iter() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - - for (d, sd) in dims.iter().zip(shape.iter()) { - assert_eq!(d, sd); - } - } - - #[test] - fn test_shape_iter_mut() { - let mut shape = Shape::new([2, 3, 4, 5]); - - for d in shape.iter_mut() { - *d += 1; - } - - assert_eq!(shape.as_slice(), &[3, 4, 5, 6]); - } - #[test] fn test_shape_as_slice() { let dims = [2, 3, 4, 5]; @@ -913,288 +173,6 @@ mod tests { assert_eq!(shape, shape_mut) } - #[test] - fn test_shape_flatten() { - let shape = Shape::new([2, 3, 4, 5]); - assert_eq!(shape.num_elements(), 120); - - let shape = shape.flatten(); - assert_eq!(shape.num_elements(), 120); - assert_eq!(shape.as_slice(), &[120]); - } - - #[test] - fn test_ravel() { - let shape = Shape::new([2, 3, 4, 5]); - - assert_eq!(shape.ravel_index(&[0, 0, 0, 0]), 0); - assert_eq!( - shape.ravel_index(&[1, 2, 3, 4]), - 1 * (3 * 4 * 5) + 2 * (4 * 5) + 3 * 5 + 4 - ); - } - - #[test] - fn test_shape_insert_remove_push() { - let dims = [2, 3, 4, 5]; - let mut shape = Shape::new(dims); - let size = 6; - shape.insert(1, size); - - assert_eq!(shape, Shape::new([2, 6, 3, 4, 5])); - - let removed = shape.remove(1); - assert_eq!(removed, size); - assert_eq!(shape, Shape::new(dims)); - - shape.push(6); - assert_eq!(shape, Shape::new([2, 3, 4, 5, 6])); - } - - #[test] - fn test_shape_swap_permute() { - let dims = [2, 3, 4, 5]; - let shape = Shape::new(dims); - let shape = shape.swap(1, 2).unwrap(); - - assert_eq!(shape.as_slice(), &[2, 4, 3, 5]); - - let shape = shape.permute(&[0, 2, 1, 3]).unwrap(); - assert_eq!(shape, Shape::new(dims)); - } - - #[test] - #[should_panic] - fn test_shape_swap_out_of_bounds() { - let shape = Shape::new([2, 3, 4, 5]); - - shape.swap(0, 4).unwrap(); - } - - #[test] - #[should_panic] - fn test_shape_permute_incomplete() { - let shape = Shape::new([2, 3, 4, 5]); - - shape.permute(&[0, 2, 1]).unwrap(); - } - - #[test] - fn test_shape_repeat() { - let shape = Shape::new([2, 3, 4, 5]); - - let out = shape.repeat(2, 3).unwrap(); - assert_eq!(out, Shape::new([2, 3, 12, 5])); - } - - #[test] - fn test_shape_repeat_invalid() { - let shape = Shape::new([2, 3, 4, 5]); - - let out = shape.repeat(5, 3); - assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 })); - } - - #[test] - fn test_shape_reduce() { - let shape = Shape::new([2, 3, 4, 5]); - - let out = shape.reduce(2).unwrap(); - assert_eq!(out, Shape::new([2, 3, 1, 5])); - } - - #[test] - fn test_shape_reduce_invalid() { - let shape = Shape::new([2, 3, 4, 5]); - - let out = shape.reduce(5); - assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 5, rank: 4 })); - } - - #[test] - fn test_shape_broadcast_binary() { - let lhs = Shape::new([1, 1, 2, 4]); - let rhs = Shape::new([7, 6, 2, 1]); - - let out = lhs.broadcast(&rhs).unwrap(); - assert_eq!(out, Shape::new([7, 6, 2, 4])); - } - - #[test] - fn test_shape_broadcast_rank_mismatch() { - let lhs = Shape::new([1, 2, 4]); - let rhs = Shape::new([7, 6, 2, 4]); - - let out = lhs.broadcast(&rhs); - assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 })); - } - - #[test] - fn test_shape_broadcast_incompatible_dims() { - let lhs = Shape::new([1, 2, 2, 4]); - let rhs = Shape::new([7, 6, 2, 1]); - - let out = lhs.broadcast(&rhs); - assert_eq!( - out, - Err(ShapeError::IncompatibleDims { - left: 2, - right: 6, - dim: 1 - }) - ); - } - - #[test] - fn test_shape_broadcast_many() { - let s1 = Shape::new([1, 1, 2, 4]); - let s2 = Shape::new([7, 1, 2, 1]); - let s3 = Shape::new([7, 6, 1, 1]); - - let out = Shape::broadcast_many([&s1, &s2, &s3]).unwrap(); - assert_eq!(out, Shape::new([7, 6, 2, 4])); - } - - #[test] - fn test_shape_broadcast_many_rank_mismatch() { - let s1 = Shape::new([1, 1, 2, 4]); - let s2 = Shape::new([7, 1, 2, 1]); - let s3 = Shape::new([1, 6, 1]); - - let out = Shape::broadcast_many([&s1, &s2, &s3]); - assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 3 })); - } - - #[test] - fn test_shape_broadcast_many_incompatible_dims() { - let s1 = Shape::new([1, 1, 2, 4]); - let s2 = Shape::new([7, 1, 2, 1]); - let s3 = Shape::new([4, 6, 1, 1]); - - let out = Shape::broadcast_many([&s1, &s2, &s3]); - assert_eq!( - out, - Err(ShapeError::IncompatibleDims { - left: 7, - right: 4, - dim: 0 - }) - ); - } - - #[test] - fn test_shape_broadcast_many_empty() { - let out = Shape::broadcast_many(&[]); - assert_eq!(out, Err(ShapeError::empty())); - } - - #[test] - fn test_shape_matmul_2d() { - let lhs = Shape::new([2, 4]); - let rhs = Shape::new([4, 2]); - let out = calculate_matmul_output(&lhs, &rhs).unwrap(); - assert_eq!(out, Shape::new([2, 2])); - } - - #[test] - fn test_shape_matmul_4d_broadcasted() { - let lhs = Shape::new([1, 3, 2, 4]); - let rhs = Shape::new([2, 1, 4, 2]); - let out = calculate_matmul_output(&lhs, &rhs).unwrap(); - assert_eq!(out, Shape::new([2, 3, 2, 2])); - } - - #[test] - fn test_shape_matmul_invalid_rank() { - let lhs = Shape::new([3, 2, 4]); - let rhs = Shape::new([2, 1, 4, 2]); - let out = calculate_matmul_output(&lhs, &rhs); - assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 4 })); - } - - #[test] - fn test_shape_matmul_invalid_shape() { - let lhs = Shape::new([1, 3, 2, 4]); - let rhs = Shape::new([2, 1, 3, 2]); - let out = calculate_matmul_output(&lhs, &rhs); - assert_eq!( - out, - Err(ShapeError::IncompatibleShapes { - left: lhs, - right: rhs - }) - ); - } - - #[test] - fn test_shape_matmul_invalid_broadcast() { - let lhs = Shape::new([1, 3, 2, 4]); - let rhs = Shape::new([2, 2, 4, 2]); - let out = calculate_matmul_output(&lhs, &rhs); - assert_eq!( - out, - Err(ShapeError::IncompatibleDims { - left: 3, - right: 2, - dim: 1 - }) - ); - } - - #[test] - fn test_shape_cat() { - let s1 = Shape::new([2, 3, 4, 5]); - let s2 = Shape::new([1, 3, 4, 5]); - let s3 = Shape::new([4, 3, 4, 5]); - - let out = Shape::cat(&[s1, s2, s3], 0).unwrap(); - assert_eq!(out, Shape::new([7, 3, 4, 5])); - - let s1 = Shape::new([2, 3, 4, 5]); - let s2 = Shape::new([2, 3, 2, 5]); - let s3 = Shape::new([2, 3, 1, 5]); - - let out = Shape::cat(&[s1, s2, s3], 2).unwrap(); - assert_eq!(out, Shape::new([2, 3, 7, 5])); - } - - #[test] - fn test_shape_cat_empty() { - let out = Shape::cat(&[], 0); - assert_eq!(out, Err(ShapeError::empty())); - } - - #[test] - fn test_shape_cat_dim_out_of_bounds() { - let s1 = Shape::new([2, 3, 4, 5]); - let s2 = Shape::new([2, 3, 4, 5]); - let out = Shape::cat(&[s1, s2], 4); - assert_eq!(out, Err(ShapeError::OutOfBounds { dim: 4, rank: 4 })); - } - - #[test] - fn test_shape_cat_rank_mismatch() { - let s1 = Shape::new([2, 3, 4, 5]); - let s2 = Shape::new([2, 3, 4, 5, 6]); - let out = Shape::cat(&[s1, s2], 0); - assert_eq!(out, Err(ShapeError::RankMismatch { left: 4, right: 5 })); - } - - #[test] - fn test_shape_cat_incompatible_shapes() { - let s1 = Shape::new([2, 3, 4, 5]); - let s2 = Shape::new([1, 3, 4, 5]); - let out = Shape::cat(&[s1.clone(), s2.clone()], 1); - - assert_eq!( - out, - Err(ShapeError::IncompatibleShapes { - left: s1, - right: s2 - }) - ); - } - #[test] fn test_shape_slice_output_shape_basic() { // Test basic slicing with step=1 @@ -1290,83 +268,4 @@ mod tests { let result = original_shape.slice(&slices).unwrap(); assert_eq!(result, Shape::new([3, 3, 2])); } - - #[test] - fn test_shape_expand() { - let shape = Shape::new([1, 3, 1]); - let expanded = Shape::new([2, 3, 4]); - let out = shape.expand(expanded.clone()).unwrap(); - assert_eq!(out, expanded); - } - - #[test] - fn test_shape_expand_higher_rank() { - let shape = Shape::new([1, 4]); - let expanded = Shape::new([2, 3, 4]); - let out = shape.expand(expanded.clone()).unwrap(); - assert_eq!(out, expanded); - } - - #[test] - fn test_shape_expand_invalid_rank() { - let shape = Shape::new([1, 3, 1]); - let expanded = Shape::new([3, 4]); - let out = shape.expand(expanded); - assert_eq!(out, Err(ShapeError::RankMismatch { left: 3, right: 2 })); - } - - #[test] - fn test_shape_expand_incompatible_dims() { - let shape = Shape::new([1, 3, 2]); - let expanded = Shape::new([2, 3, 4]); - let out = shape.expand(expanded); - assert_eq!( - out, - Err(ShapeError::IncompatibleDims { - left: 2, - right: 4, - dim: 2 - }) - ); - } - - #[test] - fn test_shape_reshape() { - let shape = Shape::new([2, 3, 4, 5]); - let reshaped = Shape::new([1, 2, 12, 5]); - let out = shape.reshape(reshaped.clone()).unwrap(); - assert_eq!(out, reshaped); - } - - #[test] - fn test_shape_reshape_invalid() { - let shape = Shape::new([2, 3, 4, 5]); - let reshaped = Shape::new([2, 2, 12, 5]); - let out = shape.reshape(reshaped.clone()); - assert_eq!( - out, - Err(ShapeError::Invalid { - reason: "The given shape doesn't have the same number of elements as the current shape. Current shape: [2, 3, 4, 5], target shape: [2, 2, 12, 5].".into(), - }) - ); - } - - #[test] - fn test_shape_reshape_invalid_inferred() { - let shape = Shape::new([2, 4]); - let out = shape.reshape([-1, 3]); - assert_eq!( - out, - Err(ShapeError::Invalid { - reason: "Cannot infer a valid target shape. Current shape: [2, 4], target dimensions: [-1, 3].".into(), - }) - ); - } - - #[test] - fn test_flatten_dims() { - let shape = Shape::new([2, 3, 4, 5]); - let flattened = shape.flatten_dims(-2, 3); - assert_eq!(flattened, Shape::new([2, 3, 20])); - } } diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index 997b42c773..12807c93b1 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -11,7 +11,7 @@ use alloc::format; use alloc::string::String; use alloc::vec; -use burn_std::stub::RwLock; +use burn_std::{SliceOps, stub::RwLock}; use core::iter::repeat; use core::{fmt::Debug, ops::Range}; use serde::{Deserialize, Deserializer}; @@ -3184,6 +3184,8 @@ where #[cfg(test)] mod tests { + use burn_std::SliceOps; + use crate::{Shape, s}; #[test] diff --git a/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs index 345c10159b..a6aec52eae 100644 --- a/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/hardware_accelerated.rs @@ -12,7 +12,7 @@ use burn_cubecl::{ ops::{into_data_sync, numeric::zeros_client}, tensor::CubeTensor, }; -use burn_tensor::{Shape, cast::ToElement, ops::IntTensorOps}; +use burn_tensor::{Shape, TensorMetadata, cast::ToElement, ops::IntTensorOps}; use cubecl::{features::Plane, prelude::*}; use super::prefix_sum::prefix_sum; @@ -487,14 +487,9 @@ pub fn hardware_accelerated( - client.clone(), - device.clone(), - img.shape.clone(), - I::dtype(), - ); + let labels = zeros_client::(client.clone(), device.clone(), img.shape(), I::dtype()); // Assume 32 wide warp. Currently, larger warps are handled by just exiting everything past 32. // This isn't ideal but we require CUBE_DIM_X == warp_size, and we can't query the actual warp diff --git a/crates/burn-vision/src/backends/cube/connected_components/mod.rs b/crates/burn-vision/src/backends/cube/connected_components/mod.rs index 075b2ac1cc..736ad3b13f 100644 --- a/crates/burn-vision/src/backends/cube/connected_components/mod.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/mod.rs @@ -24,7 +24,7 @@ where I: IntElement, BT: BoolElement, { - let [height, width] = l.shape.dims(); + let [height, width] = l.meta.shape().dims(); let shape = Shape::new([height * width]); let zeros = || { zeros_client::( diff --git a/crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs b/crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs index c790469e7f..ad2ffc9edf 100644 --- a/crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs +++ b/crates/burn-vision/src/backends/cube/connected_components/prefix_sum.rs @@ -1,4 +1,4 @@ -use burn_tensor::Shape; +use burn_tensor::{Shape, TensorMetadata}; use cubecl::prelude::*; use burn_cubecl::{ @@ -220,12 +220,12 @@ fn count_trailing_zeros(num: u32) -> u32 { pub fn prefix_sum(input: CubeTensor) -> CubeTensor { let client = input.client.clone(); let device = input.device.clone(); - let num_elems = input.shape.num_elements(); - let numbers = *input.shape.last().unwrap(); + let num_elems = input.meta.num_elements(); + let numbers = *input.meta.shape().last().unwrap(); let batches = num_elems / numbers; let input = reshape(input, Shape::new([batches, numbers])); - let out = empty_device::(client.clone(), device.clone(), input.shape.clone()); + let out = empty_device::(client.clone(), device.clone(), input.shape()); let cubes = numbers.div_ceil(PART_SIZE); let cube_dim = CubeDim::new_1d(CUBE_SIZE as u32); diff --git a/examples/custom-cubecl-kernel/src/forward.rs b/examples/custom-cubecl-kernel/src/forward.rs index bac01bd1c2..8cc67ae798 100644 --- a/examples/custom-cubecl-kernel/src/forward.rs +++ b/examples/custom-cubecl-kernel/src/forward.rs @@ -29,15 +29,15 @@ impl Backend let bias = into_contiguous(bias); // Get the matmul relevant shapes. - let ndims = lhs.shape.num_dims(); - let num_rows = lhs.shape[ndims - 2]; - let num_cols = rhs.shape[ndims - 1]; + let ndims = lhs.meta.num_dims(); + let num_rows = lhs.meta.shape()[ndims - 2]; + let num_cols = rhs.meta.shape()[ndims - 1]; // Compute shape of output, while tracking number of batches. let mut num_batches = 1; let mut shape_out = vec![0; ndims]; for i in shape_out.clone().into_iter().take(ndims - 2) { - shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]); + shape_out[i] = usize::max(lhs.meta.shape()[i], rhs.meta.shape()[i]); num_batches *= shape_out[i]; } shape_out[ndims - 2] = num_rows; diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index d14d52e1ea..d9701fc2e4 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -61,15 +61,15 @@ impl Backend let bias = into_contiguous(bias); // Get the matmul relevant shapes. - let ndims = lhs.shape.num_dims(); - let num_rows = lhs.shape[ndims - 2]; - let num_cols = rhs.shape[ndims - 1]; + let ndims = lhs.meta.shape().num_dims(); + let num_rows = lhs.meta.shape()[ndims - 2]; + let num_cols = rhs.meta.shape()[ndims - 1]; // Compute shape of output, while tracking number of batches. let mut num_batches = 1; let mut shape_out = vec![0; ndims]; for i in shape_out.clone().into_iter().take(ndims - 2) { - shape_out[i] = usize::max(lhs.shape[i], rhs.shape[i]); + shape_out[i] = usize::max(lhs.meta.shape()[i], rhs.meta.shape()[i]); num_batches *= shape_out[i]; } shape_out[ndims - 2] = num_rows; From 83fd0e7a63fae44906a371f08a96fac0d7514b94 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 17 Feb 2026 15:29:21 +0100 Subject: [PATCH 3/8] Revert temp fix --- crates/burn-cubecl-fusion/src/engine/launch/executor.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs index f069ae5be1..acfe5ee71b 100644 --- a/crates/burn-cubecl-fusion/src/engine/launch/executor.rs +++ b/crates/burn-cubecl-fusion/src/engine/launch/executor.rs @@ -71,8 +71,6 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { return Ok(tune_output); } - let mut configs = Vec::with_capacity(plan.blocks.len()); - let mut inputs = GlobalArgsLaunch::default(); let mut outputs = GlobalArgsLaunch::default(); @@ -94,6 +92,8 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { } } + let mut configs = Vec::with_capacity(plan.blocks.len()); + for (block_plan, block) in plan.blocks.into_iter().zip(self.blocks) { let reference = match block_plan.reference { ReferenceSelection::Concrete { layout, .. } => RefLayout::Concrete(layout), @@ -113,8 +113,6 @@ impl<'a, R: Runtime> LaunchPlanExecutor<'a, R> { RefLayout::Virtual(VirtualLayout::Runtime { pos }) } ReferenceSelection::Searching => { - drop(inputs); - drop(outputs); return Err(ExecutionError::new( TraceError::ReferenceNotFound, plan.handle_inputs, From 7c07c1e937d680e2dc825e01d83fe1b85114aee1 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 17 Feb 2026 16:45:25 +0100 Subject: [PATCH 4/8] Rename `ShapeError` to `MetadataError` --- .../src/backend/ops/modules/conv.rs | 18 +++++++++--------- crates/burn-std/src/tensor/shape.rs | 8 ++++---- crates/burn-tensor/src/tensor/api/check.rs | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/crates/burn-backend/src/backend/ops/modules/conv.rs b/crates/burn-backend/src/backend/ops/modules/conv.rs index 659e08f5de..a4e0666679 100644 --- a/crates/burn-backend/src/backend/ops/modules/conv.rs +++ b/crates/burn-backend/src/backend/ops/modules/conv.rs @@ -1,7 +1,7 @@ #![allow(clippy::single_range_in_vec_init)] use super::{ConvOptions, ConvTransposeOptions}; use crate::{Backend, TensorMetadata, tensor::FloatTensor}; -use burn_std::{Shape, ShapeError, Slice}; +use burn_std::{MetadataError, Shape, Slice}; use alloc::{vec, vec::Vec}; #[cfg(not(feature = "std"))] @@ -16,9 +16,9 @@ pub fn calculate_pool_output_shape( padding: &[usize; N], dilation: &[usize; N], ceil_mode: bool, -) -> Result { +) -> Result { if in_shape.rank() != N + 2 { - return Err(ShapeError::RankMismatch { + return Err(MetadataError::RankMismatch { left: in_shape.rank(), right: N + 2, }); @@ -47,16 +47,16 @@ pub fn calculate_conv_output_shape( stride: &[usize; N], padding: &[usize; N], dilation: &[usize; N], -) -> Result { +) -> Result { if weight_shape.rank() != N + 2 { - return Err(ShapeError::RankMismatch { + return Err(MetadataError::RankMismatch { left: weight_shape.rank(), right: N + 2, }); } if in_shape.rank() != N + 2 { - return Err(ShapeError::RankMismatch { + return Err(MetadataError::RankMismatch { left: in_shape.rank(), right: N + 2, }); @@ -85,16 +85,16 @@ pub fn calculate_conv_transpose_output_shape( padding_out: &[usize; N], dilation: &[usize; N], groups: usize, -) -> Result { +) -> Result { if weight_shape.rank() != N + 2 { - return Err(ShapeError::RankMismatch { + return Err(MetadataError::RankMismatch { left: weight_shape.rank(), right: N + 2, }); } if in_shape.rank() != N + 2 { - return Err(ShapeError::RankMismatch { + return Err(MetadataError::RankMismatch { left: in_shape.rank(), right: N + 2, }); diff --git a/crates/burn-std/src/tensor/shape.rs b/crates/burn-std/src/tensor/shape.rs index 10acbd3835..a28e9fed82 100644 --- a/crates/burn-std/src/tensor/shape.rs +++ b/crates/burn-std/src/tensor/shape.rs @@ -6,7 +6,7 @@ use core::ops::Range; pub use crate::errors::ExpressionError; -pub use cubecl_zspace::{Shape, ShapeError, calculate_matmul_output, shape}; +pub use cubecl_zspace::{MetadataError, Shape, calculate_matmul_output, shape}; /// Slice-relatedo ops on [`Shape`] pub trait SliceOps: Sized { @@ -75,7 +75,7 @@ pub trait SliceOps: Sized { where S: SliceArg; /// Compute the output shape from the given slices. - fn slice(self, slices: &[Slice]) -> Result; + fn slice(self, slices: &[Slice]) -> Result; } impl SliceOps for Shape { @@ -90,9 +90,9 @@ impl SliceOps for Shape { slices.into_slices(&self) } - fn slice(mut self, slices: &[Slice]) -> Result { + fn slice(mut self, slices: &[Slice]) -> Result { if slices.len() > self.rank() { - return Err(ShapeError::RankMismatch { + return Err(MetadataError::RankMismatch { left: self.rank(), right: slices.len(), }); diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d199b6572a..d408fe175e 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1440,11 +1440,11 @@ pub(crate) mod macros { pub(crate) use check; } -pub(crate) fn unwrap_shape_reshape(result: Result) -> Shape { +pub(crate) fn unwrap_shape_reshape(result: Result) -> Shape { match result { Ok(shape) => shape, // `shape.reshape(new_shape)` should only return `ShapeError::Invalid`. - Err(burn_std::ShapeError::Invalid { reason }) => { + Err(burn_std::MetadataError::Invalid { reason }) => { macros::check!({ TensorCheck::Ok.register("Reshape", crate::check::TensorError::new(reason)) }); From 35043059ab027d2d21727f667aaaac469b89202b Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:56:28 +0100 Subject: [PATCH 5/8] Cleanup --- crates/burn-cubecl/src/kernel/conv/im2col.rs | 6 +++--- crates/burn-std/src/tensor/mod.rs | 4 +--- crates/burn-std/src/tensor/shape.rs | 2 +- crates/burn-tensor/src/tensor/api/check.rs | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/crates/burn-cubecl/src/kernel/conv/im2col.rs b/crates/burn-cubecl/src/kernel/conv/im2col.rs index f1128581b4..ef110842ce 100644 --- a/crates/burn-cubecl/src/kernel/conv/im2col.rs +++ b/crates/burn-cubecl/src/kernel/conv/im2col.rs @@ -2,7 +2,7 @@ use burn_backend::{ DType, ops::{ConvOptions, conv::calculate_conv_output_sizes}, }; -use burn_std::Metadata; +use burn_std::{Metadata, Shape}; use core::iter; use cubecl::{ prelude::*, @@ -139,7 +139,7 @@ fn reshape_input(mut input: CubeTensor) -> CubeTensor { let batch_size = input.meta.shape()[0]; let in_c: usize = input.meta.shape()[dim_c]; - let in_shape = input.meta.shape()[1..dim_c].to_vec(); + let in_shape: Shape = input.meta.shape()[1..dim_c].into(); if !is_spatial_contiguous(input.meta.shape(), input.meta.strides()) { let contiguous = @@ -148,7 +148,7 @@ fn reshape_input(mut input: CubeTensor) -> CubeTensor { input = from_handle(&input.client, &input.device, contiguous, dtype); } *input.meta = Metadata::new( - [batch_size * in_shape.iter().product::(), in_c], // [M, K] + [batch_size * in_shape.num_elements(), in_c], // [M, K] [input.meta.strides()[dim_c - 1], input.meta.strides()[dim_c]], ); input diff --git a/crates/burn-std/src/tensor/mod.rs b/crates/burn-std/src/tensor/mod.rs index 19e6802219..c11d911ef9 100644 --- a/crates/burn-std/src/tensor/mod.rs +++ b/crates/burn-std/src/tensor/mod.rs @@ -11,8 +11,6 @@ pub use slice::*; pub use cubecl_zspace::indexing::{self, *}; pub use cubecl_zspace::{Strides, metadata::Metadata, strides}; -use alloc::vec; - /// Check if the current tensor is contiguous. /// /// A tensor is considered contiguous if its elements are stored in memory @@ -187,7 +185,7 @@ pub fn reshape_analysis( match n_new_batch > 0 { true => { if shape == &shape_new[n_new_batch..shape_new_rank] - && shape_new[0..n_new_batch] == vec![1; n_new_batch] + && shape_new[0..n_new_batch].iter().all(|it| *it == 1) { return ReshapeAnalysis::Broadcasted; } else { diff --git a/crates/burn-std/src/tensor/shape.rs b/crates/burn-std/src/tensor/shape.rs index a28e9fed82..e40df09e47 100644 --- a/crates/burn-std/src/tensor/shape.rs +++ b/crates/burn-std/src/tensor/shape.rs @@ -8,7 +8,7 @@ pub use crate::errors::ExpressionError; pub use cubecl_zspace::{MetadataError, Shape, calculate_matmul_output, shape}; -/// Slice-relatedo ops on [`Shape`] +/// Slice-related ops on [`Shape`] pub trait SliceOps: Sized { /// Convert shape dimensions to full covering ranges (0..dim) for each dimension. fn into_ranges(self) -> Vec>; diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d408fe175e..09bd9ce011 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1443,7 +1443,7 @@ pub(crate) mod macros { pub(crate) fn unwrap_shape_reshape(result: Result) -> Shape { match result { Ok(shape) => shape, - // `shape.reshape(new_shape)` should only return `ShapeError::Invalid`. + // `shape.reshape(new_shape)` should only return `MetadataError::Invalid`. Err(burn_std::MetadataError::Invalid { reason }) => { macros::check!({ TensorCheck::Ok.register("Reshape", crate::check::TensorError::new(reason)) From d218e5c29795bdd9951e06c8efc614ce1b16666f Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 18 Feb 2026 02:49:00 +0100 Subject: [PATCH 6/8] Bump cubecl and cubek rev --- Cargo.lock | 23 +++++++++++++++++++++++ Cargo.toml | 15 ++++++++------- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 63a30509bf..7b90763d55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2078,6 +2078,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "cubecl-core", "cubecl-cpu", @@ -2093,6 +2094,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "backtrace", "bincode", @@ -2130,6 +2132,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -2156,6 +2159,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "bytemuck", "cubecl-common", @@ -2171,6 +2175,7 @@ dependencies = [ [[package]] name = "cubecl-cpu" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "bytemuck", "cubecl-common", @@ -2191,6 +2196,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "bytemuck", "cubecl-common", @@ -2208,6 +2214,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "bytemuck", "cubecl-common", @@ -2236,6 +2243,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -2256,6 +2264,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "cubecl-common", "darling 0.23.0", @@ -2270,6 +2279,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -2280,6 +2290,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "cubecl-common", "cubecl-core", @@ -2296,6 +2307,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "async-channel", "bytemuck", @@ -2325,6 +2337,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "bitflags 2.10.0", "cubecl-common", @@ -2340,6 +2353,7 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "cubecl-common", "cubecl-core", @@ -2355,6 +2369,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "ash", "async-channel", @@ -2381,6 +2396,7 @@ dependencies = [ [[package]] name = "cubecl-zspace" version = "0.10.0-pre.1" +source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" dependencies = [ "derive-new", "serde", @@ -2390,6 +2406,7 @@ dependencies = [ [[package]] name = "cubek" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "cubecl", "cubek-attention", @@ -2403,6 +2420,7 @@ dependencies = [ [[package]] name = "cubek-attention" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "bytemuck", "cubecl", @@ -2416,6 +2434,7 @@ dependencies = [ [[package]] name = "cubek-convolution" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "bytemuck", "cubecl", @@ -2430,6 +2449,7 @@ dependencies = [ [[package]] name = "cubek-matmul" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "bytemuck", "cubecl", @@ -2441,6 +2461,7 @@ dependencies = [ [[package]] name = "cubek-quant" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "cubecl", "cubecl-common", @@ -2451,6 +2472,7 @@ dependencies = [ [[package]] name = "cubek-random" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "cubecl", "cubecl-common", @@ -2463,6 +2485,7 @@ dependencies = [ [[package]] name = "cubek-reduce" version = "0.2.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" dependencies = [ "cubecl", "half", diff --git a/Cargo.toml b/Cargo.toml index 1de6db9040..d6a6c6539c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -179,14 +179,15 @@ portable-atomic = { version = "1.13.1" } portable-atomic-util = { version = "0.2.5", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "11c32460621652077898d3309430711ab3d7944d" } -# cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "b121ed35e19bf9e0935dd563e97cab6e7a76005e" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" } +cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" } +cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" } ### For local development. ### -cubecl = { path = "../cubecl/crates/cubecl", default-features = false } -cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } -cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false } -cubek = { path = "../cubek/crates/cubek", default-features = false } +# cubecl = { path = "../cubecl/crates/cubecl", default-features = false } +# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } +# cubecl-zspace = { path = "../cubecl/crates/cubecl-zspace", default-features = false } +# cubek = { path = "../cubek/crates/cubek", default-features = false } ### For the release. ### # cubecl = { version = "=0.10.0-pre.1", default-features = false } # cubecl-common = { version = "=0.10.0-pre.1", default-features = false } From a066026c8b7cf3994a5e64c9e0740d9d5c3ecf87 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Wed, 18 Feb 2026 16:02:41 +0100 Subject: [PATCH 7/8] Fix doc test --- crates/burn-std/src/tensor/shape.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-std/src/tensor/shape.rs b/crates/burn-std/src/tensor/shape.rs index e40df09e47..7f2cdce0a5 100644 --- a/crates/burn-std/src/tensor/shape.rs +++ b/crates/burn-std/src/tensor/shape.rs @@ -42,7 +42,7 @@ pub trait SliceOps: Sized { /// # Examples /// /// ```rust - /// use burn_std::{Shape, Slice, s}; + /// use burn_std::{Shape, Slice, s, SliceOps}; /// /// fn example() { /// // 1D slicing From 432d733f653120a712c3f4f538fbeb6359827782 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Thu, 19 Feb 2026 14:54:13 +0100 Subject: [PATCH 8/8] Bump cubecl rev --- Cargo.lock | 46 +++++++++++++++++++++++----------------------- Cargo.toml | 8 ++++---- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7b90763d55..9602a70b8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2078,7 +2078,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "cubecl-core", "cubecl-cpu", @@ -2094,7 +2094,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "backtrace", "bincode", @@ -2132,7 +2132,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "bitflags 2.10.0", "bytemuck", @@ -2159,7 +2159,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "bytemuck", "cubecl-common", @@ -2175,7 +2175,7 @@ dependencies = [ [[package]] name = "cubecl-cpu" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "bytemuck", "cubecl-common", @@ -2196,7 +2196,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "bytemuck", "cubecl-common", @@ -2214,7 +2214,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "bytemuck", "cubecl-common", @@ -2243,7 +2243,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -2264,7 +2264,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "cubecl-common", "darling 0.23.0", @@ -2279,7 +2279,7 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -2290,7 +2290,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "cubecl-common", "cubecl-core", @@ -2307,7 +2307,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "async-channel", "bytemuck", @@ -2337,7 +2337,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "bitflags 2.10.0", "cubecl-common", @@ -2353,7 +2353,7 @@ dependencies = [ [[package]] name = "cubecl-std" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "cubecl-common", "cubecl-core", @@ -2369,7 +2369,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "ash", "async-channel", @@ -2396,7 +2396,7 @@ dependencies = [ [[package]] name = "cubecl-zspace" version = "0.10.0-pre.1" -source = "git+https://github.com/tracel-ai/cubecl?rev=edc576ec0d89d34a2330cdf1174d0fc6d49c14c8#edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" +source = "git+https://github.com/tracel-ai/cubecl?rev=b19859ee693bb02a25e4da2ca53797bb164be140#b19859ee693bb02a25e4da2ca53797bb164be140" dependencies = [ "derive-new", "serde", @@ -2406,7 +2406,7 @@ dependencies = [ [[package]] name = "cubek" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "cubecl", "cubek-attention", @@ -2420,7 +2420,7 @@ dependencies = [ [[package]] name = "cubek-attention" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "bytemuck", "cubecl", @@ -2434,7 +2434,7 @@ dependencies = [ [[package]] name = "cubek-convolution" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "bytemuck", "cubecl", @@ -2449,7 +2449,7 @@ dependencies = [ [[package]] name = "cubek-matmul" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "bytemuck", "cubecl", @@ -2461,7 +2461,7 @@ dependencies = [ [[package]] name = "cubek-quant" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "cubecl", "cubecl-common", @@ -2472,7 +2472,7 @@ dependencies = [ [[package]] name = "cubek-random" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "cubecl", "cubecl-common", @@ -2485,7 +2485,7 @@ dependencies = [ [[package]] name = "cubek-reduce" version = "0.2.0-pre.1" -source = "git+https://github.com/tracel-ai/cubek?rev=7b9a1f87d9e0cb984cfcb83fb0f04240513038e7#7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" +source = "git+https://github.com/tracel-ai/cubek?rev=75669db878e78c762ed44f91a1e4c6b8e2b84b2a#75669db878e78c762ed44f91a1e4c6b8e2b84b2a" dependencies = [ "cubecl", "half", diff --git a/Cargo.toml b/Cargo.toml index d6a6c6539c..305206022f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -179,10 +179,10 @@ portable-atomic = { version = "1.13.1" } portable-atomic-util = { version = "0.2.5", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" } -cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "edc576ec0d89d34a2330cdf1174d0fc6d49c14c8" } -cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "7b9a1f87d9e0cb984cfcb83fb0f04240513038e7" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b19859ee693bb02a25e4da2ca53797bb164be140" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b19859ee693bb02a25e4da2ca53797bb164be140" } +cubecl-zspace = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b19859ee693bb02a25e4da2ca53797bb164be140" } +cubek = { git = "https://github.com/tracel-ai/cubek", default-features = false, rev = "75669db878e78c762ed44f91a1e4c6b8e2b84b2a" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }