@@ -4626,64 +4626,61 @@ def main(
4626
4626
4627
4627
verify_model (DynamicModel (), example_args , {}, Expected , dynamic_shapes = dynamic_shapes )
4628
4628
4629
- #ADDED blank line
4629
+
4630
4630
def test_dynamic_shape_with_constraints ():
4631
- # Define SymInts with constraints
4632
- B = torch .export .Dim ("B" , min = 2 , max = 10 )
4633
- # Use B again for another dimension to test refinement (max(10, 15) -> 15)
4634
- B_refined = torch .export .Dim ("B" , min = 3 , max = 15 )
4635
- S = torch .export .Dim ("S" , min = 1 ) # Test min constraint only (-> (1, None))
4636
-
4637
- # Example args matching initial B dim (max=10)
4638
- example_args = (torch .randn (3 , 4 , dtype = torch .float32 ), torch .randn (5 , 2 , dtype = torch .float32 ))
4639
-
4640
- # Dynamic shapes using the Dim objects
4641
- # Input 0: Dim 0 uses B (min=2, max=10), Dim 1 uses S (min=1)
4642
- # Input 1: Dim 0 uses B_refined (min=3, max=15)
4643
- # The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10
4644
- dynamic_shapes = {0 : {0 : B , 1 : S }, 1 : {0 : B_refined }}
4645
-
4646
- class SimpleDynamic (torch .nn .Module ):
4647
- # Simple op, the main thing is testing the input signature and constraints
4648
- def forward (self , x , y ):
4649
- # Add tensors with different shapes requires broadcasting,
4650
- # but we only care about the input signature here.
4651
- # Use an op that doesn't depend on exact shapes matching.
4652
- return torch .relu (x ) # Return just one to simplify output signature
4653
-
4654
- # NEW: Define TIR Vars for TVMScript parsing
4655
- B = tir .Var ("B" , "int64" )
4656
- S = tir .Var ("S" , "int64" )
4657
-
4658
- # Define the expected Relax IRModule
4659
- @tvm .script .ir_module
4660
- class Expected :
4661
- @R .function
4662
- def main (
4663
- # Note: B has refined constraints: min=3, max=10
4664
- # Note: S has constraints: min=1
4665
- x : R .Tensor ((B , S ), dtype = "float32" ),
4666
- y : R .Tensor ((B , 2 ), dtype = "float32" )
4667
- ) -> R .Tuple (R .Tensor ((B , S ), dtype = "float32" )):
4668
- B = T .int64 ()
4669
- S = T .int64 ()
4670
- # tell TIR about the constraints via function attributes
4671
- T .func_attr ({
4672
- "tir_var_lower_bound" : {B : 3 , S : 1 },
4673
- "tir_var_upper_bound" : {B : 10 }
4674
- })
4675
- with R .dataflow ():
4676
- # The actual body isn't the focus, just the signature
4677
- lv : R .Tensor ((B , S ), dtype = "float32" ) = R .relu (x )
4678
- # Output must be a tuple
4679
- gv : R .Tuple (R .Tensor ((B , S ), dtype = "float32" )) = (lv ,)
4680
- R .output (gv )
4681
- return gv
4682
-
4683
- # Use verify_model utility
4684
- verify_model (SimpleDynamic (), example_args , {}, Expected , dynamic_shapes = dynamic_shapes )
4685
-
4686
- #ADDED blank line
4631
+ # Define SymInts with constraints
4632
+ B = torch .export .Dim ("B" , min = 2 , max = 10 )
4633
+ # Use B again for another dimension to test refinement (max(10, 15) -> 15)
4634
+ B_refined = torch .export .Dim ("B" , min = 3 , max = 15 )
4635
+ S = torch .export .Dim ("S" , min = 1 ) # Test min constraint only (-> (1, None))
4636
+
4637
+ # Example args matching initial B dim (max=10)
4638
+ example_args = (torch .randn (3 , 4 , dtype = torch .float32 ), torch .randn (5 , 2 , dtype = torch .float32 ))
4639
+
4640
+ # Dynamic shapes using the Dim objects
4641
+ # Input 0: Dim 0 uses B (min=2, max=10), Dim 1 uses S (min=1)
4642
+ # Input 1: Dim 0 uses B_refined (min=3, max=15)
4643
+ # The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10
4644
+ dynamic_shapes = {0 : {0 : B , 1 : S }, 1 : {0 : B_refined }}
4645
+
4646
+ class SimpleDynamic (torch .nn .Module ):
4647
+ # Simple op, the main thing is testing the input signature and constraints
4648
+ def forward (self , x , y ):
4649
+ # Add tensors with different shapes requires broadcasting,
4650
+ # but we only care about the input signature here.
4651
+ # Use an op that doesn't depend on exact shapes matching.
4652
+ return torch .relu (x ) # Return just one to simplify output signature
4653
+
4654
+ # NEW: Define TIR Vars for TVMScript parsing
4655
+ B = tir .Var ("B" , "int64" )
4656
+ S = tir .Var ("S" , "int64" )
4657
+
4658
+ # Define the expected Relax IRModule
4659
+ @tvm .script .ir_module
4660
+ class Expected :
4661
+ @R .function
4662
+ def main (
4663
+ # Note: B has refined constraints: min=3, max=10
4664
+ # Note: S has constraints: min=1
4665
+ x : R .Tensor ((B , S ), dtype = "float32" ),
4666
+ y : R .Tensor ((B , 2 ), dtype = "float32" ),
4667
+ ) -> R .Tuple (R .Tensor ((B , S ), dtype = "float32" )):
4668
+ B = T .int64 ()
4669
+ S = T .int64 ()
4670
+ # tell TIR about the constraints via function attributes
4671
+ T .func_attr ({"tir_var_lower_bound" : {B : 3 , S : 1 }, "tir_var_upper_bound" : {B : 10 }})
4672
+ with R .dataflow ():
4673
+ # The actual body isn't the focus, just the signature
4674
+ lv : R .Tensor ((B , S ), dtype = "float32" ) = R .relu (x )
4675
+ # Output must be a tuple
4676
+ gv : R .Tuple (R .Tensor ((B , S ), dtype = "float32" )) = (lv ,)
4677
+ R .output (gv )
4678
+ return gv
4679
+
4680
+ # Use verify_model utility
4681
+ verify_model (SimpleDynamic (), example_args , {}, Expected , dynamic_shapes = dynamic_shapes )
4682
+
4683
+
4687
4684
def test_broadcast_to ():
4688
4685
class BroadcastTo (Module ):
4689
4686
def forward (self , x ):
0 commit comments