diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index 987fade40a..4c7e1851cd 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -296,7 +296,7 @@ impl TensorCheck { check = check.register( "Flatten", TensorError::new(format!( - "The start dim ({start_dim}) must be smaller than the end dim ({end_dim})" + "The start dim ({start_dim}) must be smaller than or equal to the end dim ({end_dim})" )), ); } @@ -304,7 +304,9 @@ impl TensorCheck { if D2 > D1 { check = check.register( "Flatten", - TensorError::new(format!("Result dim ({D2}) must be smaller than ({D1})")), + TensorError::new(format!( + "Result dim ({D2}) must be smaller than or equal to ({D1})" + )), ); } @@ -312,7 +314,7 @@ impl TensorCheck { check = check.register( "Flatten", TensorError::new(format!( - "The end dim ({end_dim}) must be greater than the tensor dim ({D2})" + "The end dim ({end_dim}) must be smaller than the tensor dim ({D1})" )), ); }