Skip to content

Commit 736899d

Browse files
tye-singwaHelios113
authored andcommitted
Add support onnx size (tracel-ai#3301)
* feat: support onnx size * test: remove unnecessary args for test
1 parent 803bb95 commit 736899d

File tree

10 files changed

+167
-2
lines changed

10 files changed

+167
-2
lines changed

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ functionality.
177177
| [Sign][163] |||
178178
| [Sin][164] |||
179179
| [Sinh][165] |||
180-
| [Size][166] | ||
180+
| [Size][166] | ||
181181
| [Slice][167] |||
182182
| [Softmax][168] |||
183183
| [SoftmaxCrossEntropyLoss][169] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ fn main() {
120120
.input("tests/sign/sign.onnx")
121121
.input("tests/sin/sin.onnx")
122122
.input("tests/sinh/sinh.onnx")
123+
.input("tests/size/size.onnx")
123124
.input("tests/slice/slice.onnx")
124125
.input("tests/slice/slice_shape.onnx")
125126
.input("tests/softmax/softmax.onnx")
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Import the shared macro
2+
use crate::include_models;
3+
include_models!(size);
4+
5+
#[cfg(test)]
6+
mod tests {
7+
use super::*;
8+
use burn::tensor::{Tensor, TensorData};
9+
10+
type Backend = burn_ndarray::NdArray<f32>;
11+
12+
#[test]
13+
fn size() {
14+
let model: size::Model<Backend> = size::Model::default();
15+
let device = Default::default();
16+
17+
let input =
18+
Tensor::<Backend, 1>::arange(0..(1 * 2 * 3 * 4 * 5), &device).reshape([1, 2, 3, 4, 5]);
19+
let output = model.forward(input);
20+
let expected = TensorData::from([120]);
21+
22+
output.to_data().assert_eq(&expected, true);
23+
}
24+
}
125 Bytes
Binary file not shown.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: size.onnx
4+
5+
import numpy as np
6+
import onnx
7+
import onnx.helper
8+
from onnx import TensorProto
9+
from onnx.reference import ReferenceEvaluator
10+
11+
12+
def build_model():
13+
# Define the graph inputs and outputs
14+
input = onnx.helper.make_tensor_value_info(
15+
'input', TensorProto.FLOAT, [2, 6, 2, 3])
16+
output = onnx.helper.make_tensor_value_info(
17+
'output', TensorProto.FLOAT, [1])
18+
19+
# Create the Size node
20+
size = onnx.helper.make_node(
21+
"Size",
22+
inputs=["input"],
23+
outputs=["output"],
24+
name="SizeNode",
25+
)
26+
27+
# Create the graph
28+
graph = onnx.helper.make_graph(
29+
[size],
30+
'SizeModel',
31+
[input],
32+
[output],
33+
)
34+
35+
# Create the model
36+
model = onnx.helper.make_model(
37+
opset_imports=[onnx.helper.make_operatorsetid("", 16)],
38+
graph=graph,
39+
producer_name='ONNX_Generator',
40+
)
41+
42+
return model
43+
44+
45+
if __name__ == "__main__":
46+
# Set seed and precision
47+
np.random.seed(42)
48+
np.set_printoptions(precision=8)
49+
50+
# Build model
51+
test_input = np.arange(1*2*3*4*5).reshape(1, 2, 3, 4, 5)
52+
onnx_model = build_model()
53+
file_name = "size.onnx"
54+
55+
# Ensure valid ONNX and save
56+
onnx.checker.check_model(onnx_model)
57+
onnx.save(onnx_model, file_name)
58+
print(f"Finished exporting model to {file_name}")
59+
60+
# Output some test data for use in the test
61+
print(f"Test input data shape: {test_input.shape}")
62+
session = ReferenceEvaluator(file_name, verbose=1)
63+
test_output, = session.run(None, {"input": test_input})
64+
print(f"Test output: {repr(test_output)}")

crates/burn-import/src/burn/node/unary.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub enum UnaryNodeKind {
5151
Tanh,
5252
Transpose,
5353
Sign,
54+
Size,
5455
}
5556

5657
impl UnaryNodeKind {
@@ -86,6 +87,7 @@ impl UnaryNodeKind {
8687
Self::Tanh => "tanh",
8788
Self::Transpose => "transpose",
8889
Self::Sign => "sign",
90+
Self::Size => "size",
8991
}
9092
}
9193
}
@@ -513,6 +515,11 @@ impl UnaryNode {
513515
let function = move |input| quote! { #input.sign()};
514516
Self::new(input, output, UnaryNodeKind::Sign, Rc::new(function))
515517
}
518+
519+
pub(crate) fn size(input: Type, output: Type) -> Self {
520+
let function = move |input| quote! { #input.shape.num_elements()};
521+
Self::new(input, output, UnaryNodeKind::Size, Rc::new(function))
522+
}
516523
}
517524

518525
#[cfg(test)]
@@ -1214,4 +1221,23 @@ mod tests {
12141221
vec!["tensor2".to_string()],
12151222
);
12161223
}
1224+
1225+
#[test]
1226+
fn test_unary_codegen_size() {
1227+
one_node_graph(
1228+
UnaryNode::size(
1229+
Type::Tensor(TensorType::new_float("tensor1", 4)),
1230+
Type::Scalar(ScalarType::new("scalar1", ScalarKind::Int64)),
1231+
),
1232+
quote! {
1233+
pub fn forward(&self, tensor1: Tensor<B, 4>) -> i64 {
1234+
let scalar1 = tensor1.shape.num_elements();
1235+
1236+
scalar1
1237+
}
1238+
},
1239+
vec!["tensor1".to_string()],
1240+
vec!["scalar1".to_string()],
1241+
);
1242+
}
12171243
}

