Skip to content

Commit 9dd5cf2

Browse files
committed
feat: add onnx ArgMin node
1 parent 1c1042d commit 9dd5cf2

File tree

12 files changed

+338
-13
lines changed

12 files changed

+338
-13
lines changed

crates/burn-import/onnx-tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ fn main() {
1010
.input("tests/and/and.onnx")
1111
.input("tests/add/add_int.onnx")
1212
.input("tests/argmax/argmax.onnx")
13+
.input("tests/argmin/argmin.onnx")
1314
.input("tests/avg_pool1d/avg_pool1d.onnx")
1415
.input("tests/avg_pool2d/avg_pool2d.onnx")
1516
.input("tests/batch_norm/batch_norm.onnx")
Binary file not shown.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: onnx-tests/tests/argmin/argmin.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Model(nn.Module):
10+
def __init__(self, argmin_dim: int = 0):
11+
super(Model, self).__init__()
12+
self._argmin_dim = argmin_dim
13+
14+
def forward(self, x):
15+
# Note: only keepdim=True is supported in burn
16+
y = torch.argmin(input=x, dim=self._argmin_dim, keepdim=True)
17+
return y
18+
19+
20+
def main():
21+
22+
# Export to onnx
23+
model = Model(1)
24+
model.eval()
25+
device = torch.device("cpu")
26+
onnx_name = "argmin.onnx"
27+
dummy_input = torch.randn((3, 4), device=device)
28+
torch.onnx.export(model, dummy_input, onnx_name,
29+
verbose=False, opset_version=16)
30+
31+
print("Finished exporting model to {}".format(onnx_name))
32+
33+
# Output some test data for use in the test
34+
test_input = torch.randn((2, 3), device=device)
35+
print("Test input data shape: {}".format(test_input.shape))
36+
print("Test input data:\n{}".format(test_input))
37+
output = model.forward(test_input)
38+
39+
print("Test output data shape: {}".format(output.shape))
40+
print("Test output data:\n{}".format(output))
41+
42+
43+
if __name__ == '__main__':
44+
main()
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Import the shared macro
2+
use crate::include_models;
3+
include_models!(argmin);
4+
5+
#[cfg(test)]
6+
mod tests {
7+
use super::*;
8+
use burn::tensor::{Tensor, TensorData};
9+
10+
type Backend = burn_ndarray::NdArray<f32>;
11+
12+
#[test]
13+
fn argmin() {
14+
// Initialize the model with weights (loaded from the exported file)
15+
let model: argmin::Model<Backend> = argmin::Model::default();
16+
17+
let device = Default::default();
18+
// Run the model
19+
let input = Tensor::<Backend, 2>::from_floats(
20+
[[1.6124, 1.0463, -1.3808], [-0.3852, 0.1301, 0.9780]],
21+
&device,
22+
);
23+
let output = model.forward(input);
24+
let expected = TensorData::from([[2i64], [0]]);
25+
26+
output.to_data().assert_eq(&expected, true);
27+
}
28+
}

crates/burn-import/onnx-tests/tests/test_mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extern crate alloc;
66
pub mod add;
77
pub mod and;
88
pub mod argmax;
9+
pub mod argmin;
910
pub mod avg_pool;
1011
pub mod batch_norm;
1112
pub mod cast;
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{TensorKind, TensorType, ToTokens, Type};
3+
4+
use burn::record::PrecisionSettings;
5+
use quote::quote;
6+
7+
#[derive(Debug, Clone, new)]
8+
pub struct ArgMinNode {
9+
pub input: TensorType,
10+
pub output: TensorType,
11+
pub axis: usize,
12+
}
13+
14+
impl<PS: PrecisionSettings> NodeCodegen<PS> for ArgMinNode {
15+
fn output_types(&self) -> Vec<Type> {
16+
let mut output = self.output.clone();
17+
output.kind = TensorKind::Int;
18+
vec![Type::Tensor(output)]
19+
}
20+
21+
fn input_types(&self) -> Vec<crate::burn::Type> {
22+
vec![Type::Tensor(self.input.clone())]
23+
}
24+
25+
fn forward(
26+
&self,
27+
scope: &mut crate::burn::Scope,
28+
node_position: usize,
29+
) -> proc_macro2::TokenStream {
30+
//NOTE: select_last_index and keep_dims are not supported
31+
let axis = self.axis.to_tokens();
32+
33+
let input = scope.tensor_use_owned(&self.input, node_position);
34+
let output = &self.output.name;
35+
36+
quote! {
37+
let #output = #input.argmin(#axis);
38+
}
39+
}
40+
41+
fn into_node(self) -> super::Node<PS> {
42+
Node::ArgMin(self)
43+
}
44+
}
45+
46+
#[cfg(test)]
47+
mod tests {
48+
49+
use burn::record::FullPrecisionSettings;
50+
51+
use super::*;
52+
use crate::burn::{TensorType, graph::BurnGraph, node::test::assert_tokens};
53+
54+
#[test]
55+
fn test_codegen_argmin() {
56+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
57+
58+
graph.register(ArgMinNode::new(
59+
TensorType::new_float("tensor1", 2),
60+
TensorType::new_int("tensor2", 2),
61+
1,
62+
));
63+
64+
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
65+
66+
let expected = quote! {
67+
use burn::tensor::Int;
68+
use burn::{
69+
module::Module,
70+
tensor::{backend::Backend, Tensor},
71+
};
72+
73+
#[derive(Module, Debug)]
74+
pub struct Model<B: Backend> {
75+
phantom: core::marker::PhantomData<B>,
76+
device: burn::module::Ignored<B::Device>,
77+
}
78+
79+
impl<B: Backend> Model <B> {
80+
#[allow(unused_variables)]
81+
pub fn new(device: &B::Device) -> Self {
82+
Self {
83+
phantom: core::marker::PhantomData,
84+
device: burn::module::Ignored(device.clone()),
85+
}
86+
}
87+
88+
#[allow(clippy::let_and_return, clippy::approx_constant)]
89+
pub fn forward(
90+
&self,
91+
tensor1: Tensor<B, 2>
92+
) -> Tensor<B, 2, Int> {
93+
let tensor2 = tensor1.argmin(1);
94+
95+
tensor2
96+
}
97+
}
98+
};
99+
100+
assert_tokens(graph.codegen(), expected);
101+
}
102+
}

crates/burn-import/src/burn/node/base.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::marker::PhantomData;
22

33
use super::{
4-
argmax::ArgMaxNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
4+
argmax::ArgMaxNode, argmin::ArgMinNode, avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode,
55
batch_norm::BatchNormNode, binary::BinaryNode, ceil::CeilNode, clip::ClipNode,
66
concat::ConcatNode, constant::ConstantNode, constant_of_shape::ConstantOfShapeNode,
77
conv_transpose_1d::ConvTranspose1dNode, conv_transpose_2d::ConvTranspose2dNode,
@@ -86,6 +86,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {
8686
#[derive(Debug, Clone)]
8787
pub enum Node<PS: PrecisionSettings> {
8888
ArgMax(ArgMaxNode),
89+
ArgMin(ArgMinNode),
8990
AvgPool1d(AvgPool1dNode),
9091
AvgPool2d(AvgPool2dNode),
9192
BatchNorm(BatchNormNode),
@@ -147,6 +148,7 @@ macro_rules! match_all {
147148
#[allow(clippy::redundant_closure_call)]
148149
match $self {
149150
Node::ArgMax(node) => $func(node),
151+
Node::ArgMin(node) => $func(node),
150152
Node::AvgPool1d(node) => $func(node),
151153
Node::AvgPool2d(node) => $func(node),
152154
Node::BatchNorm(node) => $func(node),
@@ -216,6 +218,7 @@ impl<PS: PrecisionSettings> Node<PS> {
216218
pub fn name(&self) -> &str {
217219
match self {
218220
Node::ArgMax(_) => "argmax",
221+
Node::ArgMin(_) => "argmin",
219222
Node::AvgPool1d(_) => "avg_pool1d",
220223
Node::AvgPool2d(_) => "avg_pool2d",
221224
Node::BatchNorm(_) => "batch_norm",

crates/burn-import/src/burn/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod base;
22

33
pub(crate) mod argmax;
4+
pub(crate) mod argmin;
45
pub(crate) mod avg_pool1d;
56
pub(crate) mod avg_pool2d;
67
pub(crate) mod batch_norm;

crates/burn-import/src/onnx/to_burn.rs

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::{
1717
graph::BurnGraph,
1818
node::{
1919
argmax::ArgMaxNode,
20+
argmin::ArgMinNode,
2021
avg_pool1d::AvgPool1dNode,
2122
avg_pool2d::AvgPool2dNode,
2223
batch_norm::BatchNormNode,
@@ -80,17 +81,18 @@ use onnx_ir::{
8081
TensorType as OnnxTensorType,
8182
},
8283
node::{
83-
argmax::argmax_config, avg_pool1d::avg_pool1d_config, avg_pool2d::avg_pool2d_config,
84-
batch_norm::batch_norm_config, clip::clip_config, concat::concat_config,
85-
conv_transpose1d::conv_transpose1d_config, conv_transpose2d::conv_transpose2d_config,
86-
conv_transpose3d::conv_transpose3d_config, conv1d::conv1d_config, conv2d::conv2d_config,
87-
conv3d::conv3d_config, dropout::dropout_config, expand::expand_config,
88-
flatten::flatten_config, gather::gather_config, gemm::gemm_config,
89-
group_norm::group_norm_config, hard_sigmoid::hard_sigmoid_config,
90-
instance_norm::instance_norm_config, layer_norm::layer_norm_config,
91-
leaky_relu::leaky_relu_config, linear::linear_config, log_softmax::log_softmax_config,
92-
max_pool1d::max_pool1d_config, max_pool2d::max_pool2d_config, one_hot::one_hot_config,
93-
pad::pad_config, reduce_max::reduce_max_config, reduce_mean::reduce_mean_config,
84+
argmax::argmax_config, argmin::argmin_config, avg_pool1d::avg_pool1d_config,
85+
avg_pool2d::avg_pool2d_config, batch_norm::batch_norm_config, clip::clip_config,
86+
concat::concat_config, conv_transpose1d::conv_transpose1d_config,
87+
conv_transpose2d::conv_transpose2d_config, conv_transpose3d::conv_transpose3d_config,
88+
conv1d::conv1d_config, conv2d::conv2d_config, conv3d::conv3d_config,
89+
dropout::dropout_config, expand::expand_config, flatten::flatten_config,
90+
gather::gather_config, gemm::gemm_config, group_norm::group_norm_config,
91+
hard_sigmoid::hard_sigmoid_config, instance_norm::instance_norm_config,
92+
layer_norm::layer_norm_config, leaky_relu::leaky_relu_config, linear::linear_config,
93+
log_softmax::log_softmax_config, max_pool1d::max_pool1d_config,
94+
max_pool2d::max_pool2d_config, one_hot::one_hot_config, pad::pad_config,
95+
reduce_max::reduce_max_config, reduce_mean::reduce_mean_config,
9496
reduce_min::reduce_min_config, reduce_prod::reduce_prod_config,
9597
reduce_sum::reduce_sum_config, reshape::reshape_config, resize::resize_config,
9698
slice::slice_config, softmax::softmax_config, split::split_config, squeeze::squeeze_config,
@@ -287,6 +289,7 @@ impl ParsedOnnxGraph {
287289
match node.node_type {
288290
NodeType::Add => graph.register(Self::add_conversion(node)),
289291
NodeType::ArgMax => graph.register(Self::argmax_conversion(node)),
292+
NodeType::ArgMin => graph.register(Self::argmin_conversion(node)),
290293
NodeType::Sub => graph.register(Self::sub_conversion(node)),
291294
NodeType::Mul => graph.register(Self::mul_conversion(node)),
292295
NodeType::Div => graph.register(Self::div_conversion(node)),
@@ -968,6 +971,14 @@ impl ParsedOnnxGraph {
968971
ArgMaxNode::new(input, output, axis)
969972
}
970973

974+
fn argmin_conversion(node: Node) -> ArgMinNode {
975+
let input = TensorType::from(node.inputs.first().unwrap());
976+
let output = TensorType::from(node.outputs.first().unwrap());
977+
let axis = argmin_config(&node);
978+
979+
ArgMinNode::new(input, output, axis)
980+
}
981+
971982
fn concat_conversion(node: Node) -> ConcatNode {
972983
let inputs = node.inputs.iter().map(TensorType::from).collect();
973984

0 commit comments

Comments
 (0)