Skip to content

Commit 1f92ec1

Browse files
Fix/fusion compilation speed (#3155)
1 parent 7e547c4 commit 1f92ec1

File tree

23 files changed

+338
-581
lines changed

23 files changed

+338
-581
lines changed

Cargo.lock

Lines changed: 16 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ portable-atomic = { version = "1.11.0" }
156156
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }
157157

158158
### For the main burn branch. ###
159-
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "150829d0876e5ced8e937f18abbc7e3c757e11c7" }
160-
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "150829d0876e5ced8e937f18abbc7e3c757e11c7" }
161-
cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "150829d0876e5ced8e937f18abbc7e3c757e11c7" }
159+
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "48939d2ca47473f0f30526962c3b6fb8c9b558e0" }
160+
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "48939d2ca47473f0f30526962c3b6fb8c9b558e0" }
161+
cubecl-std = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "48939d2ca47473f0f30526962c3b6fb8c9b558e0" }
162162
### For local development. ###
163163
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
164164
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }

crates/burn-cubecl-fusion/src/base.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
use std::marker::PhantomData;
2-
31
use crate::reduce::optimization::{ReduceOptimization, ReduceOptimizationState};
2+
use std::marker::PhantomData;
43

54
use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState};
65
use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState};
76

7+
use burn_fusion::stream::Context;
88
use burn_tensor::DType;
99
use cubecl::client::ComputeClient;
1010
use cubecl::ir::Elem;
@@ -35,6 +35,10 @@ pub enum CubeOptimizationState {
3535
Reduce(ReduceOptimizationState),
3636
}
3737

