18
18
19
19
import logging
20
20
from copy import deepcopy
21
- from typing import Any , Dict , List , NamedTuple , Tuple , Union
21
+ from pathlib import Path
22
+ from typing import Any , Dict , List , NamedTuple , Optional , Tuple , Union
22
23
23
24
import numpy
24
25
import onnx
60
61
61
62
62
63
def extract_nodes_shapes_and_dtypes_ort (
63
- model : ModelProto ,
64
+ model : ModelProto , path : Optional [ str ] = None
64
65
) -> Tuple [Dict [str , List [List [int ]]], Dict [str , numpy .dtype ]]:
65
66
"""
66
67
Creates a modified model to expose intermediate outputs and runs an ONNX Runtime
67
68
InferenceSession to obtain the output shape of each node.
68
69
69
70
:param model: an ONNX model
71
+ :param path: absolute path to the original onnx model
70
72
:return: a list of NodeArg with their shape exposed
71
73
"""
72
74
import onnxruntime
@@ -79,11 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort(
79
81
)
80
82
model_copy .graph .output .append (intermediate_layer_value_info )
81
83
84
+ # using the ModelProto does not work for large models when running the session
85
+ # have to save again and pass the new path to the inference session
82
86
sess_options = onnxruntime .SessionOptions ()
83
87
sess_options .log_severity_level = 3
84
- sess = onnxruntime .InferenceSession (
85
- model_copy .SerializeToString (), sess_options , providers = ["CPUExecutionProvider" ]
86
- )
88
+
89
+ if path :
90
+ parent_dir = Path (path ).parent .absolute ()
91
+ new_path = parent_dir / "model_new.onnx"
92
+ onnx .save (model_copy , new_path , save_as_external_data = True )
93
+ sess = onnxruntime .InferenceSession (
94
+ new_path , sess_options , providers = onnxruntime .get_available_providers ()
95
+ )
96
+ else :
97
+ sess = onnxruntime .InferenceSession (
98
+ model_copy .SerializeToString (),
99
+ sess_options ,
100
+ providers = onnxruntime .get_available_providers (),
101
+ )
87
102
88
103
input_value_dict = {}
89
104
for input in model_copy .graph .input :
@@ -166,19 +181,20 @@ def extract_nodes_shapes_and_dtypes_shape_inference(
166
181
167
182
168
183
def extract_nodes_shapes_and_dtypes (
169
- model : ModelProto ,
184
+ model : ModelProto , path : Optional [ str ] = None
170
185
) -> Tuple [Dict [str , List [List [int ]]], Dict [str , numpy .dtype ]]:
171
186
"""
172
187
Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model
173
188
174
189
:param model: model to extract output values from
190
+ :param path: absolute path to the original onnx model
175
191
:return: output shapes and output data types
176
192
"""
177
193
output_shapes = None
178
194
output_dtypes = None
179
195
180
196
try :
181
- output_shapes , output_dtypes = extract_nodes_shapes_and_dtypes_ort (model )
197
+ output_shapes , output_dtypes = extract_nodes_shapes_and_dtypes_ort (model , path )
182
198
except Exception as err :
183
199
_LOGGER .warning (f"Extracting shapes using ONNX Runtime session failed: { err } " )
184
200
@@ -306,18 +322,19 @@ def collate_output_dtypes(
306
322
307
323
308
324
def extract_node_shapes_and_dtypes (
309
- model : ModelProto ,
325
+ model : ModelProto , path : Optional [ str ] = None
310
326
) -> Tuple [Dict [str , NodeShape ], Dict [str , NodeDataType ]]:
311
327
"""
312
328
Extracts the shape and dtype information for each node as NodeShape objects
313
329
and numpy dtypes.
314
330
315
331
:param model: the loaded onnx.ModelProto to extract node shape information from
332
+ :param path: absolute path to the original onnx model
316
333
:return: a mapping of node id to a NodeShape object
317
334
"""
318
335
319
336
# Obtains output shapes for each model's node
320
- output_shapes , output_dtypes = extract_nodes_shapes_and_dtypes (model )
337
+ output_shapes , output_dtypes = extract_nodes_shapes_and_dtypes (model , path )
321
338
322
339
# Package output shapes into each node's inputs and outputs
323
340
node_shapes = collate_output_shapes (model , output_shapes )
0 commit comments