diff --git a/Cargo.lock b/Cargo.lock index 98fcee833f..e2f6dcb39d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1027,7 +1027,7 @@ dependencies = [ "byteorder", "candle-kernels", "candle-metal-kernels", - "cudarc", + "cudarc 0.13.9", "gemm 0.17.1", "half", "libc", @@ -1496,9 +1496,8 @@ dependencies = [ [[package]] name = "cubecl" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1e438056cf7c25b3adde38240b89842e1c924b8e914731c82ad81161d23e6ff" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1513,9 +1512,8 @@ dependencies = [ [[package]] name = "cubecl-common" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79251bfc7f067ac9038232fe38a317adc2f31cb2fc3800e69fd409ccac7abc1f" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1537,9 +1535,8 @@ dependencies = [ [[package]] name = "cubecl-core" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03bf4211cdbd68bb0fb8291e0ed825c13da0d1ac01b7c02dce3cee44a6138be" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bitflags 2.9.0", "bytemuck", @@ -1561,9 +1558,8 @@ dependencies = [ [[package]] name = "cubecl-cpp" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5eef85cbcc34be7e25fc9d39edf99ed68559862dbf25c1877ebdf4a9595d31b" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bytemuck", "cubecl-common", @@ -1576,16 +1572,15 @@ dependencies = [ [[package]] name = "cubecl-cuda" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71e091e4e3a3900faff440aec4053805ec4456f94f4acc4afe8e6b27519c6d16" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bytemuck", "cubecl-common", "cubecl-core", "cubecl-cpp", "cubecl-runtime", - "cudarc", + "cudarc 0.16.2", "derive-new 0.6.0", "half", "log", @@ -1594,9 +1589,8 @@ dependencies = [ [[package]] name = "cubecl-hip" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2f8c00207517de61cccdc4ca2724bc1db9dab94840beaf4329e43cead3bc4a" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bytemuck", "cubecl-common", @@ -1622,9 +1616,8 @@ dependencies = [ [[package]] name = "cubecl-ir" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e096d77646590f0180ed4ce1aa7df4ecc7219f3c4616e9fe72d93ab63a352855" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "cubecl-common", "cubecl-macros-internal", @@ -1641,9 +1634,8 @@ dependencies = [ [[package]] name = "cubecl-linalg" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75aacf86f6004c274e63589aed55c5edcbcdf1b292eaf4ce2c1688c04c41a194" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bytemuck", "cubecl-common", @@ -1657,9 +1649,8 @@ dependencies = [ [[package]] name = "cubecl-macros" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd74622b5c8cb161e3f7fa0b2b751784ef89ab45acfa355f511eb2219dde337e" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "cubecl-common", "darling", @@ -1673,9 +1664,8 @@ dependencies = [ [[package]] name = "cubecl-macros-internal" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a89898212c1eaba0e2f0dffcadc9790b20b75d2ec8836da084370b043be2623" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "darling", "proc-macro2", @@ -1685,9 +1675,8 @@ dependencies = [ [[package]] name = "cubecl-opt" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dacd2d9f1f4891bd7a6a81fc5885852d8ee0725d7197a879d5c7ee8ba4eb641" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1702,9 +1691,8 @@ dependencies = [ [[package]] name = "cubecl-reduce" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7afbdfe03e7e3ca71f61890ebebc6b4390494204b545e6f6bf51a43755449073" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1715,9 +1703,8 @@ dependencies = [ [[package]] name = "cubecl-runtime" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "385234520c9e392382737f32ad372b05f345656eb798ba00b72d2722c68b698c" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "async-channel", "bytemuck", @@ -1738,9 +1725,8 @@ dependencies = [ [[package]] name = "cubecl-spirv" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830c5bd4314947eefc4c5fe850e2d9de3d8e9d8a3eeac7a5544c2ca30713c119" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "bitflags 2.9.0", "cubecl-common", @@ -1754,9 +1740,8 @@ dependencies = [ [[package]] name = "cubecl-std" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38868eea6fdc183feb3c46bcf5e666c78e6cf0ddca2c4f3a877785cc0eabd71e" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1766,9 +1751,8 @@ dependencies = [ [[package]] name = "cubecl-wgpu" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77fa2dcfaa6d75cfbc5ff05cafe99ec4a7fb7c0fa7197917e0fd20f5b90979fe" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=0e13dfd104a18f53e3d25608d59186a5e850d855#0e13dfd104a18f53e3d25608d59186a5e850d855" dependencies = [ "ash", "async-channel", @@ -1798,6 +1782,15 @@ dependencies = [ "libloading", ] +[[package]] +name = "cudarc" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4ed411343abcb4dd6fd1fbc32db3533d76c2af0fd40735a9e5e39e778a81254" +dependencies = [ + "libloading", +] + [[package]] name = "custom-csv-dataset" version = "0.18.0" @@ -2002,9 +1995,9 @@ dependencies = [ [[package]] name = "derive_more" -version = "0.99.19" +version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da29a38df43d6f156149c9b43ded5e018ddff2a855cf2cfd62e8cd7d079c69f" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ "proc-macro2", "quote", @@ -5283,7 +5276,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.8.24", + "zerocopy 0.8.25", ] [[package]] @@ -7160,9 +7153,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" +checksum = "900f6c86a685850b1bc9f6223b20125115ee3f31e01207d81655bbcc0aea9231" dependencies = [ "serde", "serde_spanned", @@ -7172,18 +7165,18 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" dependencies = [ "serde", ] [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485" dependencies = [ "indexmap", "serde", @@ -7243,7 +7236,7 @@ checksum = "c242cadfa35f2fdb4a812746cfc8037370c865eb8b95c026bc1ce1e68aa49be6" dependencies = [ "anyhow", "clap", - "derive_more 0.99.19", + "derive_more 0.99.20", "env_logger", "log", "rand 0.8.5", @@ -7404,7 +7397,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "50758486d7941f8b0a636ba7e29455c07071f41590beac1fd307ec893e8db69a" dependencies = [ - "cudarc", + "cudarc 0.13.9", "half", "serde", "thiserror 1.0.69", @@ -8382,9 +8375,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" +checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5" dependencies = [ "memchr", ] @@ -8504,11 +8497,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.8.24", + "zerocopy-derive 0.8.25", ] [[package]] @@ -8524,9 +8517,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 42b41390fd..e7209020ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -156,17 +156,17 @@ portable-atomic = { version = "1.11.0" } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a34fa342a4914a7a5ce962a4ddaef661ba557c5" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a34fa342a4914a7a5ce962a4ddaef661ba557c5" } -# cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "9a34fa342a4914a7a5ce962a4ddaef661ba557c5" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0e13dfd104a18f53e3d25608d59186a5e850d855" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0e13dfd104a18f53e3d25608d59186a5e850d855" } +cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "0e13dfd104a18f53e3d25608d59186a5e850d855" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } # cubecl-std = { path = "../cubecl/crates/cubecl-std", default-features = false } ### For the release. ### -cubecl = { version = "0.5.0", default-features = false } -cubecl-common = { version = "0.5.0", default-features = false } -cubecl-std = { version = "0.5.0", default-features = false } +# cubecl = { version = "0.5.0", default-features = false } +# cubecl-common = { version = "0.5.0", default-features = false } +# cubecl-std = { version = "0.5.0", default-features = false } ### For xtask crate ### tracel-xtask = { version = "=1.1.9" } diff --git a/crates/burn-cubecl-fusion/src/shared/io.rs b/crates/burn-cubecl-fusion/src/shared/io.rs index 88534e87c4..94696f63c9 100644 --- a/crates/burn-cubecl-fusion/src/shared/io.rs +++ b/crates/burn-cubecl-fusion/src/shared/io.rs @@ -539,6 +539,7 @@ pub(crate) fn swap_dims_transform(i: &I, dims: (u32, u32)) -> } #[cube] +#[allow(clippy::clone_on_copy)] /// The index the input tensor would be at if it was contiguous. fn reshaped_index( inputs: &GlobalArgs, @@ -552,10 +553,10 @@ fn reshaped_index( #[unroll] for r in 0..rank { - let i = comptime![reverse_index(rank, r)]; + let i = reverse_index(rank, r); let arg = comptime![shape.index(i.clone())]; let shape_i = read_scalar_shape(inputs, comptime![arg.clone()]); - let ogwl = index / locals.ref_strides[comptime![i.clone()]]; + let ogwl = index / locals.ref_strides[i]; offset += ogwl % shape_i * stride_curr; @@ -566,6 +567,7 @@ fn reshaped_index( } #[cube] +#[allow(clippy::clone_on_copy)] fn reshaped_index_to_original_index( original: &Tensor>, index_reshaped: u32, @@ -576,7 +578,7 @@ fn reshaped_index_to_original_index( #[unroll] for r in 0..rank { - let i = comptime![reverse_index(rank, r)]; + let i = reverse_index(rank, r); let shape = original.shape(comptime![i.clone()]); let stride = original.stride(i); @@ -589,17 +591,25 @@ fn reshaped_index_to_original_index( offset / original.line_size() } -pub(crate) fn reverse_index>>( - rank: u32, - iter: Elem, -) -> ExpandElementTyped { - let elem = iter.into(); - let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); - let result = rank - elem - 1; - let scalar: Variable = result.into(); - let expand: ExpandElement = ExpandElement::Plain(scalar); - - expand.into() +pub(crate) fn reverse_index(_rank: u32, _iter: u32) -> u32 { + unexpanded!() +} + +pub(crate) mod reverse_index { + use super::*; + + pub(crate) fn expand( + _scope: &mut Scope, + rank: u32, + elem: ExpandElementTyped, + ) -> ExpandElementTyped { + let elem = elem.constant().map(|cons| cons.as_u32()).unwrap(); + let result = rank - elem - 1; + let scalar: Variable = result.into(); + let expand: ExpandElement = ExpandElement::Plain(scalar); + + expand.into() + } } /// Generic way to construct any [`CubePrimitive`] from an int. Used for fusion. diff --git a/crates/burn-cubecl-fusion/src/shared/kernel.rs b/crates/burn-cubecl-fusion/src/shared/kernel.rs index 60041c1a70..fa2ddf9126 100644 --- a/crates/burn-cubecl-fusion/src/shared/kernel.rs +++ b/crates/burn-cubecl-fusion/src/shared/kernel.rs @@ -3,6 +3,7 @@ use crate::shared::DYN_ELEM_ID; use super::io::*; use super::ir::*; use cubecl::prelude::*; +use cubecl::unexpanded; #[cube] /// Fuse element-wise operations at the given write position. @@ -110,7 +111,7 @@ pub fn init_locals( #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { - let reverse = comptime![reverse_index(config.rank, comptime![i.clone()])]; + let reverse = reverse_index(config.rank, i); let swap = comptime![swap_dims_transform(comptime![&reverse], dims)]; let shape = layout.tensor.shape(comptime![swap.clone()]); @@ -132,8 +133,8 @@ pub fn init_locals( #[unroll] #[allow(clippy::clone_on_copy)] for i in 0..config.rank { - let reverse = comptime![reverse_index(config.rank, comptime![i.clone()])]; - let reverse_u32_comptime = comptime!(unwrap_const_u32(reverse)); + let reverse = reverse_index(config.rank, i); + let reverse_u32_comptime = unwrap_const_u32(reverse); let arg = comptime![Arg::ScalarShape(start + reverse_u32_comptime)]; let shape = read_scalar_shape(inputs, comptime![arg.clone()]); @@ -149,8 +150,16 @@ pub fn init_locals( } } -fn unwrap_const_u32(elem: ExpandElementTyped) -> u32 { - elem.constant().map(|cons| cons.as_u32()).unwrap() +fn unwrap_const_u32(_elem: u32) -> u32 { + unexpanded!() +} + +mod unwrap_const_u32 { + use super::*; + + pub(crate) fn expand(_scope: &mut Scope, elem: ExpandElementTyped) -> u32 { + elem.constant().map(|cons| cons.as_u32()).unwrap() + } } #[cube] diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index 1196a2f1f4..89c8cf4c2a 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -11,6 +11,7 @@ pub(crate) fn from_data(data: TensorData, device: &R::Device) -> } pub(crate) async fn into_data(tensor: CubeTensor) -> TensorData { + println!("Into data"); let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;