Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ represent the corresponding Burn Op.
| [InstanceNormalization][79] | ❌ | ✅ |
| [IsInf][80] | ❌ | ❌ |
| [IsNaN][81] | ❌ | ❌ |
| [LayerNormalization][82] | | ✅ |
| [LayerNormalization][82] | | ✅ |
| [LeakyRelu][83] | ✅ | ✅ |
| [Less][84] | ❌ | ✅ |
| [LessOrEqual][85] | ❌ | ✅ |
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn main() {
.input("tests/gather/gather.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/layer_norm/layer_norm.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/layer_norm/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/layer_norm/layer_norm.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm = nn.LayerNorm(4)

def forward(self, x):
return self.norm(x)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "layer_norm.onnx"
test_input = torch.arange(24, dtype=torch.float, device=device).reshape(2, 3, 4)
# LayerNormalization only appeared in opset 17
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=17)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ include_models!(
gather,
gelu,
global_avr_pool,
layer_norm,
leaky_relu,
linear,
log_softmax,
Expand Down Expand Up @@ -600,6 +601,40 @@ mod tests {
assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2)));
}

#[test]
fn layer_norm() {
let device = Default::default();
let model: layer_norm::Model<Backend> = layer_norm::Model::default();

// Run the model with ones as input for easier testing
let input = Tensor::<Backend, 3>::from_floats(
[
[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]],
[
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([
[
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
],
[
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
],
]);

output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn leaky_relu() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
Expand Down Expand Up @@ -87,6 +88,7 @@ pub enum Node<PS: PrecisionSettings> {
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Matmul(MatmulNode),
MaxPool2d(MaxPool2dNode),
Expand All @@ -112,6 +114,7 @@ macro_rules! match_all {
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Expand Down Expand Up @@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Node::Matmul(_) => "matmul",
Node::MaxPool2d(_) => "max_pool2d",
Expand Down
177 changes: 177 additions & 0 deletions crates/burn-import/src/burn/node/layer_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{ConstantRecord, Param, ParamId},
nn::{LayerNormConfig, LayerNormRecord},
record::{PrecisionSettings, Record},
tensor::{DataSerialize, Tensor},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;

#[derive(Debug, Clone)]
pub struct LayerNormNode<PS: PrecisionSettings> {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub gamma: DataSerialize<PS::FloatElem>, // Scale
pub beta: Option<DataSerialize<PS::FloatElem>>, // Bias (B)
pub config: LayerNormConfig,
pub full_precision: bool,
}

impl<PS: PrecisionSettings> LayerNormNode<PS> {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
gamma: DataSerialize<PS::FloatElem>,
beta: Option<DataSerialize<PS::FloatElem>>,
config: LayerNormConfig,
full_precision: bool,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
LayerNorm<B>
},
),
input,
output,
gamma,
beta,
config,
full_precision,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for LayerNormNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let num_features = self.config.d_model.to_tokens();
let epsilon = self.config.epsilon;

let tokens = quote! {
let #name = LayerNormConfig::new(#num_features)
.with_epsilon(#epsilon)
.init(device);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = LayerNormRecord::<SerializationBackend> {
gamma: Param::initialized(
ParamId::new(),
Tensor::from_data(self.gamma.clone().convert(), &device),
),
beta: Param::initialized(
ParamId::new(),
if let Some(beta) = self.beta.clone() {
Tensor::from_data(beta.convert(), &device)
} else {
Tensor::zeros([self.config.d_model], &device)
},
),
epsilon: ConstantRecord::new(),
};

let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

// TODO: handle self.full_precision
quote! {
let #output = self.#field.forward(#input);
}
Comment on lines +103 to +106
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking we could match the elem precision and call into_full_precision() but the backendbridge requires the backend to be specified and here this is a generic... so can't do that right now.

}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::LayerNorm");
imports.register("burn::nn::LayerNormConfig");
}

fn into_node(self) -> Node<PS> {
Node::LayerNorm(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{record::FullPrecisionSettings, tensor::Data};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(LayerNormNode::new(
"norm",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
Data::from([2.]).serialize(),
Some(Data::from([2.]).serialize()),
LayerNormConfig::new(128),
true, // full_precision isn't taken into account
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::LayerNorm;
use burn::nn::LayerNormConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
norm: LayerNorm<B>,
phantom: core::marker::PhantomData<B>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let norm = LayerNormConfig::new(128)
.with_epsilon(0.00001f64)
.init(device);

Self {
norm,
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.norm.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub(crate) mod conv_transpose_2d;
pub(crate) mod dropout;
pub(crate) mod gather;
pub(crate) mod global_avg_pool;
pub(crate) mod layer_norm;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::GatherElements => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
Expand Down
42 changes: 39 additions & 3 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn::nn::{
conv::Conv1dConfig,
conv::{Conv2dConfig, ConvTranspose2dConfig},
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d,
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
PaddingConfig2d,
};

use super::ir::{ArgType, AttributeValue, Data, Node};
Expand Down Expand Up @@ -465,6 +465,42 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig {
.with_momentum(momentum as f64)
}

/// Create a LayerNormConfig from the attributes of the node
pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
// Extract the shape of the weight tensor
let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty {
tensor_type
} else {
panic!("LayerNorm: weight tensor must be present");
};

let num_features: usize = tensor_type.shape.clone().unwrap()[0];

// When `stash_type` is `1` (default), perform operations in 32-bit float and
// cast the results back to original dtype
let mut stash_type = 1;
let mut axis = -1;
let mut epsilon = 1e-5;

for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"epsilon" => epsilon = value.clone().into_f32(),
"stash_type" => stash_type = value.clone().into_i64(),
_ => {}
}
}

if axis != -1 && axis != tensor_type.dim as i64 - 1 {
panic!("LayerNorm: normalization is only supported on the last axis right now")
}

(
LayerNormConfig::new(num_features).with_epsilon(epsilon as f64),
stash_type == 1,
)
}

/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling.
///
/// # Arguments
Expand Down
Loading