Skip to content

feat(relax/frontend/torch): Add basic range constraint support #17898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
a4ab52f
feat(relax/frontend/torch): Add basic range constraint support from E…
demoncoder-crypto Apr 26, 2025
71f8a98
Fix: Insert test_dynamic_shape_with_constraints
demoncoder-crypto Apr 28, 2025
f8ad0d7
refactor(test): Refactor constraint test to use verify_model and add …
demoncoder-crypto May 2, 2025
e5710b7
fix(test): Define tir.Var for TVMScript parsing in constraint test
demoncoder-crypto May 3, 2025
5c7758c
style: Apply black formatting
demoncoder-crypto May 3, 2025
9ca05a6
fix(relax/torch): Handle ExportedProgram range constraints and add tests
demoncoder-crypto May 4, 2025
8ab98aa
Merge branch 'main' into fix/relax-pytorch-constraints-v2
demoncoder-crypto May 4, 2025
7201b72
style: Apply formatting fixes to test_frontend_from_exported_program.py
demoncoder-crypto May 4, 2025
f7e23f4
style: Fix trailing whitespace in test file
demoncoder-crypto May 4, 2025
bcab702
feat(relax): Enhance PyTorch ExportedProgram range constraints support
demoncoder-crypto May 4, 2025
70bff93
feat: Enhance PyTorch range constraints support
demoncoder-crypto May 4, 2025
54885dd
style: Fix lint errors reported by CI
demoncoder-crypto May 4, 2025
4717288
style: Apply final lint fixes for translator and test files
demoncoder-crypto May 4, 2025
073ec93
Apply Black code formatting to exported_program_translator.py
demoncoder-crypto May 4, 2025
6162a27
Add logging module for PyTorch frontend
demoncoder-crypto May 4, 2025
249c808
fix: coerce bounds to int and update R.relu to R.nn.relu
demoncoder-crypto May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 81 additions & 14 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
"""PyTorch ExportedProgram of Relax."""
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Callable, Dict, List, Tuple
from typing import Callable, Dict, List, Tuple, Optional

import torch
import tvm
from tvm import relax
import sympy

from .base_fx_graph_translator import BaseFXGraphImporter

