Skip to content

Fuse gather #2793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 54 additions & 34 deletions crates/burn-jit/src/fusion/on_write/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,22 @@ impl TryFuseBuilder {
}
}

fn register(&mut self, add_ops: impl FnOnce(&mut FuseOnWriteTraceBuilder) -> bool) -> bool {
fn register(
&mut self,
add_ops: impl FnOnce(&mut FuseOnWriteTraceBuilder) -> Option<()>,
) -> bool {
// Always allow the first operation to be added.
if !self.added_ops {
self.added_ops = true;

if !add_ops(&mut self.builder) {
if add_ops(&mut self.builder).is_none() {
return false;
}
return true;
}

let mut cloned = self.builder.clone();
if !add_ops(&mut cloned) {
if add_ops(&mut cloned).is_none() {
return false;
}

Expand Down Expand Up @@ -217,15 +220,12 @@ impl FuseOnWriteBuilder {
}

if self.builder.register(|build| {
let input = match build.input_reshaped(&desc.input, &desc.out) {
Some(val) => val,
None => return false,
};
let out = build.output(&desc.out);
let input = build.input_reshaped(&desc.input, &desc.out)?;
let out = build.output(&desc.out)?;

build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }));

true
Some(())
}) {
self.num_reshapes += 1;
true
Expand Down Expand Up @@ -353,10 +353,10 @@ impl FuseOnWriteBuilder {
}

self.builder.register(|build| {
let cond = build.input(&desc.mask);
let lhs = build.input(&desc.value);
let rhs = build.input(&desc.tensor);
let out = build.output(&desc.out);
let cond = build.input(&desc.mask)?;
let lhs = build.input(&desc.value)?;
let rhs = build.input(&desc.tensor)?;
let out = build.output(&desc.out)?;

build.register_operation(ElemwiseOp::ConditionalAssign {
cond,
Expand All @@ -365,7 +365,7 @@ impl FuseOnWriteBuilder {
out,
});

true
Some(())
})
}
NumericOperationDescription::MaskFill(desc) => {
Expand All @@ -374,10 +374,10 @@ impl FuseOnWriteBuilder {
}

self.builder.register(|build| {
let cond = build.input(&desc.mask);
let cond = build.input(&desc.mask)?;
let lhs = build.scalar(&desc.value, desc.out.dtype);
let rhs = build.input(&desc.tensor);
let out = build.output(&desc.out);
let rhs = build.input(&desc.tensor)?;
let out = build.output(&desc.out)?;

build.register_operation(ElemwiseOp::ConditionalAssign {
cond,
Expand All @@ -386,7 +386,7 @@ impl FuseOnWriteBuilder {
out,
});

true
Some(())
})
}
NumericOperationDescription::Ones(desc) => {
Expand All @@ -399,11 +399,11 @@ impl FuseOnWriteBuilder {
let input = Arg::Literal(1, precision);

self.builder.register(|build| {
let out = build.output(desc);
let out = build.output(desc)?;

build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }));

true
Some(())
})
}
NumericOperationDescription::Zeros(desc) => {
Expand All @@ -416,11 +416,11 @@ impl FuseOnWriteBuilder {
let input = Arg::Literal(0, precision);

self.builder.register(|build| {
let out = build.output(desc);
let out = build.output(desc)?;

build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }));

true
Some(())
})
}
NumericOperationDescription::Full((desc, elem)) => {
Expand All @@ -430,11 +430,31 @@ impl FuseOnWriteBuilder {

self.builder.register(|build| {
let input = build.scalar(elem, desc.dtype);
let out = build.output(desc);
let out = build.output(desc)?;

build.register_operation(ElemwiseOp::Assign(UnaryElemwiseArgs { input, out }));

true
Some(())
})
}
NumericOperationDescription::Gather(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}

self.builder.register(|build| {
let input = build.input_indexed(&desc.tensor)?;
let indices = build.input(&desc.indices)?;
let output = build.output(&desc.out)?;

build.register_operation(ElemwiseOp::Gather {
input,
indices,
output,
dim: desc.dim as u32,
});

Some(())
})
}
_ => false,
Expand All @@ -450,13 +470,13 @@ impl FuseOnWriteBuilder {
}

self.builder.register(|build| {
let lhs = build.input(&desc.lhs);
let rhs = build.input(&desc.rhs);
let out = build.output(&desc.out);
let lhs = build.input(&desc.lhs)?;
let rhs = build.input(&desc.rhs)?;
let out = build.output(&desc.out)?;

build.register_operation(func(lhs, rhs, out));

true
Some(())
})
}

Expand All @@ -469,10 +489,10 @@ impl FuseOnWriteBuilder {
}

self.builder.register(|build| {
let input = build.input(&desc.input);
let out = build.output(&desc.out);
let input = build.input(&desc.input)?;
let out = build.output(&desc.out)?;
build.register_operation(func(input, out));
true
Some(())
})
}

Expand All @@ -490,13 +510,13 @@ impl FuseOnWriteBuilder {

self.builder.register(|build| {
let elem = desc.lhs.dtype;
let lhs = build.input(&desc.lhs);
let lhs = build.input(&desc.lhs)?;
let rhs = build.scalar(&desc.rhs, elem);
let out = build.output(&desc.out);
let out = build.output(&desc.out)?;

build.register_operation(func(lhs, rhs, out));

true
Some(())
})
}

Expand Down
Loading