-
Notifications
You must be signed in to change notification settings - Fork 645
[Feature] reduce fuse on read #2870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments, otherwise LGTM!
/edit: oh, looks like group_norm_forward_affine_false
is failing with precision errors with wgpu (cargo test --color always --features test-wgpu -p burn-core
)
let tensor = self.tensor.clone() + 5; | ||
let tensor = tensor.log(); | ||
let tensor = tensor.tanh(); | ||
let tensor = tensor * 3; | ||
tensor.sum_dim(axis); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug changes? Curious what issue you were trying to track/fix by introducing the additional ops before the reduce 😄
backend-comparison/benches/reduce.rs
Outdated
// benchmarks.push(ReduceBenchmark::<B>::new( | ||
// Instruction::ArgMin(axis), | ||
// device.clone(), | ||
// )); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncomment?
backend-comparison/benches/reduce.rs
Outdated
|
||
benchmarks.push(ReduceBenchmark::<B>::new( | ||
Instruction::SumDim(axis), | ||
device.clone(), | ||
)); | ||
} | ||
|
||
benchmarks.push(ReduceBenchmark::<B>::new(Instruction::Sum, device.clone())); | ||
// benchmarks.push(ReduceBenchmark::<B>::new(Instruction::Sum, device.clone())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uncomment?
#[cube] | ||
pub fn global_len(global: &GlobalArgs, #[comptime] pos: u32) -> u32 { | ||
let tensor = global.tensors.index(pos); | ||
tensor.tensor.len() | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this a duplicate of global_length
defined just a couple lines above? Minus the cast (which seems redudant, actually).
#[cube]
pub fn global_length(global: &GlobalArgs, #[comptime] pos: u32) -> u32 {
let tensor = global.tensors.index(pos);
u32::cast_from(tensor.tensor.len())
}
@@ -84,6 +85,18 @@ mod tests { | |||
output.into_data().assert_eq(&expected, false); | |||
} | |||
|
|||
#[test] | |||
fn test_sum_dim_reshape_maybe_fused() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have sum_dim
and mean_dim
tests in the maxmin module? 😅 We should probably move test_sum_dim_2d()
, test_mean_dim_2d()
and this new test to the correct module.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2870 +/- ##
==========================================
- Coverage 82.31% 82.29% -0.03%
==========================================
Files 863 867 +4
Lines 116956 118080 +1124
==========================================
+ Hits 96268 97169 +901
- Misses 20688 20911 +223 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Uh oh!
There was an error while loading. Please reload this page.