Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 051493d

Browse files
authored
add device_info to _extract_throughput (#294) (#295)
1 parent b060d3e commit 051493d

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

src/sparsezoo/deployment_package/utils/extractors.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import logging
2020
from types import MappingProxyType
21+
from typing import Optional
2122

2223
from sparsezoo import Model
2324

@@ -32,14 +33,20 @@ def _size(model: Model) -> float:
3233
return size
3334

3435

35-
def _throughput(model: Model, num_cores: int = 24, batch_size: int = 64) -> float:
36+
def _throughput(
37+
model: Model,
38+
num_cores: int = 24,
39+
batch_size: int = 64,
40+
device_info: Optional[str] = None,
41+
) -> float:
3642
# num_cores : 24, batch_size: 64 are standard defaults in sparsezoo
3743
throughput_results = getattr(model, "validation_results", {}).get("throughput", [])
3844

3945
for throughput_result in throughput_results:
4046
if (
4147
throughput_result.batch_size == batch_size
4248
and throughput_result.num_cores == num_cores
49+
and (device_info is None or (throughput_result.device_info == device_info))
4350
):
4451
return throughput_result.recorded_value
4552

tests/sparsezoo/deployment_package/utils/test_extractors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,17 @@ def model():
6464

6565

6666
@pytest.mark.parametrize(
67-
"num_cores, batch_size, expected",
67+
"num_cores,batch_size,device_info,expected",
6868
[
69-
(24, 64, 1948.45),
69+
(24, 64, "c6i.12xlarge", 1948.45),
7070
],
7171
)
72-
def test_throughput_extractor(model, num_cores, batch_size, expected):
72+
def test_throughput_extractor(model, num_cores, batch_size, device_info, expected):
7373
actual_throughput = _throughput(
7474
model=model,
7575
num_cores=num_cores,
7676
batch_size=batch_size,
77+
device_info=device_info,
7778
)
7879
assert actual_throughput == expected
7980

0 commit comments

Comments
 (0)