Skip to content

Commit 064f5d7

Browse files
style: Apply black formatting
1 parent e5710b7 commit 064f5d7

File tree

2 files changed

+113
-102
lines changed

2 files changed

+113
-102
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
import torch
2626
import tvm
2727
from tvm import relax
28-
import sympy
29-
28+
import tvm.tir as tir # pylint: disable=unused-import
3029
from .base_fx_graph_translator import BaseFXGraphImporter
3130

3231

@@ -498,7 +497,11 @@ def create_convert_map(
498497

499498
def create_input_vars(
500499
self, exported_program: torch.export.ExportedProgram
501-
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]]]:
500+
) -> Tuple[
501+
Dict[str, relax.Var],
502+
Dict[str, relax.Var],
503+
Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]],
504+
]:
502505
"""Create relax input vars."""
503506
parameters_buffers_constants = OrderedDict()
504507
user_inputs = OrderedDict()
@@ -521,18 +524,17 @@ def create_input_vars(
521524
torch_shape = exported_program.state_dict[spec.target].shape
522525
torch_dtype = exported_program.state_dict[spec.target].dtype
523526

524-
# UPDATED: Create SizeVars and map SymInts (removed original shape creation)
527+
# Create SizeVars and map SymInts
525528
relax_shape = []
526529
for s in torch_shape:
527530
if isinstance(s, torch.SymInt):
528531
s_str = str(s)
529-
# Ensure SizeVar is created if not already present
530532
if s_str not in torch_symbol_to_relax_var:
531533
torch_symbol_to_relax_var[s_str] = tvm.tir.SizeVar(s_str, "int64")
532534
relax_shape.append(torch_symbol_to_relax_var[s_str])
533535
else:
534536
relax_shape.append(s)
535-
537+
536538
dtype = self._convert_data_type(torch_dtype)
537539

538540
relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype))
@@ -541,48 +543,56 @@ def create_input_vars(
541543
else:
542544
parameters_buffers_constants[name_hint] = relax_var
543545

544-
# NEW: Process range constraints (basic support for simple SymInt keys)
545-
if hasattr(exported_program, "range_constraints"):
546-
for torch_sym_expr, value_range in exported_program.range_constraints.items():
547-
# Basic support: Only handle constraints where the key is a simple SymInt
548-
if isinstance(torch_sym_expr, torch.SymInt):
549-
s_str = str(torch_sym_expr)
550-
if s_str in torch_symbol_to_relax_var:
551-
relax_tir_var = torch_symbol_to_relax_var[s_str]
552-
553-
# Extract bounds, using None for infinity
554-
min_val = int(value_range.lower) if value_range.lower != -sympy.oo else None
555-
max_val = int(value_range.upper) if value_range.upper != sympy.oo else None
556-
557-
if relax_tir_var not in relax_range_constraints:
558-
relax_range_constraints[relax_tir_var] = (min_val, max_val)
559-
else:
560-
# Refine existing constraints if the new one is tighter
561-
existing_min, existing_max = relax_range_constraints[relax_tir_var]
562-
563-
# Update min: take the max of lower bounds (None means -inf)
564-
if existing_min is None:
565-
new_min = min_val
566-
elif min_val is None:
567-
new_min = existing_min
568-
else:
569-
new_min = max(existing_min, min_val)
570-
571-
# Update max: take the min of upper bounds (None means +inf)
572-
if existing_max is None:
573-
new_max = max_val
574-
elif max_val is None:
575-
new_max = existing_max
576-
else:
577-
new_max = min(existing_max, max_val)
578-
579-
relax_range_constraints[relax_tir_var] = (new_min, new_max)
546+
# Extract range constraints for TIR vars
547+
if hasattr(exported_program, "range_constraints") and exported_program.range_constraints:
548+
for torch_sym_expr, constraint in exported_program.range_constraints.items():
549+
# Convert sympy expression to string for mapping
550+
torch_sym_expr_str = str(torch_sym_expr)
551+
552+
if torch_sym_expr_str in torch_symbol_to_relax_var:
553+
relax_tir_var = torch_symbol_to_relax_var[torch_sym_expr_str]
554+
# TODO(sjt): Handle SymFloat, SymBool cases as well.
555+
# Note: min / max could be int or SymInt objects.
556+
# Need to handle symbolic shapes as well.
557+
min_val = constraint.min
558+
max_val = constraint.max
559+
# Call helper to add/refine constraint
560+
self._add_range_constraint(
561+
relax_range_constraints, relax_tir_var, min_val, max_val
562+
)
580563
# else:
581-
# TODO: Handle complex expressions (e.g., s0 + 1) for advanced support
582-
# print(f"Skipping complex constraint expression: {torch_sym_expr}")
564+
# FIXED Indentation for Black:
565+
# TODO: Handle complex expressions (e.g., s0 + 1) for advanced support
566+
# print(f"Skipping complex constraint expression: {torch_sym_expr}")
583567

584568
return parameters_buffers_constants, user_inputs, relax_range_constraints
585569

570+
# NEW HELPER METHOD
571+
def _add_range_constraint(self, constraints_dict, relax_tir_var, min_val, max_val):
572+
"""Adds or refines a range constraint for a TIR variable."""
573+
if relax_tir_var not in constraints_dict:
574+
constraints_dict[relax_tir_var] = (min_val, max_val)
575+
else:
576+
# Refine existing constraints if the new one is tighter
577+
existing_min, existing_max = constraints_dict[relax_tir_var]
578+
# Merge lower bounds (take the max)
579+
if existing_min is None:
580+
new_min = min_val
581+
elif min_val is None:
582+
new_min = existing_min
583+
else:
584+
new_min = max(existing_min, min_val)
585+
586+
# Merge upper bounds (take the min)
587+
if existing_max is None:
588+
new_max = max_val
589+
elif max_val is None:
590+
new_max = existing_max
591+
else:
592+
new_max = min(existing_max, max_val)
593+
594+
constraints_dict[relax_tir_var] = (new_min, new_max)
595+
586596
def from_exported_program(
587597
self,
588598
exported_program: torch.export.ExportedProgram,
@@ -594,7 +604,11 @@ def from_exported_program(
594604
from torch import fx # type: ignore
595605

596606
# Create input variables and get range constraints.
597-
parameter_buffer_constant_vars, user_input_vars, relax_range_constraints = self.create_input_vars(exported_program)
607+
(
608+
parameter_buffer_constant_vars,
609+
user_input_vars,
610+
relax_range_constraints,
611+
) = self.create_input_vars(exported_program)
598612
inputs_vars = user_input_vars.copy()
599613
inputs_vars.update(parameter_buffer_constant_vars)
600614

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4626,64 +4626,61 @@ def main(
46264626

46274627
verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)
46284628

4629-
#ADDED blank line
4629+
46304630
def test_dynamic_shape_with_constraints():
4631-
# Define SymInts with constraints
4632-
B = torch.export.Dim("B", min=2, max=10)
4633-
# Use B again for another dimension to test refinement (max(10, 15) -> 15)
4634-
B_refined = torch.export.Dim("B", min=3, max=15)
4635-
S = torch.export.Dim("S", min=1) # Test min constraint only (-> (1, None))
4636-
4637-
# Example args matching initial B dim (max=10)
4638-
example_args = (torch.randn(3, 4, dtype=torch.float32), torch.randn(5, 2, dtype=torch.float32))
4639-
4640-
# Dynamic shapes using the Dim objects
4641-
# Input 0: Dim 0 uses B (min=2, max=10), Dim 1 uses S (min=1)
4642-
# Input 1: Dim 0 uses B_refined (min=3, max=15)
4643-
# The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10
4644-
dynamic_shapes = {0: {0: B, 1: S}, 1: {0: B_refined}}
4645-
4646-
class SimpleDynamic(torch.nn.Module):
4647-
# Simple op, the main thing is testing the input signature and constraints
4648-
def forward(self, x, y):
4649-
# Add tensors with different shapes requires broadcasting,
4650-
# but we only care about the input signature here.
4651-
# Use an op that doesn't depend on exact shapes matching.
4652-
return torch.relu(x) # Return just one to simplify output signature
4653-
4654-
# NEW: Define TIR Vars for TVMScript parsing
4655-
B = tir.Var("B", "int64")
4656-
S = tir.Var("S", "int64")
4657-
4658-
# Define the expected Relax IRModule
4659-
@tvm.script.ir_module
4660-
class Expected:
4661-
@R.function
4662-
def main(
4663-
# Note: B has refined constraints: min=3, max=10
4664-
# Note: S has constraints: min=1
4665-
x: R.Tensor((B, S), dtype="float32"),
4666-
y: R.Tensor((B, 2), dtype="float32")
4667-
) -> R.Tuple(R.Tensor((B, S), dtype="float32")):
4668-
B = T.int64()
4669-
S = T.int64()
4670-
# tell TIR about the constraints via function attributes
4671-
T.func_attr({
4672-
"tir_var_lower_bound": {B: 3, S: 1},
4673-
"tir_var_upper_bound": {B: 10}
4674-
})
4675-
with R.dataflow():
4676-
# The actual body isn't the focus, just the signature
4677-
lv: R.Tensor((B, S), dtype="float32") = R.relu(x)
4678-
# Output must be a tuple
4679-
gv: R.Tuple(R.Tensor((B, S), dtype="float32")) = (lv,)
4680-
R.output(gv)
4681-
return gv
4682-
4683-
# Use verify_model utility
4684-
verify_model(SimpleDynamic(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)
4685-
4686-
#ADDED blank line
4631+
# Define SymInts with constraints
4632+
B = torch.export.Dim("B", min=2, max=10)
4633+
# Use B again for another dimension to test refinement (max(10, 15) -> 15)
4634+
B_refined = torch.export.Dim("B", min=3, max=15)
4635+
S = torch.export.Dim("S", min=1) # Test min constraint only (-> (1, None))
4636+
4637+
# Example args matching initial B dim (max=10)
4638+
example_args = (torch.randn(3, 4, dtype=torch.float32), torch.randn(5, 2, dtype=torch.float32))
4639+
4640+
# Dynamic shapes using the Dim objects
4641+
# Input 0: Dim 0 uses B (min=2, max=10), Dim 1 uses S (min=1)
4642+
# Input 1: Dim 0 uses B_refined (min=3, max=15)
4643+
# The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10
4644+
dynamic_shapes = {0: {0: B, 1: S}, 1: {0: B_refined}}
4645+
4646+
class SimpleDynamic(torch.nn.Module):
4647+
# Simple op, the main thing is testing the input signature and constraints
4648+
def forward(self, x, y):
4649+
# Add tensors with different shapes requires broadcasting,
4650+
# but we only care about the input signature here.
4651+
# Use an op that doesn't depend on exact shapes matching.
4652+
return torch.relu(x) # Return just one to simplify output signature
4653+
4654+
# NEW: Define TIR Vars for TVMScript parsing
4655+
B = tir.Var("B", "int64")
4656+
S = tir.Var("S", "int64")
4657+
4658+
# Define the expected Relax IRModule
4659+
@tvm.script.ir_module
4660+
class Expected:
4661+
@R.function
4662+
def main(
4663+
# Note: B has refined constraints: min=3, max=10
4664+
# Note: S has constraints: min=1
4665+
x: R.Tensor((B, S), dtype="float32"),
4666+
y: R.Tensor((B, 2), dtype="float32"),
4667+
) -> R.Tuple(R.Tensor((B, S), dtype="float32")):
4668+
B = T.int64()
4669+
S = T.int64()
4670+
# tell TIR about the constraints via function attributes
4671+
T.func_attr({"tir_var_lower_bound": {B: 3, S: 1}, "tir_var_upper_bound": {B: 10}})
4672+
with R.dataflow():
4673+
# The actual body isn't the focus, just the signature
4674+
lv: R.Tensor((B, S), dtype="float32") = R.relu(x)
4675+
# Output must be a tuple
4676+
gv: R.Tuple(R.Tensor((B, S), dtype="float32")) = (lv,)
4677+
R.output(gv)
4678+
return gv
4679+
4680+
# Use verify_model utility
4681+
verify_model(SimpleDynamic(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)
4682+
4683+
46874684
def test_broadcast_to():
46884685
class BroadcastTo(Module):
46894686
def forward(self, x):

0 commit comments

Comments
 (0)