Skip to content

Commit c27b0eb

Browse files
committed
feat: add support for SpaceToDepth onnx node
1 parent e3afafd commit c27b0eb

File tree

13 files changed

+453
-21
lines changed

13 files changed

+453
-21
lines changed

crates/burn-import/onnx-tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ fn main() {
122122
.input("tests/slice/slice.onnx")
123123
.input("tests/slice/slice_shape.onnx")
124124
.input("tests/softmax/softmax.onnx")
125+
.input("tests/space_to_depth/space_to_depth.onnx")
125126
.input("tests/sqrt/sqrt.onnx")
126127
.input("tests/squeeze/squeeze_multiple.onnx")
127128
.input("tests/squeeze/squeeze.onnx")
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Import the shared macro
2+
use crate::include_models;
3+
include_models!(space_to_depth);
4+
5+
#[cfg(test)]
6+
mod tests {
7+
use super::*;
8+
use burn::tensor::{Tensor, TensorData, Tolerance, ops::FloatElem};
9+
10+
type Backend = burn_ndarray::NdArray<f32>;
11+
type FT = FloatElem<Backend>;
12+
13+
#[test]
14+
fn space_to_depth() {
15+
let device = Default::default();
16+
let model: space_to_depth::Model<Backend> = space_to_depth::Model::new(&device);
17+
18+
let input = Tensor::<Backend, 4>::from_floats(
19+
[
20+
[[
21+
[0.5, -0.14, 0.65, 1.52, -0.23, -0.23],
22+
[1.58, 0.77, -0.47, 0.54, -0.46, -0.47],
23+
[0.24, -1.91, -1.72, -0.56, -1.01, 0.31],
24+
[-0.91, -1.41, 1.47, -0.23, 0.07, -1.42],
25+
]],
26+
[[
27+
[-0.54, 0.11, -1.15, 0.38, -0.6, -0.29],
28+
[-0.6, 1.85, -0.01, -1.06, 0.82, -1.22],
29+
[0.21, -1.96, -1.33, 0.2, 0.74, 0.17],
30+
[-0.12, -0.3, -1.48, -0.72, -0.46, 1.06],
31+
]],
32+
],
33+
&device,
34+
);
35+
let output = model.forward(input);
36+
let expected = TensorData::from([
37+
[
38+
[[0.5, 0.65, -0.23], [0.24, -1.72, -1.01]],
39+
[[-0.14, 1.52, -0.23], [-1.91, -0.56, 0.31]],
40+
[[1.58, -0.47, -0.46], [-0.91, 1.47, 0.07]],
41+
[[0.77, 0.54, -0.47], [-1.41, -0.23, -1.42]],
42+
],
43+
[
44+
[[-0.54, -1.15, -0.6], [0.21, -1.33, 0.74]],
45+
[[0.11, 0.38, -0.29], [-1.96, 0.2, 0.17]],
46+
[[-0.6, -0.01, 0.82], [-0.12, -1.48, -0.46]],
47+
[[1.85, -1.06, -1.22], [-0.3, -0.72, 1.06]],
48+
],
49+
]);
50+
51+
output
52+
.to_data()
53+
.assert_approx_eq::<FT>(&expected, Tolerance::default());
54+
}
55+
}
Binary file not shown.
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate models: space_to_depth_*.onnx
4+
5+
import numpy as np
6+
import onnx
7+
import onnx.helper
8+
from onnx import TensorProto
9+
from onnx.reference import ReferenceEvaluator
10+
11+
12+
def build_model():
13+
# Define the graph inputs and outputs
14+
input = onnx.helper.make_tensor_value_info(
15+
'input', TensorProto.FLOAT, [2, 1, 4, 6])
16+
output = onnx.helper.make_tensor_value_info(
17+
'output', TensorProto.FLOAT, [2, 4, 2, 3])
18+
19+
# Create the SpaceToDepth node
20+
space_to_depth = onnx.helper.make_node(
21+
"SpaceToDepth",
22+
inputs=["input"],
23+
outputs=["output"],
24+
name="SpaceToDepthNode",
25+
blocksize=2,
26+
)
27+
28+
# Create the graph
29+
graph = onnx.helper.make_graph(
30+
[space_to_depth],
31+
'SpaceToDepthModel',
32+
[input],
33+
[output],
34+
)
35+
36+
# Create the model
37+
model = onnx.helper.make_model(
38+
opset_imports=[onnx.helper.make_operatorsetid("", 21)],
39+
graph=graph,
40+
producer_name='ONNX_Generator',
41+
)
42+
43+
return model
44+
45+
46+
if __name__ == "__main__":
47+
# Set seed and precision
48+
np.random.seed(42)
49+
np.set_printoptions(precision=8)
50+
51+
# Build model
52+
test_input = np.random.randn(2, 1, 4, 6).round(2)
53+
onnx_model = build_model()
54+
file_name = "space_to_depth.onnx"
55+
56+
# Ensure valid ONNX and save
57+
onnx.checker.check_model(onnx_model)
58+
onnx.save(onnx_model, file_name)
59+
print(f"Finished exporting model to {file_name}")
60+
61+
# Output some test data for use in the test
62+
print(f"Test input data:\n{repr(test_input)}")
63+
print(f"Test input data shape: {test_input.shape}")
64+
session = ReferenceEvaluator(file_name, verbose=1)
65+
test_output, = session.run(None, {"input": test_input})
66+
print(f"Test output:\n{repr(test_output)}")
67+
print(f"Test output shape: {test_output.shape}")

crates/burn-import/onnx-tests/tests/test_mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ pub mod conv;
1818
pub mod conv_transpose;
1919
pub mod cos;
2020
pub mod cosh;
21+
pub mod depth_to_space;
2122
pub mod div;
2223
pub mod dropout;
2324
pub mod equal;
2425
pub mod erf;
2526
pub mod exp;
26-
pub mod depth_to_space;
2727
pub mod expand;
2828
pub mod flatten;
2929
pub mod floor;
@@ -80,6 +80,7 @@ pub mod sin;
8080
pub mod sinh;
8181
pub mod slice;
8282
pub mod softmax;
83+
pub mod space_to_depth;
8384
pub mod split;
8485
pub mod sqrt;
8586
pub mod squeeze;

crates/burn-import/src/burn/node/base.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ use super::{
66
concat::ConcatNode, constant::ConstantNode, constant_of_shape::ConstantOfShapeNode,
77
conv_transpose_1d::ConvTranspose1dNode, conv_transpose_2d::ConvTranspose2dNode,
88
conv_transpose_3d::ConvTranspose3dNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
9-
conv3d::Conv3dNode, dropout::DropoutNode, expand::ExpandNode, floor::FloorNode,
10-
gather::GatherNode, gather_elements::GatherElementsNode, gemm::GemmNode,
9+
conv3d::Conv3dNode, depth_to_space::DepthToSpaceNode, dropout::DropoutNode, expand::ExpandNode,
10+
floor::FloorNode, gather::GatherNode, gather_elements::GatherElementsNode, gemm::GemmNode,
1111
global_avg_pool::GlobalAvgPoolNode, group_norm::GroupNormNode, instance_norm::InstanceNormNode,
1212
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
1313
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, mean::MeanNode, one_hot::OneHotNode,
@@ -17,9 +17,8 @@ use super::{
1717
resize::ResizeNode, round::RoundNode, slice::SliceNode, split::SplitNode, squeeze::SqueezeNode,
1818
sum::SumNode, tile::TileNode, top_k::TopKNode, trilu::TriluNode, unary::UnaryNode,
1919
unsqueeze::UnsqueezeNode,
20-
depth_to_space::DepthToSpaceNode
2120
};
22-
use crate::burn::{BurnImports, Scope, Type};
21+
use crate::burn::{BurnImports, Scope, Type, node::space_to_depth::SpaceToDepthNode};
2322
use burn::record::PrecisionSettings;
2423
use proc_macro2::TokenStream;
2524
use serde::Serialize;
@@ -126,6 +125,7 @@ pub enum Node<PS: PrecisionSettings> {
126125
Round(RoundNode),
127126
Slice(SliceNode),
128127
Squeeze(SqueezeNode),
128+
SpaceToDepth(SpaceToDepthNode),
129129
Split(SplitNode),
130130
Sum(SumNode),
131131
Tile(TileNode),
@@ -187,6 +187,7 @@ macro_rules! match_all {
187187
Node::Resize(node) => $func(node),
188188
Node::Round(node) => $func(node),
189189
Node::Slice(node) => $func(node),
190+
Node::SpaceToDepth(node) => $func(node),
190191
Node::Squeeze(node) => $func(node),
191192
Node::Sum(node) => $func(node),
192193
Node::Tile(node) => $func(node),
@@ -257,6 +258,7 @@ impl<PS: PrecisionSettings> Node<PS> {
257258
Node::Resize(_) => "resize",
258259
Node::Round(_) => "round",
259260
Node::Slice(_) => "slice",
261+
Node::SpaceToDepth(_) => "space_to_depth",
260262
Node::Squeeze(_) => "squeeze",
261263
Node::Sum(_) => "add",
262264
Node::Tile(_) => "tile",
@@ -399,10 +401,9 @@ pub(crate) mod tests {
399401
),
400402
));
401403

402-
graph.register_input_output(
403-
vec!["tensor1".to_string(), "tensor2".to_string()],
404-
vec!["tensor4".to_string()],
405-
);
404+
graph.register_input_output(vec!["tensor1".to_string(), "tensor2".to_string()], vec![
405+
"tensor4".to_string(),
406+
]);
406407

407408
let expected = quote! {
408409
use burn::{
@@ -485,10 +486,9 @@ pub(crate) mod tests {
485486
TensorType::new_float("output", 4),
486487
));
487488

488-
graph.register_input_output(
489-
vec!["tensor1".to_string(), "tensor2".to_string()],
490-
vec!["output".to_string()],
491-
);
489+
graph.register_input_output(vec!["tensor1".to_string(), "tensor2".to_string()], vec![
490+
"output".to_string(),
491+
]);
492492

493493
let expected = quote! {
494494
use burn::{

crates/burn-import/src/burn/node/depth_to_space.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ mod tests {
8080

8181
graph.register(DepthToSpaceNode::new(
8282
TensorType::new_float("input", 4),
83-
TensorType::new_float("output", 5),
83+
TensorType::new_float("output", 4),
8484
DepthToSpaceConfig::new(DepthToSpaceMode::DCR, 2),
8585
));
8686

@@ -105,7 +105,7 @@ mod tests {
105105
}
106106
}
107107
#[allow(clippy::let_and_return, clippy::approx_constant)]
108-
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 5> {
108+
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
109109
let output = {
110110
let [b, c, h, w] = input.shape().dims();
111111
input
@@ -152,7 +152,7 @@ mod tests {
152152
}
153153
}
154154
#[allow(clippy::let_and_return, clippy::approx_constant)]
155-
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 5> {
155+
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
156156
let output = {
157157
let [b, c, h, w] = input.shape().dims();
158158
input

crates/burn-import/src/burn/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub(crate) mod reshape;
4545
pub(crate) mod resize;
4646
pub(crate) mod round;
4747
pub(crate) mod slice;
48+
pub(crate) mod space_to_depth;
4849
pub(crate) mod split;
4950
pub(crate) mod squeeze;
5051
pub(crate) mod sum;
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{Scope, TensorType, Type};
3+
use burn::record::PrecisionSettings;
4+
use proc_macro2::TokenStream;
5+
use quote::quote;
6+
7+
#[derive(Debug, Clone)]
8+
pub struct SpaceToDepthNode {
9+
pub input: TensorType,
10+
pub output: TensorType,
11+
pub block_size: usize,
12+
}
13+
14+
impl SpaceToDepthNode {
15+
pub fn new(input: TensorType, output: TensorType, block_size: usize) -> Self {
16+
Self {
17+
input,
18+
output,
19+
block_size,
20+
}
21+
}
22+
}
23+
24+
impl<PS: PrecisionSettings> NodeCodegen<PS> for SpaceToDepthNode {
25+
fn input_types(&self) -> Vec<Type> {
26+
vec![Type::Tensor(self.input.clone())]
27+
}
28+
29+
fn output_types(&self) -> Vec<Type> {
30+
vec![Type::Tensor(self.output.clone())]
31+
}
32+
33+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
34+
let input = scope.tensor_use_owned(&self.input, node_position);
35+
let output = &self.output.name;
36+
let block_size = self.block_size;
37+
38+
quote! {
39+
let #output = {
40+
let [b, c, h, w] = #input.shape().dims();
41+
#input
42+
.reshape([b, c, h / #block_size, #block_size, w / #block_size, #block_size])
43+
.permute([0, 3, 5, 1, 2, 4])
44+
.reshape([b, c * #block_size * #block_size, h / #block_size, w / #block_size])
45+
};
46+
}
47+
}
48+
49+
fn into_node(self) -> Node<PS> {
50+
Node::SpaceToDepth(self)
51+
}
52+
}
53+
54+
#[cfg(test)]
55+
mod tests {
56+
use super::*;
57+
use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens};
58+
use burn::record::FullPrecisionSettings;
59+
60+
#[test]
61+
fn test_codegen() {
62+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
63+
64+
graph.register(SpaceToDepthNode::new(
65+
TensorType::new_float("input", 4),
66+
TensorType::new_float("output", 4),
67+
2,
68+
));
69+
70+
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
71+
72+
let expected = quote! {
73+
use burn::{
74+
module::Module,
75+
tensor::{backend::Backend, Tensor},
76+
};
77+
#[derive(Module, Debug)]
78+
pub struct Model<B: Backend> {
79+
phantom: core::marker::PhantomData<B>,
80+
device: burn::module::Ignored<B::Device>,
81+
}
82+
impl<B: Backend> Model<B> {
83+
#[allow(unused_variables)]
84+
pub fn new(device: &B::Device) -> Self {
85+
Self {
86+
phantom: core::marker::PhantomData,
87+
device: burn::module::Ignored(device.clone()),
88+
}
89+
}
90+
#[allow(clippy::let_and_return, clippy::approx_constant)]
91+
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
92+
let output = {
93+
let [b, c, h, w] = input.shape().dims();
94+
input
95+
.reshape([b, c, h / 2usize, 2usize, w / 2usize, 2usize])
96+
.permute([0, 3, 5, 1, 2, 4])
97+
.reshape([b, c * 2usize * 2usize, h / 2usize, w / 2usize])
98+
};
99+
output
100+
}
101+
}
102+
};
103+
104+
assert_tokens(graph.codegen(), expected);
105+
}
106+
}

0 commit comments

Comments
 (0)