38+
pub trait FallbackOperation<R: Runtime>: Send + Sync {
39+
fn run(&self, context: &mut Context<'_, CubeFusionHandle<R>>);
40+
}
41+
3842
pub(crate) fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
3943
let mut strides = vec![0; shape.len()];
4044

crates/burn-cubecl-fusion/src/matmul/builder.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
use std::sync::Arc;
2-
3-
use super::MatmulFallbackFn;
41
use burn_fusion::{OptimizationBuilder, OptimizationStatus};
52
use burn_ir::{FloatOperationIr, OperationIr};
63
use cubecl::Runtime;
@@ -19,15 +16,10 @@ pub struct MatmulBuilder<R: Runtime> {
1916
builder_fallback: FuseOptimizationBuilder,
2017
device: R::Device,
2118
matmul: Option<FusedMatmul>,
22-
fallback: Arc<dyn MatmulFallbackFn<R>>,
2319
}
2420

2521
impl<R: Runtime> MatmulBuilder<R> {
26-
pub fn new(
27-
device: R::Device,
28-
bool_precision: FusePrecision,
29-
fallback: Arc<dyn MatmulFallbackFn<R>>,
30-
) -> Self {
22+
pub fn new(device: R::Device, bool_precision: FusePrecision) -> Self {
3123
let client = R::client(&device);
3224
let props = client.properties();
3325
let max_bindings = props.hardware_properties().max_bindings;
@@ -43,7 +35,6 @@ impl<R: Runtime> MatmulBuilder<R> {
4335
builder_fallback: FuseOptimizationBuilder::new(max_bindings, bool_precision, settings),
4436
device,
4537
matmul: None,
46-
fallback,
4738
}
4839
}
4940
}
@@ -94,7 +85,6 @@ impl<R: Runtime> OptimizationBuilder<CubeOptimization<R>> for MatmulBuilder<R> {
9485
self.device.clone(),
9586
self.len(),
9687
self.matmul.as_ref().unwrap().clone(),
97-
self.fallback.clone(),
9888
);
9989

10090
CubeOptimization::Matmul(matmul)

crates/burn-cubecl-fusion/src/matmul/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,3 @@ pub mod optimization;
44
pub(crate) mod args;
55
pub(crate) mod spec;
66
pub(crate) mod tune;
7-
8-
pub use optimization::MatmulFallbackFn;

crates/burn-cubecl-fusion/src/matmul/optimization.rs

Lines changed: 25 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::any::TypeId;
2-
use std::sync::Arc;
32

43
use crate::CubeFusionHandle;
4+
use crate::FallbackOperation;
55
use crate::elemwise::optimization::ElemwiseRunner;
66
use crate::shared::ir::FusePrecision;
77
use crate::shared::ir::RefLayout;
@@ -10,7 +10,7 @@ use crate::shared::trace::TuneOutput;
1010
use crate::shared::trace::Vectorization;
1111

1212
use burn_fusion::stream::Context;
13-
use burn_ir::{BinaryOpIr, TensorStatus};
13+
use burn_ir::BinaryOpIr;
1414
use cubecl::linalg::matmul::components;
1515
use cubecl::linalg::matmul::components::MatmulPrecision;
1616
use cubecl::linalg::matmul::components::MatmulProblem;
@@ -44,15 +44,7 @@ pub struct MatmulOptimization<R: Runtime> {
4444
pub(crate) len: usize,
4545
pub(crate) matmul_simple: FusedMatmul,
4646
pub(crate) matmul_double_buffering: FusedMatmul,
47-
fallback: Arc<dyn MatmulFallbackFn<R>>,
48-
}
49-
50-
pub trait MatmulFallbackFn<R: Runtime>: Send + Sync {
51-
fn run(
52-
&self,
53-
lhs: (CubeFusionHandle<R>, &[usize]),
54-
rhs: (CubeFusionHandle<R>, &[usize]),
55-
) -> CubeFusionHandle<R>;
47+
fallback: Option<Box<dyn FallbackOperation<R>>>,
5648
}
5749

5850
#[derive(Serialize, Deserialize, Debug)]
@@ -73,7 +65,6 @@ impl<R: Runtime> MatmulOptimization<R> {
7365
device: R::Device,
7466
len: usize,
7567
matmul: FusedMatmul,
76-
fallback: Arc<dyn MatmulFallbackFn<R>>,
7768
) -> Self {
7869
let mut matmul_simple = matmul.clone();
7970
let mut matmul_double_buffering = matmul;
@@ -89,11 +80,18 @@ impl<R: Runtime> MatmulOptimization<R> {
8980
len,
9081
matmul_simple,
9182
matmul_double_buffering,
92-
fallback,
83+
fallback: None,
9384
}
9485
}
9586
/// Execute the optimization.
96-
pub fn execute<BT: CubeElement>(&mut self, context: &mut Context<'_, CubeFusionHandle<R>>) {
87+
pub fn execute<BT: CubeElement>(
88+
&mut self,
89+
context: &mut Context<'_, CubeFusionHandle<R>>,
90+
fallback: impl FnOnce(usize) -> Box<dyn FallbackOperation<R>>,
91+
) {
92+
// The index of the fallback matmul is always 0.
93+
self.fallback = Some(fallback(0));
94+
9795
#[cfg(feature = "autotune")]
9896
fused_matmul_autotune::<R, BT>(self, context);
9997

@@ -109,11 +107,7 @@ impl<R: Runtime> MatmulOptimization<R> {
109107
}
110108

111109
/// Create an optimization from its [state](MatmulOptimizationState).
112-
pub fn from_state(
113-
device: &R::Device,
114-
state: MatmulOptimizationState,
115-
fallback: Arc<dyn MatmulFallbackFn<R>>,
116-
) -> Self {
110+
pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self {
117111
Self {
118112
trace: state.trace,
119113
trace_fallback: state.trace_fallback,
@@ -122,7 +116,7 @@ impl<R: Runtime> MatmulOptimization<R> {
122116
device: device.clone(),
123117
matmul_simple: state.matmul_simple.clone(),
124118
matmul_double_buffering: state.matmul_double_buffering.clone(),
125-
fallback,
119+
fallback: None,
126120
}
127121
}
128122

@@ -170,31 +164,11 @@ impl<R: Runtime> MatmulOptimization<R> {
170164
&self,
171165
context: &mut Context<'_, CubeFusionHandle<R>>,
172166
) -> TuneOutput<R> {
173-
let (out_tensor, out_desc) = {
174-
let lhs = context
175-
.tensors
176-
.get(&self.matmul_simple.op.lhs.id)
177-
.unwrap()
178-
.clone();
179-
let rhs = context
180-
.tensors
181-
.get(&self.matmul_simple.op.rhs.id)
182-
.unwrap()
183-
.clone();
184-
let out = context
185-
.tensors
186-
.get(&self.matmul_simple.op.out.id)
187-
.unwrap()
188-
.clone();
189-
190-
let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly);
191-
let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly);
192-
let out_handle = self
193-
.fallback
194-
.run((lhs_handle, &lhs.shape), (rhs_handle, &rhs.shape));
195-
196-
(out_handle, out)
197-
};
167+
self.fallback
168+
.as_ref()
169+
.expect("A fallback operation should be available")
170+
.run(context);
171+
198172
#[cfg(feature = "autotune-checks")]
199173
let mut output = TuneOutput::Checked {
200174
handles: Default::default(),
@@ -204,12 +178,16 @@ impl<R: Runtime> MatmulOptimization<R> {
204178

205179
#[cfg(feature = "autotune-checks")]
206180
if let TuneOutput::Checked { handles } = &mut output {
181+
let out_desc = context.tensors.get(&self.matmul_simple.op.out.id).unwrap();
182+
let handle_out = context
183+
.handles
184+
.get_handle(&out_desc.id, &burn_ir::TensorStatus::ReadOnly);
185+
207186
handles.insert(
208187
self.matmul_simple.op.out.id,
209-
(out_desc.shape.clone(), out_tensor.clone()),
188+
(out_desc.shape.clone(), handle_out.clone()),
210189
);
211190
}
212-
context.handles.register_handle(out_desc.id, out_tensor);
213191

214192
let output_write = self
215193
.trace_fallback

0 commit comments

Comments
 (0)