Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit aaeb280

Browse files
dsikkaDipika Sikkadbogunowicz
authored
[sparsezoo.analyze] Fix pathway such that it works for larger models (#437) (#438)
* fix analyze to work with larger models * update for failing tests; add comments * Update src/sparsezoo/utils/onnx/external_data.py --------- Co-authored-by: Dipika Sikka <dipikasikka1@gmail.coom> Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com>
1 parent fccb742 commit aaeb280

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

src/sparsezoo/analyze_v2/model_analysis.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,20 +146,25 @@ def analyze(path: str, download_path: Optional[str] = None) -> "ModelAnalysis":
146146
:param path: .onnx path or stub
147147
"""
148148
if path.endswith(".onnx"):
149-
onnx_model = load_model(path)
149+
onnx_model = load_model(path, load_external_data=False)
150+
onnx_model_path = path
150151
elif is_stub(path):
151152
model = Model(path, download_path)
152153
onnx_model_path = model.onnx_model.path
153-
onnx_model = onnx.load(onnx_model_path)
154+
onnx_model = onnx.load(onnx_model_path, load_external_data=False)
154155
else:
155156
raise ValueError(f"{path} is not a valid argument")
156157

157-
model_graph = ONNXGraph(onnx_model)
158-
node_shapes, _ = extract_node_shapes_and_dtypes(model_graph.model)
158+
# just need graph to get shape information; dont load external data
159+
node_shapes, _ = extract_node_shapes_and_dtypes(onnx_model, onnx_model_path)
159160

160161
summary_analysis = SummaryAnalysis()
161162
node_analyses = {}
162163

164+
# load external data for node analysis
165+
onnx_model = onnx.load(onnx_model_path)
166+
model_graph = ONNXGraph(onnx_model)
167+
163168
for graph_order, node in enumerate(model_graph.nodes):
164169
node_id = extract_node_id(node)
165170
node_shape = node_shapes.get(node_id)

src/sparsezoo/utils/node_inference.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
import logging
2020
from copy import deepcopy
21-
from typing import Any, Dict, List, NamedTuple, Tuple, Union
21+
from pathlib import Path
22+
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
2223

2324
import numpy
2425
import onnx
@@ -60,13 +61,14 @@
6061

6162

6263
def extract_nodes_shapes_and_dtypes_ort(
63-
model: ModelProto,
64+
model: ModelProto, path: Optional[str] = None
6465
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
6566
"""
6667
Creates a modified model to expose intermediate outputs and runs an ONNX Runtime
6768
InferenceSession to obtain the output shape of each node.
6869
6970
:param model: an ONNX model
71+
:param path: absolute path to the original onnx model
7072
:return: a list of NodeArg with their shape exposed
7173
"""
7274
import onnxruntime
@@ -79,11 +81,24 @@ def extract_nodes_shapes_and_dtypes_ort(
7981
)
8082
model_copy.graph.output.append(intermediate_layer_value_info)
8183

84+
# using the ModelProto does not work for large models when running the session
85+
# have to save again and pass the new path to the inference session
8286
sess_options = onnxruntime.SessionOptions()
8387
sess_options.log_severity_level = 3
84-
sess = onnxruntime.InferenceSession(
85-
model_copy.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
86-
)
88+
89+
if path:
90+
parent_dir = Path(path).parent.absolute()
91+
new_path = parent_dir / "model_new.onnx"
92+
onnx.save(model_copy, new_path, save_as_external_data=True)
93+
sess = onnxruntime.InferenceSession(
94+
new_path, sess_options, providers=onnxruntime.get_available_providers()
95+
)
96+
else:
97+
sess = onnxruntime.InferenceSession(
98+
model_copy.SerializeToString(),
99+
sess_options,
100+
providers=onnxruntime.get_available_providers(),
101+
)
87102

88103
input_value_dict = {}
89104
for input in model_copy.graph.input:
@@ -166,19 +181,20 @@ def extract_nodes_shapes_and_dtypes_shape_inference(
166181

167182

168183
def extract_nodes_shapes_and_dtypes(
169-
model: ModelProto,
184+
model: ModelProto, path: Optional[str] = None
170185
) -> Tuple[Dict[str, List[List[int]]], Dict[str, numpy.dtype]]:
171186
"""
172187
Uses ONNX Runtime or shape inference to infer output shapes and dtypes from model
173188
174189
:param model: model to extract output values from
190+
:param path: absolute path to the original onnx model
175191
:return: output shapes and output data types
176192
"""
177193
output_shapes = None
178194
output_dtypes = None
179195

180196
try:
181-
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model)
197+
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes_ort(model, path)
182198
except Exception as err:
183199
_LOGGER.warning(f"Extracting shapes using ONNX Runtime session failed: {err}")
184200

@@ -306,18 +322,19 @@ def collate_output_dtypes(
306322

307323

308324
def extract_node_shapes_and_dtypes(
309-
model: ModelProto,
325+
model: ModelProto, path: Optional[str] = None
310326
) -> Tuple[Dict[str, NodeShape], Dict[str, NodeDataType]]:
311327
"""
312328
Extracts the shape and dtype information for each node as NodeShape objects
313329
and numpy dtypes.
314330
315331
:param model: the loaded onnx.ModelProto to extract node shape information from
332+
:param path: absolute path to the original onnx model
316333
:return: a mapping of node id to a NodeShape object
317334
"""
318335

319336
# Obtains output shapes for each model's node
320-
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model)
337+
output_shapes, output_dtypes = extract_nodes_shapes_and_dtypes(model, path)
321338

322339
# Package output shapes into each node's inputs and outputs
323340
node_shapes = collate_output_shapes(model, output_shapes)

src/sparsezoo/utils/onnx/external_data.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,23 @@ def validate_onnx(model: Union[str, ModelProto]):
174174
raise ValueError(f"Invalid onnx model: {err}")
175175

176176

177-
def load_model(model: Union[str, ModelProto, Path]) -> ModelProto:
177+
def load_model(
178+
model: Union[str, ModelProto, Path], load_external_data: bool = True
179+
) -> ModelProto:
178180
"""
179181
Load an ONNX model from an onnx model file path. If a ModelProto
180182
is given, then it is returned.
181183
182184
:param model: the model proto or path to the model ONNX file to check for loading
185+
:param load_external_data: if a path is given, whether or not to also load the
186+
external model data
183187
:return: the loaded ONNX ModelProto
184188
"""
185189
if isinstance(model, ModelProto):
186190
return model
187191

188192
if isinstance(model, (Path, str)):
189-
return onnx.load(clean_path(model))
193+
return onnx.load(clean_path(model), load_external_data=load_external_data)
190194

191195
raise TypeError(f"unknown type given for model: {type(model)}")
192196

0 commit comments

Comments
 (0)