Expand Down Expand Up @@ -497,11 +498,12 @@ def create_convert_map(

def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
relax_range_constraints: Dict[tvm.tir.Var, Tuple[Optional[int], Optional[int]]] = {}

for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
Expand All @@ -519,13 +521,18 @@ def create_input_vars(
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype

# TODO(mshr-h): Support range constraints
relax_shape = [
torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64"))
if isinstance(s, torch.SymInt)
else s
for s in torch_shape
]
# UPDATED: Create SizeVars and map SymInts (removed original shape creation)
relax_shape = []
for s in torch_shape:
if isinstance(s, torch.SymInt):
s_str = str(s)
# Ensure SizeVar is created if not already present
if s_str not in torch_symbol_to_relax_var:
torch_symbol_to_relax_var[s_str] = tvm.tir.SizeVar(s_str, "int64")
relax_shape.append(torch_symbol_to_relax_var[s_str])
else:
relax_shape.append(s)

dtype = self._convert_data_type(torch_dtype)

relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype))
Expand All @@ -534,7 +541,47 @@ def create_input_vars(
else:
parameters_buffers_constants[name_hint] = relax_var

return parameters_buffers_constants, user_inputs
# NEW: Process range constraints (basic support for simple SymInt keys)
if hasattr(exported_program, "range_constraints"):
for torch_sym_expr, value_range in exported_program.range_constraints.items():
# Basic support: Only handle constraints where the key is a simple SymInt
if isinstance(torch_sym_expr, torch.SymInt):
s_str = str(torch_sym_expr)
if s_str in torch_symbol_to_relax_var:
relax_tir_var = torch_symbol_to_relax_var[s_str]

# Extract bounds, using None for infinity
min_val = int(value_range.lower) if value_range.lower != -sympy.oo else None
max_val = int(value_range.upper) if value_range.upper != sympy.oo else None

if relax_tir_var not in relax_range_constraints:
relax_range_constraints[relax_tir_var] = (min_val, max_val)
else:
# Refine existing constraints if the new one is tighter
existing_min, existing_max = relax_range_constraints[relax_tir_var]

# Update min: take the max of lower bounds (None means -inf)
if existing_min is None:
new_min = min_val
elif min_val is None:
new_min = existing_min
else:
new_min = max(existing_min, min_val)

# Update max: take the min of upper bounds (None means +inf)
if existing_max is None:
new_max = max_val
elif max_val is None:
new_max = existing_max
else:
new_max = min(existing_max, max_val)

relax_range_constraints[relax_tir_var] = (new_min, new_max)
# else:
# TODO: Handle complex expressions (e.g., s0 + 1) for advanced support
# print(f"Skipping complex constraint expression: {torch_sym_expr}")

return parameters_buffers_constants, user_inputs, relax_range_constraints

def from_exported_program(
self,
Expand All @@ -546,23 +593,43 @@ def from_exported_program(
"""Convert a PyTorch ExportedProgram to a Relax program."""
from torch import fx # type: ignore

# Create input variables.
parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program)
# Create input variables and get range constraints.
parameter_buffer_constant_vars, user_input_vars, relax_range_constraints = self.create_input_vars(exported_program)
inputs_vars = user_input_vars.copy()
inputs_vars.update(parameter_buffer_constant_vars)

# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
func_name = "main"
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None

# Prepare function attributes
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else {}

# NEW: Add range constraints to function attributes if they exist
if relax_range_constraints:
lower_bounds = {}
upper_bounds = {}
for tir_var, (min_val, max_val) in relax_range_constraints.items():
if min_val is not None:
lower_bounds[tir_var] = tvm.tir.IntImm("int64", min_val)
if max_val is not None:
upper_bounds[tir_var] = tvm.tir.IntImm("int64", max_val)

if lower_bounds:
func_attrs["tir_var_lower_bound"] = lower_bounds
if upper_bounds:
func_attrs["tir_var_upper_bound"] = upper_bounds

# Use None if func_attrs is empty, otherwise use the dictionary
final_func_attrs = func_attrs if func_attrs else None

nodes: List[fx.Node] = exported_program.graph.nodes

# Find all the missing function types
self._check_unsupported_func_type(nodes)

with self.block_builder.function(
name=func_name, params=list(inputs_vars.values()).copy(), attrs=func_attrs
name=func_name, params=list(inputs_vars.values()).copy(), attrs=final_func_attrs
):
output = None
with self.block_builder.dataflow():
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4625,6 +4625,58 @@ def main(
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)

def test_dynamic_shape_with_constraints():
# Define SymInts with constraints
B = torch.export.Dim("B", min=2, max=10)
# Use B again for another dimension to test refinement (max(10, 15) -> 15)
B_refined = torch.export.Dim("B", min=3, max=15)
S = torch.export.Dim("S", min=1) # Test min constraint only (-> (1, None))

# Example args matching initial B dim (max=10)
example_args = (torch.randn(3, 4, dtype=torch.float32), torch.randn(5, 2, dtype=torch.float32))

# Dynamic shapes using the Dim objects
# Input 0: Dim 0 uses B (min=2, max=10), Dim 1 uses S (min=1)
# Input 1: Dim 0 uses B_refined (min=3, max=15)
# The final constraint for tir.Var("B") should be max(2,3) to min(10,15) => min=3, max=10
dynamic_shapes = {0: {0: B, 1: S}, 1: {0: B_refined}}

class SimpleDynamic(torch.nn.Module):
# Simple op, the main thing is testing the input signature and constraints
def forward(self, x, y):
# Add tensors with different shapes requires broadcasting,
# but we only care about the input signature here.
# Use an op that doesn't depend on exact shapes matching.
return torch.relu(x) # Return just one to simplify output signature

# Define the expected Relax IRModule
@tvm.script.ir_module
class Expected:
@R.function
def main(
# Note: B has refined constraints: min=3, max=10
# Note: S has constraints: min=1
x: R.Tensor((B, S), dtype="float32"),
y: R.Tensor((B, 2), dtype="float32")
) -> R.Tuple(R.Tensor((B, S), dtype="float32")):
B = T.int64()
S = T.int64()
# tell TIR about the constraints via function attributes
T.func_attr({
"tir_var_lower_bound": {B: 3, S: 1},
"tir_var_upper_bound": {B: 10}
})
with R.dataflow():
# The actual body isn't the focus, just the signature
lv: R.Tensor((B, S), dtype="float32") = R.relu(x)
# Output must be a tuple
gv: R.Tuple(R.Tensor((B, S), dtype="float32")) = (lv,)
R.output(gv)
return gv

# Use verify_model utility
verify_model(SimpleDynamic(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes)


def test_broadcast_to():
Expand Down