Skip to content
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
cf625c1
Refactor naming
nathanielsimard Feb 21, 2025
f896944
Build the args
nathanielsimard Feb 21, 2025
af32941
WIP
nathanielsimard Feb 21, 2025
762ae7f
WIP
nathanielsimard Feb 24, 2025
ab343b0
WIP It executes
nathanielsimard Feb 25, 2025
c359c89
Multi execution
nathanielsimard Feb 25, 2025
8cd6acc
WIP
nathanielsimard Feb 25, 2025
416e220
WIP
nathanielsimard Feb 26, 2025
1a7e61d
Output offset
nathanielsimard Feb 26, 2025
1919dc2
Fix output offset
nathanielsimard Feb 26, 2025
ca549e5
WIP
nathanielsimard Feb 26, 2025
6fe9aa1
Better error handling around missing reference tensors
nathanielsimard Feb 27, 2025
e3d8017
Fix inplace
nathanielsimard Feb 27, 2025
a4efa79
WIP
nathanielsimard Feb 28, 2025
1fd3cd0
Add tuner
nathanielsimard Feb 28, 2025
0c1b90d
Support more instructions
nathanielsimard Feb 28, 2025
6aeb2f1
CLippy
nathanielsimard Mar 3, 2025
933652a
Lock
nathanielsimard Mar 3, 2025
88e202a
Add state
nathanielsimard Mar 3, 2025
4491690
Merge branch 'main' into feat/fuse-on-read
nathanielsimard Mar 4, 2025
9b42fa8
Update CubeCL
nathanielsimard Mar 4, 2025
2f97206
CLippy
nathanielsimard Mar 4, 2025
a4b6afa
Add strategy validation
nathanielsimard Mar 4, 2025
11cb906
Clippy
nathanielsimard Mar 5, 2025
04b78a8
Test CI
nathanielsimard Mar 5, 2025
e2fa98d
Only reduce + reduce shared plane
nathanielsimard Mar 5, 2025
601e094
Put ref info in locals
nathanielsimard Mar 5, 2025
b89bb62
Store ref metadata in arrays
nathanielsimard Mar 5, 2025
9aad422
Fix bug
nathanielsimard Mar 5, 2025
e464142
Cleanup
nathanielsimard Mar 5, 2025
0a28d83
Activate fuse features
nathanielsimard Mar 5, 2025
8c80064
Clippy
nathanielsimard Mar 6, 2025
11b746c
Update rev
nathanielsimard Mar 6, 2025
e1b0039
Removes println
nathanielsimard Mar 6, 2025
4264ffb
Removes println
nathanielsimard Mar 6, 2025
a6d8fd6
Fix precision
nathanielsimard Mar 6, 2025
466485d
Update bench
nathanielsimard Mar 6, 2025
c736943
Fixes
nathanielsimard Mar 6, 2025
f9b9e98
Fix fallback
nathanielsimard Mar 6, 2025
05dfd4a
Enable int reduce in fusion
nathanielsimard Mar 6, 2025
e0da405
Fix Virtual Reference Tensor Broadcasted
nathanielsimard Mar 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,13 @@ portable-atomic = { version = "1.11.0" }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e4fa42bebc3348b8912854298f3ec8e4d2d23529" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e4fa42bebc3348b8912854298f3ec8e4d2d23529" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "32367c1fb6898beea79e175f27173b26ec8e5a69" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "32367c1fb6898beea79e175f27173b26ec8e5a69" }
cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "32367c1fb6898beea79e175f27173b26ec8e5a69" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubecl-std = { path = "../cubecl/crates/cubecl-std", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
16 changes: 10 additions & 6 deletions backend-comparison/benches/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ impl<B: Backend> Benchmark for ReduceBenchmark<B> {
self.tensor.clone().argmin(axis);
}
Instruction::SumDim(axis) => {
self.tensor.clone().sum_dim(axis);
let tensor = self.tensor.clone() + 5;
let tensor = tensor.log();
let tensor = tensor.tanh();
let tensor = tensor * 3;
tensor.sum_dim(axis);
Comment on lines +47 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug changes? Curious what issue you were trying to track/fix by introducing the additional ops before the reduce 😄

}
Instruction::Sum => {
self.tensor.clone().sum();
Expand Down Expand Up @@ -74,18 +78,18 @@ fn bench<B: Backend>(
let mut benchmarks = Vec::new();

for axis in 0..3 {
benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::ArgMin(axis),
device.clone(),
));
// benchmarks.push(ReduceBenchmark::<B>::new(
// Instruction::ArgMin(axis),
// device.clone(),
// ));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncomment?


benchmarks.push(ReduceBenchmark::<B>::new(
Instruction::SumDim(axis),
device.clone(),
));
}

benchmarks.push(ReduceBenchmark::<B>::new(Instruction::Sum, device.clone()));
// benchmarks.push(ReduceBenchmark::<B>::new(Instruction::Sum, device.clone()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncomment?


save::<B>(
benchmarks.into_iter().map(run_benchmark).collect(),
Expand Down
2 changes: 1 addition & 1 deletion backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ macro_rules! bench_on_backend {
};
($fn_name:ident) => {
use std::env;
backend_comparison::init_log().unwrap();
// backend_comparison::init_log().unwrap();

let args: Vec<String> = env::args().collect();
let url = backend_comparison::get_sharing_url(&args);
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cubecl-fusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ burn-ir = { path = "../burn-ir", version = "0.17.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [
"cubecl",
] }
cubecl = { workspace = true, features = ["linalg"] }
cubecl = { workspace = true, features = ["linalg", "reduce"] }

half = { workspace = true }
serde = { workspace = true }
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-cubecl-fusion/src/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::marker::PhantomData;

use crate::reduce::optimization::{ReduceOptimization, ReduceOptimizationState};

use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState};
use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState};

