Skip to content

Commit 800e224

Browse files
authored
Fix ONNX node name sanitization and allow ai.onnx.ml domain (#3371)
* Fix ONNX node name sanitization to handle special characters Sanitizes node names containing invalid identifier characters like ':' and '/' by replacing them with underscores, preventing panic during ONNX import. * Expand ONNX domain support in opset version check The check_opset_version function now supports both the standard ONNX ('') and ML ('ai.onnx.ml') domains, panicking for unsupported domains. This improves compatibility with models using ML operators. * Format panic message for unsupported ONNX domain * Refactor Type::format_name for identifier sanitization Simplifies and improves the logic for sanitizing names to valid Rust identifiers. The new implementation replaces invalid characters with underscores and ensures the name starts with a valid character, removing special handling for numeric names. * Document unsupported ai.onnx.ml domain operators Added a section listing ONNX ML domain operators that are currently not supported for import or Burn support, along with reference links for each operator. * Remove unsupported ai.onnx.ml operators section
1 parent e2b8dc5 commit 800e224

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

crates/burn-import/src/burn/ty.rs

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,24 @@ pub enum Type {
6363

6464
impl Type {
6565
// This is used, because types might have number literal name, which cannot be
66-
// used as a variable name.
66+
// used as a variable name, or contain invalid identifier characters.
67+
// TODO (antimora) push the name sanitization upstream to onnx-ir; this is simple for now
6768
pub fn format_name(name: &str) -> String {
68-
let name_is_number = name.bytes().all(|digit| digit.is_ascii_digit());
69-
if name_is_number {
70-
format!("_{name}")
71-
} else {
72-
name.to_string()
69+
let mut result = String::with_capacity(name.len());
70+
// Sanitize the name by replacing invalid identifier characters with underscores
71+
for c in name.chars() {
72+
if c.is_ascii_alphanumeric() || c == '_' {
73+
result.push(c);
74+
} else {
75+
result.push('_');
76+
}
77+
}
78+
79+
// Ensure the first character is valid to start an identifier
80+
if !result.starts_with(|c: char| c.is_ascii_alphabetic() || c == '_') {
81+
result = format!("_{result}");
7382
}
83+
result
7484
}
7585
pub fn name(&self) -> &Ident {
7686
match self {
@@ -253,3 +263,53 @@ impl OtherType {
253263
self.ty.clone()
254264
}
255265
}
266+
267+
#[cfg(test)]
268+
mod tests {
269+
use super::*;
270+
271+
#[test]
272+
fn test_format_name_with_problematic_characters() {
273+
// Test the problematic node name from GitHub issue #2878
274+
let problematic_name = "jax2tf_rhs_/pjit_silu_/Const_2:0";
275+
let sanitized = Type::format_name(problematic_name);
276+
assert_eq!(sanitized, "jax2tf_rhs__pjit_silu__Const_2_0");
277+
}
278+
279+
#[test]
280+
fn test_format_name_edge_cases() {
281+
// Test various edge cases
282+
assert_eq!(Type::format_name("normal_name"), "normal_name");
283+
assert_eq!(Type::format_name("123"), "_123");
284+
assert_eq!(Type::format_name("name:with:colons"), "name_with_colons");
285+
assert_eq!(Type::format_name("name/with/slashes"), "name_with_slashes");
286+
assert_eq!(Type::format_name("name-with-dashes"), "name_with_dashes");
287+
assert_eq!(Type::format_name("name.with.dots"), "name_with_dots");
288+
assert_eq!(Type::format_name("name with spaces"), "name_with_spaces");
289+
assert_eq!(
290+
Type::format_name("9starts_with_number"),
291+
"_9starts_with_number"
292+
);
293+
assert_eq!(
294+
Type::format_name(":starts_with_colon"),
295+
"_starts_with_colon"
296+
);
297+
}
298+
299+
#[test]
300+
fn test_format_name_preserves_valid_identifiers() {
301+
// Test that valid identifiers are preserved
302+
assert_eq!(Type::format_name("valid_name"), "valid_name");
303+
assert_eq!(Type::format_name("_underscore_start"), "_underscore_start");
304+
assert_eq!(Type::format_name("CamelCase"), "CamelCase");
305+
assert_eq!(Type::format_name("snake_case"), "snake_case");
306+
assert_eq!(Type::format_name("name123"), "name123");
307+
}
308+
309+
#[test]
310+
fn test_tensor_type_creation_with_problematic_name() {
311+
// Test that TensorType can be created with problematic names
312+
let tensor = TensorType::new("jax2tf_rhs_/pjit_silu_/Const_2:0", 2, TensorKind::Float);
313+
assert_eq!(tensor.name.to_string(), "jax2tf_rhs__pjit_silu__Const_2_0");
314+
}
315+
}

crates/onnx-ir/src/util.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,21 @@ pub fn shape_config(curr: &Node) -> (usize, usize) {
5353
///
5454
/// # Panics
5555
///
56-
/// * If the domain is not the empty ONNX domain
56+
/// * If the domain is not supported
5757
pub fn check_opset_version(opset: &OperatorSetIdProto, min_version: i64) -> bool {
58-
// For now, only empty domain (standard ONNX operators) is supported
59-
if !opset.domain.is_empty() {
60-
panic!("Only the standard ONNX domain is supported");
58+
match opset.domain.as_str() {
59+
// Standard ONNX operators
60+
"" => opset.version >= min_version,
61+
// ONNX ML operators - commonly used for traditional ML operators
62+
"ai.onnx.ml" => opset.version >= 1, // ML operators are generally stable from version 1
63+
// Add support for other domains as needed
64+
_ => {
65+
panic!(
66+
"Unsupported ONNX domain: '{}'. Only standard ONNX ('') and ML ('ai.onnx.ml') domains are supported",
67+
opset.domain
68+
);
69+
}
6170
}
62-
63-
// Return true if the opset version is greater than or equal to min_version
64-
opset.version >= min_version
6571
}
6672

6773
/// Verify that all operator sets in a model are supported.

0 commit comments

Comments
 (0)