Skip to content

Commit cd6e5db

Browse files
authored
Consider GraphInferenceContext in inference functions: InferenceContext (#4632)
* Expose GraphInferenceContext in Python interface for inference functions Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * use the same map Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * add opset_imports and handle input_types Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * graph_opset_import to clarify Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * fix lint Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * fix black Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * add a test Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * make_opsetid Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * replace test with Add in subgraph Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * remove unused Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> * use shorter name for opset_imports and ir_version Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com> Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
1 parent 466edb7 commit cd6e5db

File tree

3 files changed

+126
-57
lines changed

3 files changed

+126
-57
lines changed

onnx/cpp2py_export.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ std::unordered_map<std::string, py::bytes> CallNodeInferenceFunction(
5858
const py::bytes& nodeBytes,
5959
std::unordered_map<std::string, py::bytes> valueTypesByNameBytes,
6060
std::unordered_map<std::string, py::bytes> inputDataByNameBytes,
61-
std::unordered_map<std::string, py::bytes> inputSparseDataByNameBytes) {
61+
std::unordered_map<std::string, py::bytes> inputSparseDataByNameBytes,
62+
std::unordered_map<std::string, int> opsetImports,
63+
const int irVersion) {
6264
NodeProto node{};
6365
ParseProtoFromPyBytes(&node, nodeBytes);
6466
// Early fail if node is badly defined - may throw ValidationError
@@ -68,9 +70,15 @@ std::unordered_map<std::string, py::bytes> CallNodeInferenceFunction(
6870
const auto& valueTypes = ParseProtoFromBytesMap<TypeProto>(valueTypesByNameBytes);
6971
const auto& inputData = ParseProtoFromBytesMap<const TensorProto>(inputDataByNameBytes);
7072
const auto& inputSparseData = ParseProtoFromBytesMap<const SparseTensorProto>(inputSparseDataByNameBytes);
73+
if (opsetImports.empty()) {
74+
opsetImports[schema->domain()] = schema->SinceVersion();
75+
}
7176

77+
shape_inference::GraphInferenceContext graphInferenceContext(
78+
valueTypes.second, opsetImports, nullptr, {}, OpSchemaRegistry::Instance(), nullptr, irVersion);
7279
// Construct inference context and get results - may throw InferenceError
73-
shape_inference::InferenceContextImpl ctx(node, valueTypes.second, inputData.second, inputSparseData.second);
80+
shape_inference::InferenceContextImpl ctx(
81+
node, valueTypes.second, inputData.second, inputSparseData.second, nullptr, &graphInferenceContext);
7482
schema->GetTypeAndShapeInferenceFunction()(ctx);
7583
// Verify the inference succeeded - may also throw ValidationError
7684
// Note that input types were not validated until now (except that their count was correct)
@@ -142,7 +150,9 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
142150
py::arg("nodeBytes"),
143151
py::arg("valueTypesByNameBytes"),
144152
py::arg("inputDataByNameBytes") = std::unordered_map<std::string, py::bytes>{},
145-
py::arg("inputSparseDataByNameBytes") = std::unordered_map<std::string, py::bytes>{})
153+
py::arg("inputSparseDataByNameBytes") = std::unordered_map<std::string, py::bytes>{},
154+
py::arg("opsetImports") = std::unordered_map<std::string, int>{},
155+
py::arg("irVersion") = int(IR_VERSION))
146156
.def(
147157
"get_context_dependent_function",
148158
[](OpSchema* op, const py::bytes& bytes, const std::vector<py::bytes>& input_types_bytes) -> py::bytes {

onnx/shape_inference.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
"""
66

7-
from typing import Dict, Optional, Union
7+
from typing import Dict, List, Optional, Union
88

99
import onnx
1010
import onnx.onnx_cpp2py_export.shape_inference as C
@@ -88,18 +88,28 @@ def infer_node_outputs(
8888
input_types: Dict[str, onnx.TypeProto],
8989
input_data: Optional[Dict[str, onnx.TensorProto]] = None,
9090
input_sparse_data: Optional[Dict[str, onnx.SparseTensorProto]] = None,
91+
opset_imports: Optional[List[onnx.OperatorSetIdProto]] = None,
92+
ir_version: int = onnx.IR_VERSION,
9193
) -> Dict[str, onnx.TypeProto]:
9294
if not schema.has_type_and_shape_inference_function: # type: ignore
9395
return {}
9496
if input_data is None:
9597
input_data = {}
9698
if input_sparse_data is None:
9799
input_sparse_data = {}
100+
if opset_imports is None:
101+
passed_opset_imports = {}
102+
else:
103+
passed_opset_imports = {opset.domain: opset.version for opset in opset_imports}
98104

99-
# To avoid copying on C++ side, pass only what is needed for this inference call
105+
# catch KeyError if node's input does not exist in input_types
100106
passed_input_types = {
101107
key: input_types[key].SerializeToString() for key in node.input
102108
}
109+
# input_types will also be used as outer_scope_value_types so do not filter by node's input here
110+
for key in input_types:
111+
if key not in passed_input_types:
112+
passed_input_types[key] = input_types[key].SerializeToString()
103113
passed_input_data = {
104114
key: input_data[key].SerializeToString()
105115
for key in node.input
@@ -116,7 +126,9 @@ def infer_node_outputs(
116126
passed_input_types,
117127
passed_input_data,
118128
passed_sparse_input_data,
119-
)
129+
passed_opset_imports,
130+
ir_version,
131+
) # type: ignore[call-arg]
120132
return {key: onnx.TypeProto.FromString(out) for key, out in outputs.items()}
121133

122134

onnx/test/inference_function_test.py

Lines changed: 98 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,123 +3,119 @@
33

44
import numpy as np
55

6-
import onnx
7-
import onnx.numpy_helper
8-
import onnx.shape_inference
6+
from onnx import TensorProto, TypeProto
7+
from onnx.checker import ValidationError
8+
from onnx.defs import OpSchema, get_all_schemas_with_history, get_schema
9+
from onnx.helper import (
10+
make_graph,
11+
make_node,
12+
make_opsetid,
13+
make_tensor_type_proto,
14+
make_tensor_value_info,
15+
)
16+
from onnx.numpy_helper import from_array
17+
from onnx.shape_inference import InferenceError, infer_node_outputs
918

1019
ADD_SCHEMA = max(
11-
(
12-
s
13-
for s in onnx.defs.get_all_schemas_with_history()
14-
if s.name == "Add" and s.domain == ""
15-
),
20+
(s for s in get_all_schemas_with_history() if s.name == "Add" and s.domain == ""),
1621
key=lambda s: s.since_version,
1722
)
1823
RESHAPE_SCHEMA = max(
1924
(
2025
s
21-
for s in onnx.defs.get_all_schemas_with_history()
26+
for s in get_all_schemas_with_history()
2227
if s.name == "Reshape" and s.domain == ""
2328
),
2429
key=lambda s: s.since_version,
2530
)
2631

27-
_tensor = onnx.helper.make_tensor_type_proto
28-
2932

3033
def _to_tensor_types(
3134
tensor_types: Dict[str, Tuple[int, Tuple[Union[int, str, None], ...]]]
32-
) -> Dict[str, onnx.TypeProto]:
33-
return {
34-
key: onnx.helper.make_tensor_type_proto(*value)
35-
for key, value in tensor_types.items()
36-
}
35+
) -> Dict[str, TypeProto]:
36+
return {key: make_tensor_type_proto(*value) for key, value in tensor_types.items()}
3737

3838

3939
def _run_case(
40-
schema: onnx.defs.OpSchema,
40+
schema: OpSchema,
4141
input_names: List[str],
4242
output_names: List[str],
43-
input_types: Dict[str, onnx.TypeProto],
43+
input_types: Dict[str, TypeProto],
4444
input_data: Optional[Dict[str, np.ndarray]] = None,
45-
) -> Dict[str, onnx.TypeProto]:
45+
) -> Dict[str, TypeProto]:
4646
if input_data is None:
4747
input_data = {}
48-
return onnx.shape_inference.infer_node_outputs(
48+
return infer_node_outputs(
4949
schema,
50-
onnx.helper.make_node(
51-
schema.name, input_names, output_names, domain=schema.domain
52-
),
50+
make_node(schema.name, input_names, output_names, domain=schema.domain),
5351
input_types,
54-
{key: onnx.numpy_helper.from_array(arr) for key, arr in input_data.items()},
52+
{key: from_array(arr) for key, arr in input_data.items()},
5553
)
5654

5755

5856
class TestInferenceFunctionCall(unittest.TestCase):
5957
def test_add_inference(self) -> None:
6058
cases = [
6159
(
62-
{"A": (onnx.TensorProto.FLOAT, ()), "B": (onnx.TensorProto.FLOAT, ())},
63-
{"C": (onnx.TensorProto.FLOAT, ())},
60+
{"A": (TensorProto.FLOAT, ()), "B": (TensorProto.FLOAT, ())},
61+
{"C": (TensorProto.FLOAT, ())},
6462
),
6563
(
6664
{
67-
"A": (onnx.TensorProto.FLOAT, (None, 2)),
68-
"B": (onnx.TensorProto.FLOAT, (2,)),
65+
"A": (TensorProto.FLOAT, (None, 2)),
66+
"B": (TensorProto.FLOAT, (2,)),
6967
},
70-
{"C": (onnx.TensorProto.FLOAT, (None, 2))},
68+
{"C": (TensorProto.FLOAT, (None, 2))},
7169
),
7270
(
7371
{
74-
"A": (onnx.TensorProto.FLOAT, (None, 2)),
75-
"B": (onnx.TensorProto.FLOAT, (1, 2)),
72+
"A": (TensorProto.FLOAT, (None, 2)),
73+
"B": (TensorProto.FLOAT, (1, 2)),
7674
},
77-
{"C": (onnx.TensorProto.FLOAT, (None, 2))},
75+
{"C": (TensorProto.FLOAT, (None, 2))},
7876
),
7977
(
8078
{
81-
"A": (onnx.TensorProto.DOUBLE, ("n", "m")),
82-
"B": (onnx.TensorProto.DOUBLE, (1, "n", "m")),
79+
"A": (TensorProto.DOUBLE, ("n", "m")),
80+
"B": (TensorProto.DOUBLE, (1, "n", "m")),
8381
},
84-
{"C": (onnx.TensorProto.DOUBLE, (1, "n", "m"))},
82+
{"C": (TensorProto.DOUBLE, (1, "n", "m"))},
8583
),
8684
(
8785
{
88-
"A": (onnx.TensorProto.FLOAT, ("x", 2)),
89-
"B": (onnx.TensorProto.FLOAT, ("y", 2)),
86+
"A": (TensorProto.FLOAT, ("x", 2)),
87+
"B": (TensorProto.FLOAT, ("y", 2)),
9088
},
91-
{"C": (onnx.TensorProto.FLOAT, (None, 2))},
89+
{"C": (TensorProto.FLOAT, (None, 2))},
9290
),
9391
]
9492
for ins, outs in cases:
9593
assert _run_case(ADD_SCHEMA, ["A", "B"], ["C"], _to_tensor_types(ins)) == _to_tensor_types(outs) # type: ignore
9694

9795
def test_add_inference_raises_errors(self) -> None:
98-
with self.assertRaises(onnx.checker.ValidationError):
96+
with self.assertRaises(ValidationError):
9997
_run_case(
10098
ADD_SCHEMA,
10199
["A"],
102100
["C"],
103-
_to_tensor_types({"A": (onnx.TensorProto.FLOAT, (3, 4))}),
101+
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
104102
)
105-
with self.assertRaises(onnx.checker.ValidationError):
103+
with self.assertRaises(ValidationError):
106104
_run_case(
107105
ADD_SCHEMA,
108106
["A", "B"],
109107
["C"],
110-
_to_tensor_types(
111-
{"A": (onnx.TensorProto.FLOAT, (3, 4)), "B": (2, (3, 4))}
112-
),
108+
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4)), "B": (2, (3, 4))}),
113109
)
114-
with self.assertRaises(onnx.shape_inference.InferenceError):
110+
with self.assertRaises(InferenceError):
115111
_run_case(
116112
ADD_SCHEMA,
117113
["A", "B"],
118114
["C"],
119115
_to_tensor_types(
120116
{
121-
"A": (onnx.TensorProto.FLOAT, (2, 4)),
122-
"B": (onnx.TensorProto.FLOAT, (3, 4)),
117+
"A": (TensorProto.FLOAT, (2, 4)),
118+
"B": (TensorProto.FLOAT, (3, 4)),
123119
}
124120
),
125121
)
@@ -128,7 +124,7 @@ def test_add_inference_raises_errors(self) -> None:
128124
ADD_SCHEMA,
129125
["A", "B"],
130126
["C"],
131-
_to_tensor_types({"A": (onnx.TensorProto.FLOAT, (3, 4))}),
127+
_to_tensor_types({"A": (TensorProto.FLOAT, (3, 4))}),
132128
)
133129

134130
def test_reshape_inference(self) -> None:
@@ -138,12 +134,63 @@ def test_reshape_inference(self) -> None:
138134
["y"],
139135
_to_tensor_types(
140136
{
141-
"x": (onnx.TensorProto.FLOAT, (5, 4)),
142-
"t": (onnx.TensorProto.INT64, (3,)),
137+
"x": (TensorProto.FLOAT, (5, 4)),
138+
"t": (TensorProto.INT64, (3,)),
143139
}
144140
),
145141
{"t": np.array([2, 2, 5], dtype=np.int64)},
146-
) == _to_tensor_types({"y": (onnx.TensorProto.FLOAT, (2, 2, 5))})
142+
) == _to_tensor_types({"y": (TensorProto.FLOAT, (2, 2, 5))})
143+
144+
def test_scan_inference_with_subgraph(self) -> None:
145+
seq_len = "sequence"
146+
input_size = 2
147+
loop_state_size = 3
148+
149+
input_value_infos = [
150+
make_tensor_value_info("loop_state_in", TensorProto.UNDEFINED, None),
151+
make_tensor_value_info("input", TensorProto.UNDEFINED, None),
152+
make_tensor_value_info("outer", TensorProto.UNDEFINED, None),
153+
]
154+
output_value_infos = [
155+
make_tensor_value_info("loop_state_out", TensorProto.UNDEFINED, None),
156+
make_tensor_value_info("output", TensorProto.FLOAT, (seq_len, input_size)),
157+
]
158+
159+
subgraph = make_graph(
160+
[
161+
make_node("Identity", ["loop_state_in"], ["loop_state_out"]),
162+
make_node("Add", ["input", "outer"], ["output"]),
163+
],
164+
"subgraph",
165+
input_value_infos,
166+
output_value_infos,
167+
)
168+
169+
assert infer_node_outputs(
170+
get_schema("Scan", 9),
171+
make_node(
172+
"Scan",
173+
["loop_state_orig", "scan_input", "scan_outer"],
174+
["loop_state_final", "scan_output"],
175+
num_scan_inputs=1,
176+
body=subgraph,
177+
),
178+
_to_tensor_types(
179+
{
180+
"loop_state_orig": (TensorProto.FLOAT, (loop_state_size,)),
181+
"scan_input": (TensorProto.FLOAT, (seq_len, input_size)),
182+
"scan_outer": (TensorProto.FLOAT, (input_size,)),
183+
}
184+
),
185+
# Same as default value in Scan-9
186+
opset_imports=[make_opsetid("", 9)],
187+
ir_version=4,
188+
) == _to_tensor_types(
189+
{
190+
"loop_state_final": (TensorProto.FLOAT, (loop_state_size,)),
191+
"scan_output": (TensorProto.FLOAT, (seq_len, input_size)),
192+
}
193+
)
147194

148195

149196
if __name__ == "__main__":

0 commit comments

Comments
 (0)