Skip to content

ONNX import tests fail or do not compile when switching backend from Ndarray to Autodiff<Wgpu> #3370

@lucianyao

Description

@lucianyao

Describe the bug
When switching the backend from burn_ndarray::NdArray to burn_autodiff::Autodiff<burn_wgpu::Wgpu>, and compiling and running the tests under crates/burn-import/onnx-tests/, several ONNX import tests fail.

To Reproduce

  1. in crates/burn-import/onnx-tests/Cargo.toml, add:
[features]
backend-autodiff-wgpu = ["burn-autodiff", "burn-wgpu"]

[dependencies]
burn-autodiff = { path = "../../burn-autodiff", version = "0.18.0", default-features = false, optional = true }
burn-wgpu = { path = "../../burn-wgpu", version = "0.18.0", optional = true, default-features = false }
  1. in crates/burn-import/onnx-tests/tests/test_mod.rs, add:
#[cfg(feature = "backend-autodiff-wgpu")]
type Backend = burn_autodiff::Autodiff<burn_wgpu::Wgpu>;

#[cfg(not(feature = "backend-autodiff-wgpu"))]
type Backend = burn_ndarray::NdArray<f32>;
  1. under crates/burn-import/onnx-tests/tests/, replace:
type Backend = burn_ndarray::NdArray<f32>

with:

use super::super::Backend;

in the mod.rs files, by running script under crates/burn-import/onnx-tests/tests/:

find ./ -name mod.rs -exec sed -i '' 's|^[[:space:]]*type Backend = burn_ndarray::NdArray<f32>;|use super::super::Backend;|' {} +
  1. under crates/burn-import/onnx-tests/, run:
cargo test --test test_mod --features backend-autodiff-wgpu

got:

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:41:9
   |
41 |         assert!(f_output.equal(f_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:42:9
   |
42 |         assert!(i_output.equal(i_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

error[E0308]: mismatched types
  --> crates/burn-import/onnx-tests/tests/constant_of_shape/mod.rs:43:9
   |
43 |         assert!(b_output.equal(b_expected).all().into_scalar());
   |         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `bool`, found `u32`

After commenting out pub mod constant_of_shape; in the test_mod.rs, run (4) again, got:

running 134 tests
test argmax::tests::argmax ... FAILED
test and::tests::and ... FAILED
test ceil::tests::ceil_test ... ok
test argmin::tests::argmin ... FAILED
test add::tests::add_scalar_to_int_tensor_and_int_tensor_to_int_tensor ... FAILED
......
test result: FAILED. 50 passed; 84 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.36s

Expected behavior
All code pass compiling and testing.

Desktop (please complete the following information):

  • Device: Apple M1 Pro
  • OS: mac OS
  • Version: 14.2.1 (23C71)

Additional context
The git repo version is:

commit b42a8b6556bf0f87077cdd0b2b5c0756f90ec8cb (HEAD -> main, origin/main, origin/HEAD)
Author: Jimmy Johnson <catch22@fastmail.net>
Date:   Fri Jul 11 15:13:38 2025 +0200

    Updating documentation description for nonzero and nonzero_async (#3368)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingonnx

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions