Skip to content

Commit b440d20

Browse files
Remove some unsafe code (#3385)
1 parent 4f8f7a4 commit b440d20

File tree

7 files changed

+32
-35
lines changed

7 files changed

+32
-35
lines changed

crates/burn-cubecl/src/fusion.rs

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use burn_ir::{BackendIr, TensorHandle};
1616
use burn_tensor::{DType, Shape};
1717
use core::marker::PhantomData;
1818
use half::{bf16, f16};
19+
use std::sync::Arc;
1920

2021
impl<R, BT> burn_fusion::Optimization<FusionCubeRuntime<R, BT>> for CubeOptimization<R>
2122
where
@@ -34,11 +35,11 @@ where
3435
Self::ElementWise(op) => op.execute::<BT>(context),
3536
Self::Matmul(op) => op.execute::<BT>(context, |index| {
3637
let operation = execution.operation_within_optimization(index);
37-
Box::new(FallbackOperationUnsafe::new(operation))
38+
Box::new(FallbackOperationWrapper::new(operation))
3839
}),
3940
Self::Reduce(op) => op.execute::<BT>(context, |index| {
4041
let operation = execution.operation_within_optimization(index);
41-
Box::new(FallbackOperationUnsafe::new(operation))
42+
Box::new(FallbackOperationWrapper::new(operation))
4243
}),
4344
}
4445
}
@@ -62,33 +63,21 @@ where
6263
}
6364
}
6465

