Skip to content

Commit c2f8779

Browse files
AshAnand34antimora
andauthored
Adding bitwise ONNX ops (#3120)
* Started integrating bitwise operators to burn * Started developing the bitwise operators for burn * Created python files to create test ONNX models for the bitwise operators * Fixing bitshift.rs to account for direction * Added test onnx files to build and added rank inference for the bitwise operations. * Made fixes to test onnx files due to wrong node inits and restricted tensor types in bitwise operations * Fixed test_onnx.rs to include the right onnx file names for bitshift * Created onnx model tests for bitwise operators * rank fixing * Fixed unit tests for bitwise operators * More unit test fixes in bitwise operators * created onnx files for scalar versions of the bitwise operators * Integrated scalar versions of onnx model to the onnx tests * Added scalar argtypes to to_burn * Added bitshift config plus minor fixes * Fixing formatting errors * resolve cubecl latest version * Remove unused op_configuration module import * Refactor bitwise node handling for scalar support Updated bitwise and bitshift node implementations to accept both tensor and scalar inputs via the Type enum, enabling proper handling of scalar operations. Adjusted code generation logic and import registration for scalar cases, and updated ONNX conversion functions to use Type instead of TensorType. Also fixed a test expectation in bitwise_xor and enabled bitwise tests in test_mod.rs. * Refactor BitShift ONNX test generation and enable tests Rewrote bitshift.py to generate ONNX models directly using the onnx API, supporting both tensor and scalar shift inputs. Updated all BitShift ONNX test models to match the new generation method and re-enabled the bitshift test module in test_mod.rs. * Fix formatting and logging in node modules * Refactor bitwise operations to use Tensor methods Replaced direct bitwise operators with corresponding Tensor methods (e.g., bitwise_and, bitwise_or, bitwise_not, bitwise_xor, bitwise_left_shift, bitwise_right_shift) in test modules for bitwise node implementations. Also updated input type registration in BitwiseAndNode tests for consistency. * Revert powf and powi changes * Fix formating * Add missing newline at end of bitshift.rs * Update ONNX ops support status in documentation Marked BitShift, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, and IsInf as supported in the ONNX ops table for both import and export. This reflects recent changes in operator support. * Fix imports in bernoulli test and base node module Removed unused tensor imports in bernoulli test module and added missing SpaceToDepthNode import in base node module to resolve compilation issues. * Refactor BitShift direction to enum type Replaces string-based direction for BitShift operations with a strongly-typed enum in both burn-import and onnx-ir crates. Updates conversion logic and tests to use the new Direction enum, improving type safety and clarity. * Refactor bitwise node input handling logic Simplifies and unifies input type matching for BitShiftNode, BitwiseAndNode, BitwiseOrNode, and BitwiseXorNode. Now explicitly panics for unsupported input combinations and removes redundant variable assignments, improving code clarity and error handling. * Support scalar-tensor inputs for bitwise and bitshift ops Added support for cases where the first input is a scalar and the second is a tensor for BitShift, BitwiseAnd, BitwiseOr, and BitwiseXor nodes. Updated ONNX test models, Python export scripts, Rust test modules, and codegen logic to handle these cases, leveraging commutativity for bitwise ops and broadcasting for bitshift. Rank inference for these ops now uses broadcasting-aware logic. * Update .gitignore * Add scalar support for bitwise and bitshift nodes Extended BitShift, BitwiseAnd, BitwiseOr, and BitwiseXor nodes to support scalar inputs and outputs. Added new ONNX test models and Python scripts for scalar bitwise and bitshift operations, and updated Rust test modules to cover scalar cases. Refactored codegen and conversion logic to handle both tensor and scalar types for these nodes. * Refactor ONNX test model generation scripts Unified and generalized the Python scripts for generating BitShift and BitwiseAnd ONNX test models, replacing multiple specialized scripts with parameterized functions. Removed redundant scalar-only scripts and updated all related ONNX model files to match the new generation logic. * Refactor bitshift operation matching logic Simplifies the match statement in BitShiftNode by removing direction from pattern matching and handling it within each arm. This improves code readability and maintainability. * Remove output type checks from BitShiftNode --------- Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
1 parent caba944 commit c2f8779

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1742
-22
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ out
1515

1616
# Virtual Environment of Python
1717
.venv
18+
uv.lock

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ functionality.
2525
| [AveragePool2d][12] |||
2626
| [BatchNormalization][14] |||
2727
| [Bernoulli][15] |||
28-
| [BitShift][16] | ||
29-
| [BitwiseAnd][17] | ||
30-
| [BitwiseNot][18] | ||
31-
| [BitwiseOr][19] | ||
32-
| [BitwiseXor][20] | ||
28+
| [BitShift][16] | ||
29+
| [BitwiseAnd][17] | ||
30+
| [BitwiseNot][18] | ||
31+
| [BitwiseOr][19] | ||
32+
| [BitwiseXor][20] | ||
3333
| [BlackmanWindow][21] |||
3434
| [Cast][22] |||
3535
| [CastLike][23] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,27 @@ fn main() {
1414
.input("tests/avg_pool1d/avg_pool1d.onnx")
1515
.input("tests/avg_pool2d/avg_pool2d.onnx")
1616
.input("tests/batch_norm/batch_norm.onnx")
17+
.input("tests/bitshift/bitshift_left.onnx")
18+
.input("tests/bitshift/bitshift_left_scalar.onnx")
19+
.input("tests/bitshift/scalar_bitshift_left.onnx")
20+
.input("tests/bitshift/scalar_bitshift_left_scalar.onnx")
21+
.input("tests/bitshift/bitshift_right.onnx")
22+
.input("tests/bitshift/bitshift_right_scalar.onnx")
23+
.input("tests/bitshift/scalar_bitshift_right.onnx")
24+
.input("tests/bitshift/scalar_bitshift_right_scalar.onnx")
25+
.input("tests/bitwise_and/bitwise_and.onnx")
26+
.input("tests/bitwise_and/bitwise_and_scalar.onnx")
27+
.input("tests/bitwise_and/scalar_bitwise_and.onnx")
28+
.input("tests/bitwise_and/scalar_bitwise_and_scalar.onnx")
29+
.input("tests/bitwise_not/bitwise_not.onnx")
30+
.input("tests/bitwise_or/bitwise_or.onnx")
31+
.input("tests/bitwise_or/bitwise_or_scalar.onnx")
32+
.input("tests/bitwise_or/scalar_bitwise_or.onnx")
33+
.input("tests/bitwise_or/scalar_bitwise_or_scalar.onnx")
34+
.input("tests/bitwise_xor/bitwise_xor.onnx")
35+
.input("tests/bitwise_xor/bitwise_xor_scalar.onnx")
36+
.input("tests/bitwise_xor/scalar_bitwise_xor.onnx")
37+
.input("tests/bitwise_xor/scalar_bitwise_xor_scalar.onnx")
1738
.input("tests/bernoulli/bernoulli.onnx")
1839
.input("tests/cast/cast.onnx")
1940
.input("tests/ceil/ceil.onnx")

crates/burn-import/onnx-tests/tests/bernoulli/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ include_models!(bernoulli);
55
mod tests {
66
use super::*;
77
use burn::tensor::Shape;
8-
use burn::tensor::{Tensor, TensorData, Tolerance, ops::FloatElem};
8+
use burn::tensor::Tensor;
99

1010
type Backend = burn_ndarray::NdArray<f32>;
1111

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#!/usr/bin/env python3
2+
# used to generate all bitshift ONNX models
3+
4+
import onnx
5+
6+
def build_model(name, input1_shape, input2_shape, output_shape, direction):
7+
"""
8+
Build a BitShift ONNX model with specified input/output shapes and direction.
9+
10+
Args:
11+
name: Name of the model (used for file naming)
12+
input1_shape: Shape of first input ([] for scalar)
13+
input2_shape: Shape of second input ([] for scalar)
14+
output_shape: Shape of output ([] for scalar)
15+
direction: "LEFT" or "RIGHT"
16+
"""
17+
op_type = "BitShift"
18+
19+
nodes = [
20+
onnx.helper.make_node(
21+
op_type,
22+
inputs=["input1", "input2"],
23+
outputs=["output"],
24+
name=f"/{op_type}",
25+
direction=direction
26+
),
27+
]
28+
29+
inputs = [
30+
onnx.helper.make_value_info(
31+
name="input1",
32+
type_proto=onnx.helper.make_tensor_type_proto(
33+
elem_type=onnx.TensorProto.INT32, shape=input1_shape
34+
),
35+
),
36+
onnx.helper.make_value_info(
37+
name="input2",
38+
type_proto=onnx.helper.make_tensor_type_proto(
39+
elem_type=onnx.TensorProto.INT32, shape=input2_shape
40+
),
41+
),
42+
]
43+
44+
outputs = [
45+
onnx.helper.make_value_info(
46+
name="output",
47+
type_proto=onnx.helper.make_tensor_type_proto(
48+
elem_type=onnx.TensorProto.INT32, shape=output_shape
49+
),
50+
)
51+
]
52+
53+
model = onnx.helper.make_model(
54+
ir_version=8,
55+
opset_imports=[onnx.helper.make_operatorsetid("", 18)],
56+
graph=onnx.helper.make_graph(
57+
name="main_graph",
58+
nodes=nodes,
59+
inputs=inputs,
60+
outputs=outputs,
61+
initializer=[]
62+
),
63+
)
64+
65+
onnx.checker.check_model(model)
66+
onnx.save(model, f"{name}.onnx")
67+
print(f"Finished exporting model to {name}.onnx")
68+
69+
if __name__ == "__main__":
70+
# Define all model configurations
71+
configs = [
72+
# (name, input1_shape, input2_shape, output_shape, direction)
73+
("bitshift_left", [4], [4], [4], "LEFT"),
74+
("bitshift_right", [4], [4], [4], "RIGHT"),
75+
("bitshift_left_scalar", [4], [], [4], "LEFT"),
76+
("bitshift_right_scalar", [4], [], [4], "RIGHT"),
77+
("scalar_bitshift_left", [], [4], [4], "LEFT"),
78+
("scalar_bitshift_right", [], [4], [4], "RIGHT"),
79+
("scalar_bitshift_left_scalar", [], [], [], "LEFT"),
80+
("scalar_bitshift_right_scalar", [], [], [], "RIGHT"),
81+
]
82+
83+
for config in configs:
84+
build_model(*config)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// Include the models for this node type
2+
use crate::include_models;
3+
include_models!(
4+
bitshift_left,
5+
bitshift_left_scalar,
6+
scalar_bitshift_left,
7+
scalar_bitshift_left_scalar,
8+
bitshift_right,
9+
bitshift_right_scalar,
10+
scalar_bitshift_right,
11+
scalar_bitshift_right_scalar
12+
);
13+
14+
#[cfg(test)]
15+
mod tests {
16+
use super::*;
17+
use burn::tensor::{Int, Tensor, TensorData};
18+
19+
type Backend = burn_ndarray::NdArray<f32>;
20+
21+
#[test]
22+
fn bitshift_left_tensors() {
23+
// Initialize the model with weights (loaded from the exported file)
24+
let device = Default::default();
25+
let model: bitshift_left::Model<Backend> = bitshift_left::Model::new(&device);
26+
// Run the model
27+
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
28+
let input2 = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 2], &device);
29+
let output = model.forward(input1, input2);
30+
let expected = TensorData::from([2i64, 4, 12, 16]);
31+
32+
output.to_data().assert_eq(&expected, true);
33+
}
34+
35+
#[test]
36+
fn bitshift_left_scalar_tensor() {
37+
// Initialize the model with weights (loaded from the exported file)
38+
let device = Default::default();
39+
let model: bitshift_left_scalar::Model<Backend> = bitshift_left_scalar::Model::new(&device);
40+
// Run the model
41+
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
42+
let scalar = 2;
43+
let output = model.forward(input1, scalar);
44+
let expected = TensorData::from([4i64, 8, 12, 16]);
45+
46+
output.to_data().assert_eq(&expected, true);
47+
}
48+
49+
#[test]
50+
fn bitshift_right_tensors() {
51+
let device = Default::default();
52+
let model: bitshift_right::Model<Backend> = bitshift_right::Model::new(&device);
53+
54+
// Run the model
55+
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
56+
let input2 = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 2], &device);
57+
let output = model.forward(input1, input2);
58+
let expected = TensorData::from([0i64, 1, 0, 1]);
59+
60+
output.to_data().assert_eq(&expected, true);
61+
}
62+
63+
#[test]
64+
fn bitshift_right_scalar_tensor() {
65+
// Initialize the model with weights (loaded from the exported file)
66+
let device = Default::default();
67+
let model: bitshift_right_scalar::Model<Backend> =
68+
bitshift_right_scalar::Model::new(&device);
69+
// Run the model
70+
let input1 = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
71+
let scalar = 2;
72+
let output = model.forward(input1, scalar);
73+
let expected = TensorData::from([0i64, 0, 0, 1]);
74+
75+
output.to_data().assert_eq(&expected, true);
76+
}
77+
78+
#[test]
79+
fn scalar_bitshift_left_tensor() {
80+
let device = Default::default();
81+
let model: scalar_bitshift_left::Model<Backend> = scalar_bitshift_left::Model::new(&device);
82+
// Run the model
83+
let scalar = 4;
84+
let shift_amounts = Tensor::<Backend, 1, Int>::from_ints([1, 1, 2, 2], &device);
85+
let output = model.forward(scalar, shift_amounts);
86+
// 4 << 1 = 8, 4 << 1 = 8, 4 << 2 = 16, 4 << 2 = 16
87+
let expected = TensorData::from([8i64, 8, 16, 16]);
88+
89+
output.to_data().assert_eq(&expected, true);
90+
}
91+
92+
#[test]
93+
fn scalar_bitshift_right_tensor() {
94+
let device = Default::default();
95+
let model: scalar_bitshift_right::Model<Backend> =
96+
scalar_bitshift_right::Model::new(&device);
97+
// Run the model
98+
let scalar = 8;
99+
let shift_amounts = Tensor::<Backend, 1, Int>::from_ints([1, 2, 3, 4], &device);
100+
let output = model.forward(scalar, shift_amounts);
101+
// 8 >> 1 = 4, 8 >> 2 = 2, 8 >> 3 = 1, 8 >> 4 = 0
102+
let expected = TensorData::from([4i64, 2, 1, 0]);
103+
104+
output.to_data().assert_eq(&expected, true);
105+
}
106+
107+
#[test]
108+
fn scalar_bitshift_left_scalar() {
109+
let device = Default::default();
110+
let model: scalar_bitshift_left_scalar::Model<Backend> =
111+
scalar_bitshift_left_scalar::Model::new(&device);
112+
// Run the model
113+
let lhs = 4;
114+
let rhs = 2;
115+
let output = model.forward(lhs, rhs);
116+
// 4 << 2 = 16
117+
assert_eq!(output, 16);
118+
}
119+
120+
#[test]
121+
fn scalar_bitshift_right_scalar() {
122+
let device = Default::default();
123+
let model: scalar_bitshift_right_scalar::Model<Backend> =
124+
scalar_bitshift_right_scalar::Model::new(&device);
125+
// Run the model
126+
let lhs = 16;
127+
let rhs = 2;
128+
let output = model.forward(lhs, rhs);
129+
// 16 >> 2 = 4
130+
assert_eq!(output, 4);
131+
}
132+
}

0 commit comments

Comments
 (0)