-
Notifications
You must be signed in to change notification settings - Fork 645
Add int tensor cast - WIP #3289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
_dtype: burn_tensor::IntDType, | ||
) -> RouterTensor<Self> { | ||
self.register_empty_tensor(shape, DType::I32) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right that dtype is not used here ?
(Dtype::U8, IntDType::I16) => kernel::cast::<R, u8, i16>(tensor), | ||
(Dtype::U8, IntDType::I8) => kernel::cast::<R, u8, i8>(tensor), | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A macro might be useful in this case, but I don't know anything about them.
Sorry for the late response! Been a little busy recently 😅 this was on the roadmap so I will absolutely have a look when I get a bit of time |
This PR has been marked as stale because it has not been updated for over a month |
@laggui Any update ? |
Pull Request Template
Checklist
cargo run-checks
command has been executed.Changes
As my project is blocked by both #3262 and a TensorFlow-generated ONNX issue, I decided to try implementing this feature.
My implementation is based on the existing float_cast function, so things might be different for integer ones. It is pretty much a work in progress, so there might be errors, and I look forward to receiving any feedback.
Regarding the current state of the implementation, I have added
int_cast
to all backends, but I am struggling to implement it for the ndarray backend.I have added the implementation of IntNdArrayElement (crates/burn-ndarray/src/element.rs), but since all uxx are unsigned, it errors with:
And the Sized trait cannot be removed since it produce other errors for
IntTensorOps
.I also get this error for all usage of
execute_with_int_dtype
and I don't understand it because the conversion betweenNdArrayTensor<I>
toNdArrayTensorInt
is implemented.I also noticed some inconsistencies in the file naming conventions between the backends. For example, in the ndarray backend it is named
int_tensor.rs
, whereas in cubecl it'sint_ops.rs
.What about unifying the naming?
Testing
Do I need to run tests for each backend, and where?