65-
/// This is only safe because we know the fallback must be executed before the cubecl context is dropped.
66-
///
67-
/// The safer alternatives would require fused operation to be cloned, so that it could
68-
/// escape the lifetime of the context's execution, which doesn't make sense since
69-
/// its only goal is to modify the context it operates on.
70-
struct FallbackOperationUnsafe<O> {
71-
operation: *const O,
66+
struct FallbackOperationWrapper<O: Clone> {
67+
operation: O,
7268
}
7369

74-
unsafe impl<O> Send for FallbackOperationUnsafe<O> {}
75-
unsafe impl<O> Sync for FallbackOperationUnsafe<O> {}
76-
77-
impl<O> FallbackOperationUnsafe<O> {
78-
fn new(op: &O) -> Self {
79-
let ptr = core::ptr::from_ref(op);
80-
81-
Self { operation: ptr }
70+
impl<O: Clone> FallbackOperationWrapper<O> {
71+
fn new(op: O) -> Self {
72+
Self { operation: op }
8273
}
8374
}
8475

8576
impl<R: CubeRuntime, BT: BoolElement> FallbackOperation<R>
86-
for FallbackOperationUnsafe<Box<dyn Operation<FusionCubeRuntime<R, BT>>>>
77+
for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R, BT>>>>
8778
{
8879
fn run(&self, context: &mut burn_fusion::stream::Context<'_, CubeFusionHandle<R>>) {
89-
unsafe {
90-
self.operation.as_ref().unwrap().execute(context.handles);
91-
}
80+
self.operation.as_ref().execute(context.handles);
9281
}
9382
}
9483

crates/burn-fusion/src/client/mutex.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ where
4343
{
4444
self.server
4545
.lock()
46-
.register(streams, repr, Box::new(operation))
46+
.register(streams, repr, Arc::new(operation))
4747
}
4848

4949
fn drain(&self) {

crates/burn-fusion/src/server.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::{
24
FusionBackend, FusionRuntime,
35
stream::{MultiStream, OperationStreams, StreamId, execution::Operation},
@@ -25,7 +27,7 @@ where
2527
&mut self,
2628
streams: OperationStreams,
2729
repr: OperationIr,
28-
operation: Box<dyn Operation<R>>,
30+
operation: Arc<dyn Operation<R>>,
2931
) {
3032
self.streams
3133
.register(streams, repr, operation, &mut self.handles)

crates/burn-fusion/src/stream/execution/ordering.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use super::Operation;
88

99
/// Manage the execution of potentially multiple optimizations and operations out of order.
1010
pub struct OrderedExecution<R: FusionRuntime> {
11-
operations: Vec<Box<dyn Operation<R>>>,
11+
operations: Vec<Arc<dyn Operation<R>>>,
1212
num_executed: usize,
1313
ordering: Option<Arc<Vec<usize>>>,
1414
}
@@ -18,25 +18,25 @@ impl<R: FusionRuntime> OrderedExecution<R> {
1818
///
1919
/// This is useful to implement fallback for optimizations.
2020
#[allow(clippy::borrowed_box)]
21-
pub fn operation_within_optimization(&self, index: usize) -> &Box<dyn Operation<R>> {
21+
pub fn operation_within_optimization(&self, index: usize) -> Arc<dyn Operation<R>> {
2222
match &self.ordering {
2323
Some(val) => {
2424
let index = val[index];
25-
&self.operations[index]
25+
self.operations[index].clone()
2626
}
2727
None => panic!("No ordering provided"),
2828
}
2929
}
3030

31-
pub(crate) fn new(operations: Vec<Box<dyn Operation<R>>>) -> Self {
31+
pub(crate) fn new(operations: Vec<Arc<dyn Operation<R>>>) -> Self {
3232
Self {
3333
operations,
3434
num_executed: 0,
3535
ordering: None,
3636
}
3737
}
3838

39-
pub(crate) fn finish(mut self) -> (Vec<Box<dyn Operation<R>>>, usize) {
39+
pub(crate) fn finish(mut self) -> (Vec<Arc<dyn Operation<R>>>, usize) {
4040
self.operations.drain(0..self.num_executed);
4141
(self.operations, self.num_executed)
4242
}

crates/burn-fusion/src/stream/multi.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use burn_ir::{HandleContainer, OperationIr, TensorId, TensorIr, TensorStatus};
24
use hashbrown::{HashMap, HashSet};
35

@@ -47,7 +49,7 @@ impl<R: FusionRuntime> MultiStream<R> {
4749
&mut self,
4850
streams: OperationStreams,
4951
mut repr: OperationIr,
50-
operation: Box<dyn Operation<R>>,
52+
operation: Arc<dyn Operation<R>>,
5153
handles: &mut HandleContainer<R::FusionHandle>,
5254
) {
5355
let id = self.resolve_streams(&streams, handles, &mut repr);
@@ -133,7 +135,7 @@ impl<R: FusionRuntime> MultiStream<R> {
133135
id: StreamId,
134136
repr: OperationIr,
135137
streams: &OperationStreams,
136-
operation: Box<dyn Operation<R>>,
138+
operation: Arc<dyn Operation<R>>,
137139
handles: &mut HandleContainer<R::FusionHandle>,
138140
) -> usize {
139141
let stream = match self.streams.get_mut(&id) {
@@ -355,7 +357,7 @@ impl<R: FusionRuntime> MultiStream<R> {
355357
current,
356358
};
357359

358-
let op = Box::new(DropOp { id: tensor.id });
360+
let op = Arc::new(DropOp { id: tensor.id });
359361
self.register(streams, OperationIr::Drop(tensor), op, handles);
360362
}
361363
}

crates/burn-fusion/src/stream/queue/base.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use crate::FusionRuntime;
24
use crate::stream::{OperationConverter, OperationStreams, RelativeOps, execution::Operation};
35
use burn_common::id::StreamId;
@@ -17,7 +19,7 @@ pub struct OperationQueue<R: FusionRuntime> {
1719
/// determine which operations can be fused.
1820
pub(crate) relative: Vec<OperationIr>,
1921
pub(crate) converter: OperationConverter,
20-
pub(crate) operations: Vec<Box<dyn Operation<R>>>,
22+
pub(crate) operations: Vec<Arc<dyn Operation<R>>>,
2123
pub(crate) variables: HashMap<TensorId, (StreamId, TensorStatus)>,
2224
}
2325

@@ -47,7 +49,7 @@ impl<R: FusionRuntime> OperationQueue<R> {
4749
pub fn add(
4850
&mut self,
4951
global: OperationIr,
50-
operation: Box<dyn Operation<R>>,
52+
operation: Arc<dyn Operation<R>>,
5153
streams: &OperationStreams,
5254
current: StreamId,
5355
) {

crates/burn-fusion/src/stream/queue/execution.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use burn_ir::{HandleContainer, TensorStatus};
24

35
use crate::{
@@ -84,8 +86,8 @@ impl<'a, R: FusionRuntime> QueueExecution<'a, R> {
8486
optimization: &mut BlockOptimization<R::Optimization>,
8587
converter: &'a mut OperationConverter,
8688
handles: &'a mut HandleContainer<R::FusionHandle>,
87-
operations: Vec<Box<dyn Operation<R>>>,
88-
) -> (Vec<Box<dyn Operation<R>>>, usize) {
89+
operations: Vec<Arc<dyn Operation<R>>>,
90+
) -> (Vec<Arc<dyn Operation<R>>>, usize) {
8991
let execution = OrderedExecution::new(operations);
9092

9193
if matches!(&optimization.strategy, ExecutionStrategy::Composed(..)) {

0 commit comments

Comments
 (0)