crates/burn-import/src/onnx/to_burn.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ impl ParsedOnnxGraph {
368368
NodeType::Sigmoid => graph.register(Self::sigmoid_conversion(node)),
369369
NodeType::Sin => graph.register(Self::sin_conversion(node)),
370370
NodeType::Sinh => graph.register(Self::sinh_conversion(node)),
371+
NodeType::Size => graph.register(Self::size_conversion(node)),
371372
NodeType::Slice => graph.register(Self::slice_conversion(node)),
372373
NodeType::SpaceToDepth => graph.register(Self::space_to_depth_conversion(node)),
373374
NodeType::Sum => graph.register(Self::sum_conversion(node)),
@@ -908,6 +909,13 @@ impl ParsedOnnxGraph {
908909
UnaryNode::sinh(input, output)
909910
}
910911

912+
fn size_conversion(node: Node) -> UnaryNode {
913+
let input = Type::from(node.inputs.first().unwrap());
914+
let output = Type::from(node.outputs.first().unwrap());
915+
916+
UnaryNode::size(input, output)
917+
}
918+
911919
fn slice_conversion(node: Node) -> SliceNode {
912920
let input = Type::from(node.inputs.first().unwrap());
913921
let output = Type::from(node.outputs.first().unwrap());

crates/onnx-ir/src/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ pub mod reduce_sum;
5757
pub mod reshape;
5858
pub mod resize;
5959
pub mod shape;
60+
pub mod size;
6061
pub mod slice;
6162
pub mod softmax;
6263
pub mod space_to_depth;

crates/onnx-ir/src/node/size.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
use crate::ir::{ArgType, ElementType, Node};
2+
3+
/// Update output type for Size (always scalar).
4+
pub fn size_update_outputs(node: &mut Node) {
5+
log::debug!("Size rank inference for node {}", node.name);
6+
7+
assert_eq!(
8+
node.inputs.len(),
9+
1,
10+
"Size: expected 1 input, found {}",
11+
node.inputs.len()
12+
);
13+
14+
node.outputs[0].ty = ArgType::Scalar(ElementType::Int64);
15+
}
16+
17+
#[cfg(test)]
18+
mod tests {
19+
use super::*;
20+
use crate::ir::NodeType;
21+
use crate::node::test_utils::NodeBuilder;
22+
23+
fn create_test_node(rank: usize) -> Node {
24+
let builder = NodeBuilder::new(NodeType::Size, "test_size")
25+
.input_tensor_f32("data", rank, None)
26+
.output_scalar_i64("size");
27+
28+
builder.build()
29+
}
30+
31+
#[test]
32+
fn test_size_update_outputs() {
33+
let mut node = create_test_node(4);
34+
size_update_outputs(&mut node);
35+
assert!(matches!(
36+
&node.outputs[0].ty,
37+
ArgType::Scalar(ElementType::Int64)
38+
));
39+
}
40+
}

crates/onnx-ir/src/rank_inference.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
reduce_max::reduce_max_update_outputs, reduce_mean::reduce_mean_update_outputs,
1313
reduce_min::reduce_min_update_outputs, reduce_prod::reduce_prod_update_outputs,
1414
reduce_sum::reduce_sum_update_outputs, reshape::reshape_update_outputs,
15-
shape::shape_update_outputs, slice::slice_update_output_rank,
15+
shape::shape_update_outputs, size::size_update_outputs, slice::slice_update_output_rank,
1616
space_to_depth::space_to_depth_update_outputs, split::split_update_outputs,
1717
squeeze::squeeze_update_output, topk::top_k_update_output,
1818
unsqueeze::unsqueeze_update_output, where_op::where_update_outputs,
@@ -104,6 +104,7 @@ pub fn rank_inference(node: &mut Node) {
104104
NodeType::Sign => same_as_input(node),
105105
NodeType::Sin => same_as_input(node),
106106
NodeType::Sinh => same_as_input(node),
107+
NodeType::Size => size_update_outputs(node),
107108
NodeType::Slice => slice_update_output_rank(node),
108109
NodeType::Softmax => same_as_input(node),
109110
NodeType::SpaceToDepth => space_to_depth_update_outputs(node),

0 commit comments

Comments
 (0)