Skip to content

Add int tensor cast - WIP #3289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

Add int tensor cast - WIP #3289

wants to merge 1 commit into from

Conversation

A2va
Copy link

@A2va A2va commented Jun 12, 2025

Pull Request Template

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Changes

As my project is blocked by both #3262 and a TensorFlow-generated ONNX issue, I decided to try implementing this feature.

My implementation is based on the existing float_cast function, so things might be different for integer ones. It is pretty much a work in progress, so there might be errors, and I look forward to receiving any feedback.

Regarding the current state of the implementation, I have added int_cast to all backends, but I am struggling to implement it for the ndarray backend.
I have added the implementation of IntNdArrayElement (crates/burn-ndarray/src/element.rs), but since all uxx are unsigned, it errors with:

the trait bound `u64: Signed` is not satisfied
the following other types implement trait `Signed`:
  f32
  f64
  i128
  i16
  i32
  i64
  i8
  isize

And the Sized trait cannot be removed since it produce other errors for IntTensorOps.

I also get this error for all usage of execute_with_int_dtype and I don't understand it because the conversion between NdArrayTensor<I> to NdArrayTensorInt is implemented.

error[E0308]: mismatched types
   --> crates\burn-ndarray\src\tensor.rs:312:14
    |
311 |           match ($lhs, $rhs) {
    |                 ------------ this expression has type `(tensor::NdArrayTensor<I>, tensor::NdArrayTensor<I>)`
312 |               ($crate::NdArrayTensorInt::I64(lhs), $crate::NdArrayTensorInt::I64(rhs)) => {
    |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `NdArrayTensor<I>`, found `NdArrayTensorInt`
    |
   ::: crates\burn-ndarray\src\ops\int_tensor.rs:74:9
    |
74  | /         execute_with_int_dtype!((tensor, source), |tensor, source| {
75  | |             NdArrayMathOps::mask_where(tensor, mask, source)
76  | |         })
    | |__________- in this macro invocation
    |
    = note: expected struct `tensor::NdArrayTensor<I>`
                 found enum `tensor::NdArrayTensorInt`
    = note: this error originates in the macro `execute_with_int_dtype` (in Nightly builds, run with -Z macro-backtrace for more info)

I also noticed some inconsistencies in the file naming conventions between the backends. For example, in the ndarray backend it is named int_tensor.rs, whereas in cubecl it's int_ops.rs.
What about unifying the naming?

Testing

Do I need to run tests for each backend, and where?

@A2va A2va marked this pull request as draft June 12, 2025 19:17
_dtype: burn_tensor::IntDType,
) -> RouterTensor<Self> {
self.register_empty_tensor(shape, DType::I32)
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right that dtype is not used here ?

(Dtype::U8, IntDType::I16) => kernel::cast::<R, u8, i16>(tensor),
(Dtype::U8, IntDType::I8) => kernel::cast::<R, u8, i8>(tensor),
}
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A macro might be useful in this case, but I don't know anything about them.

@laggui
Copy link
Member

laggui commented Jun 19, 2025

Sorry for the late response! Been a little busy recently 😅 this was on the roadmap so I will absolutely have a look when I get a bit of time

Copy link
Contributor

This PR has been marked as stale because it has not been updated for over a month

@github-actions github-actions bot added the stale The issue or pr has been open for too long label Jul 21, 2025
@A2va
Copy link
Author

A2va commented Jul 25, 2025

@laggui Any update ?

@github-actions github-actions bot removed the stale The issue or pr has been open for too long label Jul 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants