1
1
use std:: any:: TypeId ;
2
- use std:: sync:: Arc ;
3
2
4
3
use crate :: CubeFusionHandle ;
4
+ use crate :: FallbackOperation ;
5
5
use crate :: elemwise:: optimization:: ElemwiseRunner ;
6
6
use crate :: shared:: ir:: FusePrecision ;
7
7
use crate :: shared:: ir:: RefLayout ;
@@ -10,7 +10,7 @@ use crate::shared::trace::TuneOutput;
10
10
use crate :: shared:: trace:: Vectorization ;
11
11
12
12
use burn_fusion:: stream:: Context ;
13
- use burn_ir:: { BinaryOpIr , TensorStatus } ;
13
+ use burn_ir:: BinaryOpIr ;
14
14
use cubecl:: linalg:: matmul:: components;
15
15
use cubecl:: linalg:: matmul:: components:: MatmulPrecision ;
16
16
use cubecl:: linalg:: matmul:: components:: MatmulProblem ;
@@ -44,15 +44,7 @@ pub struct MatmulOptimization<R: Runtime> {
44
44
pub ( crate ) len : usize ,
45
45
pub ( crate ) matmul_simple : FusedMatmul ,
46
46
pub ( crate ) matmul_double_buffering : FusedMatmul ,
47
- fallback : Arc < dyn MatmulFallbackFn < R > > ,
48
- }
49
-
50
- pub trait MatmulFallbackFn < R : Runtime > : Send + Sync {
51
- fn run (
52
- & self ,
53
- lhs : ( CubeFusionHandle < R > , & [ usize ] ) ,
54
- rhs : ( CubeFusionHandle < R > , & [ usize ] ) ,
55
- ) -> CubeFusionHandle < R > ;
47
+ fallback : Option < Box < dyn FallbackOperation < R > > > ,
56
48
}
57
49
58
50
#[ derive( Serialize , Deserialize , Debug ) ]
@@ -73,7 +65,6 @@ impl<R: Runtime> MatmulOptimization<R> {
73
65
device : R :: Device ,
74
66
len : usize ,
75
67
matmul : FusedMatmul ,
76
- fallback : Arc < dyn MatmulFallbackFn < R > > ,
77
68
) -> Self {
78
69
let mut matmul_simple = matmul. clone ( ) ;
79
70
let mut matmul_double_buffering = matmul;
@@ -89,11 +80,18 @@ impl<R: Runtime> MatmulOptimization<R> {
89
80
len,
90
81
matmul_simple,
91
82
matmul_double_buffering,
92
- fallback,
83
+ fallback : None ,
93
84
}
94
85
}
95
86
/// Execute the optimization.
96
- pub fn execute < BT : CubeElement > ( & mut self , context : & mut Context < ' _ , CubeFusionHandle < R > > ) {
87
+ pub fn execute < BT : CubeElement > (
88
+ & mut self ,
89
+ context : & mut Context < ' _ , CubeFusionHandle < R > > ,
90
+ fallback : impl FnOnce ( usize ) -> Box < dyn FallbackOperation < R > > ,
91
+ ) {
92
+ // The index of the fallback matmul is always 0.
93
+ self . fallback = Some ( fallback ( 0 ) ) ;
94
+
97
95
#[ cfg( feature = "autotune" ) ]
98
96
fused_matmul_autotune :: < R , BT > ( self , context) ;
99
97
@@ -109,11 +107,7 @@ impl<R: Runtime> MatmulOptimization<R> {
109
107
}
110
108
111
109
/// Create an optimization from its [state](MatmulOptimizationState).
112
- pub fn from_state (
113
- device : & R :: Device ,
114
- state : MatmulOptimizationState ,
115
- fallback : Arc < dyn MatmulFallbackFn < R > > ,
116
- ) -> Self {
110
+ pub fn from_state ( device : & R :: Device , state : MatmulOptimizationState ) -> Self {
117
111
Self {
118
112
trace : state. trace ,
119
113
trace_fallback : state. trace_fallback ,
@@ -122,7 +116,7 @@ impl<R: Runtime> MatmulOptimization<R> {
122
116
device : device. clone ( ) ,
123
117
matmul_simple : state. matmul_simple . clone ( ) ,
124
118
matmul_double_buffering : state. matmul_double_buffering . clone ( ) ,
125
- fallback,
119
+ fallback : None ,
126
120
}
127
121
}
128
122
@@ -170,31 +164,11 @@ impl<R: Runtime> MatmulOptimization<R> {
170
164
& self ,
171
165
context : & mut Context < ' _ , CubeFusionHandle < R > > ,
172
166
) -> TuneOutput < R > {
173
- let ( out_tensor, out_desc) = {
174
- let lhs = context
175
- . tensors
176
- . get ( & self . matmul_simple . op . lhs . id )
177
- . unwrap ( )
178
- . clone ( ) ;
179
- let rhs = context
180
- . tensors
181
- . get ( & self . matmul_simple . op . rhs . id )
182
- . unwrap ( )
183
- . clone ( ) ;
184
- let out = context
185
- . tensors
186
- . get ( & self . matmul_simple . op . out . id )
187
- . unwrap ( )
188
- . clone ( ) ;
189
-
190
- let lhs_handle = context. handles . get_handle ( & lhs. id , & TensorStatus :: ReadOnly ) ;
191
- let rhs_handle = context. handles . get_handle ( & rhs. id , & TensorStatus :: ReadOnly ) ;
192
- let out_handle = self
193
- . fallback
194
- . run ( ( lhs_handle, & lhs. shape ) , ( rhs_handle, & rhs. shape ) ) ;
195
-
196
- ( out_handle, out)
197
- } ;
167
+ self . fallback
168
+ . as_ref ( )
169
+ . expect ( "A fallback operation should be available" )
170
+ . run ( context) ;
171
+
198
172
#[ cfg( feature = "autotune-checks" ) ]
199
173
let mut output = TuneOutput :: Checked {
200
174
handles : Default :: default ( ) ,
@@ -204,12 +178,16 @@ impl<R: Runtime> MatmulOptimization<R> {
204
178
205
179
#[ cfg( feature = "autotune-checks" ) ]
206
180
if let TuneOutput :: Checked { handles } = & mut output {
181
+ let out_desc = context. tensors . get ( & self . matmul_simple . op . out . id ) . unwrap ( ) ;
182
+ let handle_out = context
183
+ . handles
184
+ . get_handle ( & out_desc. id , & burn_ir:: TensorStatus :: ReadOnly ) ;
185
+
207
186
handles. insert (
208
187
self . matmul_simple . op . out . id ,
209
- ( out_desc. shape . clone ( ) , out_tensor . clone ( ) ) ,
188
+ ( out_desc. shape . clone ( ) , handle_out . clone ( ) ) ,
210
189
) ;
211
190
}
212
- context. handles . register_handle ( out_desc. id , out_tensor) ;
213
191
214
192
let output_write = self
215
193
. trace_fallback
0 commit comments