Skip to content

Fix ONNX node name sanitization and add ai.onnx.ml domain support #3371

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 66 additions & 6 deletions crates/burn-import/src/burn/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,24 @@ pub enum Type {

impl Type {
// This is used, because types might have number literal name, which cannot be
// used as a variable name.
// used as a variable name, or contain invalid identifier characters.
// TODO (antimora) push the name sanitization upstream to onnx-ir; this is simple for now
pub fn format_name(name: &str) -> String {
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
if name_is_number {
format!("_{name}")
} else {
name.to_string()
let mut result = String::with_capacity(name.len());
// Sanitize the name by replacing invalid identifier characters with underscores
for c in name.chars() {
if c.is_ascii_alphanumeric() || c == '_' {
result.push(c);
} else {
result.push('_');
}
}

// Ensure the first character is valid to start an identifier
if !result.starts_with(|c: char| c.is_ascii_alphabetic() || c == '_') {
result = format!("_{result}");
}
result
}
pub fn name(&self) -> &Ident {
match self {
Expand Down Expand Up @@ -253,3 +263,53 @@ impl OtherType {
self.ty.clone()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_format_name_with_problematic_characters() {
// Test the problematic node name from GitHub issue #2878
let problematic_name = "jax2tf_rhs_/pjit_silu_/Const_2:0";
let sanitized = Type::format_name(problematic_name);
assert_eq!(sanitized, "jax2tf_rhs__pjit_silu__Const_2_0");
}

#[test]
fn test_format_name_edge_cases() {
// Test various edge cases
assert_eq!(Type::format_name("normal_name"), "normal_name");
assert_eq!(Type::format_name("123"), "_123");
assert_eq!(Type::format_name("name:with:colons"), "name_with_colons");
assert_eq!(Type::format_name("name/with/slashes"), "name_with_slashes");
assert_eq!(Type::format_name("name-with-dashes"), "name_with_dashes");
assert_eq!(Type::format_name("name.with.dots"), "name_with_dots");
assert_eq!(Type::format_name("name with spaces"), "name_with_spaces");
assert_eq!(
Type::format_name("9starts_with_number"),
"_9starts_with_number"
);
assert_eq!(
Type::format_name(":starts_with_colon"),
"_starts_with_colon"
);
}

#[test]
fn test_format_name_preserves_valid_identifiers() {
// Test that valid identifiers are preserved
assert_eq!(Type::format_name("valid_name"), "valid_name");
assert_eq!(Type::format_name("_underscore_start"), "_underscore_start");
assert_eq!(Type::format_name("CamelCase"), "CamelCase");
assert_eq!(Type::format_name("snake_case"), "snake_case");
assert_eq!(Type::format_name("name123"), "name123");
}

#[test]
fn test_tensor_type_creation_with_problematic_name() {
// Test that TensorType can be created with problematic names
let tensor = TensorType::new("jax2tf_rhs_/pjit_silu_/Const_2:0", 2, TensorKind::Float);
assert_eq!(tensor.name.to_string(), "jax2tf_rhs__pjit_silu__Const_2_0");
}
}
20 changes: 13 additions & 7 deletions crates/onnx-ir/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
///
/// # Panics
///
/// * If the domain is not the empty ONNX domain
/// * If the domain is not supported
pub fn check_opset_version(opset: &OperatorSetIdProto, min_version: i64) -> bool {
// For now, only empty domain (standard ONNX operators) is supported
if !opset.domain.is_empty() {
panic!("Only the standard ONNX domain is supported");
match opset.domain.as_str() {
// Standard ONNX operators
"" => opset.version >= min_version,
// ONNX ML operators - commonly used for traditional ML operators
"ai.onnx.ml" => opset.version >= 1, // ML operators are generally stable from version 1
// Add support for other domains as needed
_ => {
panic!(
"Unsupported ONNX domain: '{}'. Only standard ONNX ('') and ML ('ai.onnx.ml') domains are supported",
opset.domain
);
}
}

// Return true if the opset version is greater than or equal to min_version
opset.version >= min_version
}

/// Verify that all operator sets in a model are supported.
Expand Down
Loading