Expand All @@ -18,6 +20,7 @@ pub enum CubeOptimization<R: Runtime> {
ElementWise(ElemwiseOptimization<R>),
/// Matrix multiplication optimization.
Matmul(MatmulOptimization<R>),
Reduce(ReduceOptimization<R>),
}

/// Fusion optimization state type for cubecl.
Expand All @@ -29,6 +32,7 @@ pub enum CubeOptimizationState {
ElementWise(ElemwiseOptimizationState),
/// Matrix multiplication optimization state.
Matmul(MatmulOptimizationState),
Reduce(ReduceOptimizationState),
}

pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
Expand Down
9 changes: 5 additions & 4 deletions crates/burn-cubecl-fusion/src/elemwise/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ use burn_fusion::OptimizationBuilder;
use cubecl::Runtime;

use crate::{
on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision, settings::FuseSettings},
shared::{builder::FuseBuilder, ir::ElemwisePrecision, settings::FuseSettings},
CubeOptimization,
};

use super::optimization::ElemwiseOptimization;

/// Fused element wise operations that are normally memory bound.
pub struct ElementWiseBuilder<R: Runtime> {
builder: FuseOnWriteBuilder,
builder: FuseBuilder,
device: R::Device,
}

Expand All @@ -21,13 +21,14 @@ impl<R: Runtime> ElementWiseBuilder<R> {
let max_bindings = props.hardware_properties().max_bindings;

Self {
builder: FuseOnWriteBuilder::new(
builder: FuseBuilder::new(
max_bindings,
bool_precision,
FuseSettings {
broadcast: true,
output_shape_updates: true,
inplace: true,
inplace: false,
vectorization: true,
},
),
device,
Expand Down
41 changes: 19 additions & 22 deletions crates/burn-cubecl-fusion/src/elemwise/optimization.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use crate::on_write::ir::GlobalArgs;
use crate::on_write::{io::global_length, kernel::fuse_on_write};
use crate::shared::io::ref_len;
use crate::shared::ir::{GlobalArgs, RefLayout};
use crate::shared::kernel::fuse_on_write;
use crate::shared::kernel::init_locals;
use crate::shared::trace::Vectorization;
use crate::CubeFusionHandle;
use burn_fusion::stream::Context;
use cubecl::{calculate_cube_count_elemwise, client::ComputeClient, prelude::*, CubeDim};
use serde::{Deserialize, Serialize};

use crate::on_write::{
use crate::shared::{
ir::{Arg, ElemwiseConfig, GlobalArgsLaunch},
trace::{FuseOnWriteTrace, TraceRunner},
trace::{FuseTrace, TraceRunner},
};

#[derive(new)]
/// Fuse element wise operations into a single kernel.
pub struct ElemwiseOptimization<R: Runtime> {
trace: FuseOnWriteTrace,
trace: FuseTrace,
client: ComputeClient<R::Server, R::Channel>,
device: R::Device,
len: usize,
Expand All @@ -22,7 +25,7 @@ pub struct ElemwiseOptimization<R: Runtime> {
#[derive(Serialize, Deserialize)]
/// State for the [elemwise optimization](ElemwiseOptimization).
pub struct ElemwiseOptimizationState {
trace: FuseOnWriteTrace,
trace: FuseTrace,
len: usize,
}

Expand Down Expand Up @@ -60,6 +63,7 @@ impl<R: Runtime> ElemwiseOptimization<R> {

pub struct ElemwiseRunner;

impl<R: Runtime> Vectorization<R> for ElemwiseRunner {}
impl<R: Runtime> TraceRunner<R> for ElemwiseRunner {
type Error = (); // No error possible

Expand All @@ -70,17 +74,13 @@ impl<R: Runtime> TraceRunner<R> for ElemwiseRunner {
outputs: GlobalArgsLaunch<'a, R>,
config: &'a ElemwiseConfig,
) -> Result<(), Self::Error> {
let arg = match config.ref_layout {
Arg::Input(index, _, _) => inputs.tensors.values.get(index as usize),
Arg::Output(index, _, _) => outputs.tensors.values.get(index as usize),
_ => panic!("Invalid value"),
};
let shape = match arg {
Some(val) => match &val.tensor {
TensorArg::Handle { handle, .. } => handle.shape,
TensorArg::Alias { .. } => panic!("Can't be an alias, got {val:?}"),
let shape = match &config.ref_layout {
RefLayout::Concrete(arg) => match arg {
Arg::Input(..) => inputs.shape_ref(&config.ref_layout),
Arg::Output(..) => outputs.shape_ref(&config.ref_layout),
_ => panic!("Invalid concreate ref layout"),
},
None => panic!("Invalid argument"),
RefLayout::Virtual(_) => inputs.shape_ref(&config.ref_layout),
};
let total_elem = shape.iter().product::<usize>() / config.width as usize;
let cube_dim = CubeDim::default();
Expand Down Expand Up @@ -112,13 +112,10 @@ fn elemwise_fuse(
let args = comptime![Sequence::<Arg>::new()];
let pos = ABSOLUTE_POS;

let length = match comptime![config.ref_layout.clone()] {
Arg::Input(index, _, _) => global_length(inputs, index),
Arg::Output(index, _, _) => global_length(outputs, index),
_ => comptime![panic!("Invalid ref layout.")],
};
let mut locals = init_locals(inputs, outputs, config);
let length = ref_len(inputs, outputs, &locals, config);

if pos < length {
fuse_on_write::<f32>(inputs, outputs, pos, values, args, config)
fuse_on_write::<f32>(inputs, outputs, &mut locals, pos, values, args, config)
}
}
3 changes: 2 additions & 1 deletion crates/burn-cubecl-fusion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ extern crate derive_new;

pub mod elemwise;
pub mod matmul;
pub mod reduce;

mod base;

pub(crate) mod on_write;
pub(crate) mod shared;
pub(crate) mod tune;

pub use base::*;
Loading
Loading