3
3
4
4
import numpy as np
5
5
6
- import onnx
7
- import onnx .numpy_helper
8
- import onnx .shape_inference
6
+ from onnx import TensorProto , TypeProto
7
+ from onnx .checker import ValidationError
8
+ from onnx .defs import OpSchema , get_all_schemas_with_history , get_schema
9
+ from onnx .helper import (
10
+ make_graph ,
11
+ make_node ,
12
+ make_opsetid ,
13
+ make_tensor_type_proto ,
14
+ make_tensor_value_info ,
15
+ )
16
+ from onnx .numpy_helper import from_array
17
+ from onnx .shape_inference import InferenceError , infer_node_outputs
9
18
10
19
ADD_SCHEMA = max (
11
- (
12
- s
13
- for s in onnx .defs .get_all_schemas_with_history ()
14
- if s .name == "Add" and s .domain == ""
15
- ),
20
+ (s for s in get_all_schemas_with_history () if s .name == "Add" and s .domain == "" ),
16
21
key = lambda s : s .since_version ,
17
22
)
18
23
RESHAPE_SCHEMA = max (
19
24
(
20
25
s
21
- for s in onnx . defs . get_all_schemas_with_history ()
26
+ for s in get_all_schemas_with_history ()
22
27
if s .name == "Reshape" and s .domain == ""
23
28
),
24
29
key = lambda s : s .since_version ,
25
30
)
26
31
27
- _tensor = onnx .helper .make_tensor_type_proto
28
-
29
32
30
33
def _to_tensor_types (
31
34
tensor_types : Dict [str , Tuple [int , Tuple [Union [int , str , None ], ...]]]
32
- ) -> Dict [str , onnx .TypeProto ]:
33
- return {
34
- key : onnx .helper .make_tensor_type_proto (* value )
35
- for key , value in tensor_types .items ()
36
- }
35
+ ) -> Dict [str , TypeProto ]:
36
+ return {key : make_tensor_type_proto (* value ) for key , value in tensor_types .items ()}
37
37
38
38
39
39
def _run_case (
40
- schema : onnx . defs . OpSchema ,
40
+ schema : OpSchema ,
41
41
input_names : List [str ],
42
42
output_names : List [str ],
43
- input_types : Dict [str , onnx . TypeProto ],
43
+ input_types : Dict [str , TypeProto ],
44
44
input_data : Optional [Dict [str , np .ndarray ]] = None ,
45
- ) -> Dict [str , onnx . TypeProto ]:
45
+ ) -> Dict [str , TypeProto ]:
46
46
if input_data is None :
47
47
input_data = {}
48
- return onnx . shape_inference . infer_node_outputs (
48
+ return infer_node_outputs (
49
49
schema ,
50
- onnx .helper .make_node (
51
- schema .name , input_names , output_names , domain = schema .domain
52
- ),
50
+ make_node (schema .name , input_names , output_names , domain = schema .domain ),
53
51
input_types ,
54
- {key : onnx . numpy_helper . from_array (arr ) for key , arr in input_data .items ()},
52
+ {key : from_array (arr ) for key , arr in input_data .items ()},
55
53
)
56
54
57
55
58
56
class TestInferenceFunctionCall (unittest .TestCase ):
59
57
def test_add_inference (self ) -> None :
60
58
cases = [
61
59
(
62
- {"A" : (onnx . TensorProto .FLOAT , ()), "B" : (onnx . TensorProto .FLOAT , ())},
63
- {"C" : (onnx . TensorProto .FLOAT , ())},
60
+ {"A" : (TensorProto .FLOAT , ()), "B" : (TensorProto .FLOAT , ())},
61
+ {"C" : (TensorProto .FLOAT , ())},
64
62
),
65
63
(
66
64
{
67
- "A" : (onnx . TensorProto .FLOAT , (None , 2 )),
68
- "B" : (onnx . TensorProto .FLOAT , (2 ,)),
65
+ "A" : (TensorProto .FLOAT , (None , 2 )),
66
+ "B" : (TensorProto .FLOAT , (2 ,)),
69
67
},
70
- {"C" : (onnx . TensorProto .FLOAT , (None , 2 ))},
68
+ {"C" : (TensorProto .FLOAT , (None , 2 ))},
71
69
),
72
70
(
73
71
{
74
- "A" : (onnx . TensorProto .FLOAT , (None , 2 )),
75
- "B" : (onnx . TensorProto .FLOAT , (1 , 2 )),
72
+ "A" : (TensorProto .FLOAT , (None , 2 )),
73
+ "B" : (TensorProto .FLOAT , (1 , 2 )),
76
74
},
77
- {"C" : (onnx . TensorProto .FLOAT , (None , 2 ))},
75
+ {"C" : (TensorProto .FLOAT , (None , 2 ))},
78
76
),
79
77
(
80
78
{
81
- "A" : (onnx . TensorProto .DOUBLE , ("n" , "m" )),
82
- "B" : (onnx . TensorProto .DOUBLE , (1 , "n" , "m" )),
79
+ "A" : (TensorProto .DOUBLE , ("n" , "m" )),
80
+ "B" : (TensorProto .DOUBLE , (1 , "n" , "m" )),
83
81
},
84
- {"C" : (onnx . TensorProto .DOUBLE , (1 , "n" , "m" ))},
82
+ {"C" : (TensorProto .DOUBLE , (1 , "n" , "m" ))},
85
83
),
86
84
(
87
85
{
88
- "A" : (onnx . TensorProto .FLOAT , ("x" , 2 )),
89
- "B" : (onnx . TensorProto .FLOAT , ("y" , 2 )),
86
+ "A" : (TensorProto .FLOAT , ("x" , 2 )),
87
+ "B" : (TensorProto .FLOAT , ("y" , 2 )),
90
88
},
91
- {"C" : (onnx . TensorProto .FLOAT , (None , 2 ))},
89
+ {"C" : (TensorProto .FLOAT , (None , 2 ))},
92
90
),
93
91
]
94
92
for ins , outs in cases :
95
93
assert _run_case (ADD_SCHEMA , ["A" , "B" ], ["C" ], _to_tensor_types (ins )) == _to_tensor_types (outs ) # type: ignore
96
94
97
95
def test_add_inference_raises_errors (self ) -> None :
98
- with self .assertRaises (onnx . checker . ValidationError ):
96
+ with self .assertRaises (ValidationError ):
99
97
_run_case (
100
98
ADD_SCHEMA ,
101
99
["A" ],
102
100
["C" ],
103
- _to_tensor_types ({"A" : (onnx . TensorProto .FLOAT , (3 , 4 ))}),
101
+ _to_tensor_types ({"A" : (TensorProto .FLOAT , (3 , 4 ))}),
104
102
)
105
- with self .assertRaises (onnx . checker . ValidationError ):
103
+ with self .assertRaises (ValidationError ):
106
104
_run_case (
107
105
ADD_SCHEMA ,
108
106
["A" , "B" ],
109
107
["C" ],
110
- _to_tensor_types (
111
- {"A" : (onnx .TensorProto .FLOAT , (3 , 4 )), "B" : (2 , (3 , 4 ))}
112
- ),
108
+ _to_tensor_types ({"A" : (TensorProto .FLOAT , (3 , 4 )), "B" : (2 , (3 , 4 ))}),
113
109
)
114
- with self .assertRaises (onnx . shape_inference . InferenceError ):
110
+ with self .assertRaises (InferenceError ):
115
111
_run_case (
116
112
ADD_SCHEMA ,
117
113
["A" , "B" ],
118
114
["C" ],
119
115
_to_tensor_types (
120
116
{
121
- "A" : (onnx . TensorProto .FLOAT , (2 , 4 )),
122
- "B" : (onnx . TensorProto .FLOAT , (3 , 4 )),
117
+ "A" : (TensorProto .FLOAT , (2 , 4 )),
118
+ "B" : (TensorProto .FLOAT , (3 , 4 )),
123
119
}
124
120
),
125
121
)
@@ -128,7 +124,7 @@ def test_add_inference_raises_errors(self) -> None:
128
124
ADD_SCHEMA ,
129
125
["A" , "B" ],
130
126
["C" ],
131
- _to_tensor_types ({"A" : (onnx . TensorProto .FLOAT , (3 , 4 ))}),
127
+ _to_tensor_types ({"A" : (TensorProto .FLOAT , (3 , 4 ))}),
132
128
)
133
129
134
130
def test_reshape_inference (self ) -> None :
@@ -138,12 +134,63 @@ def test_reshape_inference(self) -> None:
138
134
["y" ],
139
135
_to_tensor_types (
140
136
{
141
- "x" : (onnx . TensorProto .FLOAT , (5 , 4 )),
142
- "t" : (onnx . TensorProto .INT64 , (3 ,)),
137
+ "x" : (TensorProto .FLOAT , (5 , 4 )),
138
+ "t" : (TensorProto .INT64 , (3 ,)),
143
139
}
144
140
),
145
141
{"t" : np .array ([2 , 2 , 5 ], dtype = np .int64 )},
146
- ) == _to_tensor_types ({"y" : (onnx .TensorProto .FLOAT , (2 , 2 , 5 ))})
142
+ ) == _to_tensor_types ({"y" : (TensorProto .FLOAT , (2 , 2 , 5 ))})
143
+
144
+ def test_scan_inference_with_subgraph (self ) -> None :
145
+ seq_len = "sequence"
146
+ input_size = 2
147
+ loop_state_size = 3
148
+
149
+ input_value_infos = [
150
+ make_tensor_value_info ("loop_state_in" , TensorProto .UNDEFINED , None ),
151
+ make_tensor_value_info ("input" , TensorProto .UNDEFINED , None ),
152
+ make_tensor_value_info ("outer" , TensorProto .UNDEFINED , None ),
153
+ ]
154
+ output_value_infos = [
155
+ make_tensor_value_info ("loop_state_out" , TensorProto .UNDEFINED , None ),
156
+ make_tensor_value_info ("output" , TensorProto .FLOAT , (seq_len , input_size )),
157
+ ]
158
+
159
+ subgraph = make_graph (
160
+ [
161
+ make_node ("Identity" , ["loop_state_in" ], ["loop_state_out" ]),
162
+ make_node ("Add" , ["input" , "outer" ], ["output" ]),
163
+ ],
164
+ "subgraph" ,
165
+ input_value_infos ,
166
+ output_value_infos ,
167
+ )
168
+
169
+ assert infer_node_outputs (
170
+ get_schema ("Scan" , 9 ),
171
+ make_node (
172
+ "Scan" ,
173
+ ["loop_state_orig" , "scan_input" , "scan_outer" ],
174
+ ["loop_state_final" , "scan_output" ],
175
+ num_scan_inputs = 1 ,
176
+ body = subgraph ,
177
+ ),
178
+ _to_tensor_types (
179
+ {
180
+ "loop_state_orig" : (TensorProto .FLOAT , (loop_state_size ,)),
181
+ "scan_input" : (TensorProto .FLOAT , (seq_len , input_size )),
182
+ "scan_outer" : (TensorProto .FLOAT , (input_size ,)),
183
+ }
184
+ ),
185
+ # Same as default value in Scan-9
186
+ opset_imports = [make_opsetid ("" , 9 )],
187
+ ir_version = 4 ,
188
+ ) == _to_tensor_types (
189
+ {
190
+ "loop_state_final" : (TensorProto .FLOAT , (loop_state_size ,)),
191
+ "scan_output" : (TensorProto .FLOAT , (seq_len , input_size )),
192
+ }
193
+ )
147
194
148
195
149
196
if __name__ == "__main__" :
0 commit comments