-
Notifications
You must be signed in to change notification settings - Fork 645
Open
Labels
Description
Describe the bug
When switching the backend from burn_ndarray::NdArray to burn_autodiff::Autodiff<burn_wgpu::Wgpu>, and compiling and running the tests under crates/burn-import/onnx-tests/, several ONNX import tests fail.
To Reproduce
- in
crates/burn-import/onnx-tests/Cargo.toml
, add:
[features]
backend-autodiff-wgpu = ["burn-autodiff", "burn-wgpu"]
[dependencies]
burn-autodiff = { path = "../../burn-autodiff", version = "0.18.0", default-features = false, optional = true }
burn-wgpu = { path = "../../burn-wgpu", version = "0.18.0", optional = true, default-features = false }
- in
crates/burn-import/onnx-tests/tests/test_mod.rs
, add:
#[cfg(feature = "backend-autodiff-wgpu")]
type Backend = burn_autodiff::Autodiff<burn_wgpu::Wgpu>;
#[cfg(not(feature = "backend-autodiff-wgpu"))]
type Backend = burn_ndarray::NdArray<f32>;
- under
crates/burn-import/onnx-tests/tests/
, replace:
type Backend = burn_ndarray::NdArray<f32>
with:
use super::super::Backend;
in the mod.rs
files, by running script under crates/burn-import/onnx-tests/tests/
:
find ./ -name mod.rs -exec sed -i '' 's|^[[:space:]]*type Backend = burn_ndarray::NdArray<f32>;|use super::super::Backend;|' {} +
- under
crates/burn-import/onnx-tests/
, run:
cargo test --test test_mod --features backend-autodiff-wgpu
got:
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:41:9
|
41 | assert!(f_output.equal(f_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:42:9
|
42 | assert!(i_output.equal(i_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
error[E0308]: mismatched types
--> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:43:9
|
43 | assert!(b_output.equal(b_expected).all().into_scalar());
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`
After commenting out pub mod constant_of_shape;
in the test_mod.rs
, run (4) again, got:
running 134 tests
test argmax::tests::argmax ... FAILED
test and::tests::and ... FAILED
test ceil::tests::ceil_test ... ok
test argmin::tests::argmin ... FAILED
test add::tests::add_scalar_to_int_tensor_and_int_tensor_to_int_tensor ... FAILED
......
test result: FAILED. 50 passed; 84 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.36s
Expected behavior
All code pass compiling and testing.
Desktop (please complete the following information):
- Device: Apple M1 Pro
- OS: mac OS
- Version: 14.2.1 (23C71)
Additional context
The git repo version is:
commit b42a8b6556bf0f87077cdd0b2b5c0756f90ec8cb (HEAD -> main, origin/main, origin/HEAD)
Author: Jimmy Johnson <catch22@fastmail.net>
Date: Fri Jul 11 15:13:38 2025 +0200
Updating documentation description for nonzero and nonzero_async (#3368)