Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 65424cc

Browse files
rahul-tulidbogunowicz
authored andcommitted
Move evaluator registry (#411)
1 parent 79d5de2 commit 65424cc

File tree

6 files changed

+329
-0
lines changed

6 files changed

+329
-0
lines changed

src/sparsezoo/evaluation/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# flake8: noqa
16+
17+
from .registry import *

src/sparsezoo/evaluation/registry.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Implementation of a registry for evaluation functions
16+
"""
17+
18+
from sparsezoo.utils.registry import RegistryMixin
19+
20+
21+
__all__ = ["EvaluationRegistry"]
22+
23+
24+
class EvaluationRegistry(RegistryMixin):
25+
"""
26+
Extends the RegistryMixin to enable registering
27+
and loading of evaluation functions.
28+
"""

src/sparsezoo/evaluation/results.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List, Optional, Union
16+
17+
import numpy
18+
import yaml
19+
from pydantic import BaseModel, Field
20+
21+
22+
__all__ = [
23+
"Metric",
24+
"Dataset",
25+
"EvalSample",
26+
"Evaluation",
27+
"Result",
28+
"save_result",
29+
]
30+
31+
32+
def prep_for_serialization(
33+
data: Union[BaseModel, numpy.ndarray, list]
34+
) -> Union[BaseModel, list]:
35+
"""
36+
Prepares input data for JSON serialization by converting any numpy array
37+
field to a list. For large numpy arrays, this operation will take a while to run.
38+
39+
:param data: data to that is to be processed before
40+
serialization. Nested objects are supported.
41+
:return: Pipeline_outputs with potential numpy arrays
42+
converted to lists
43+
"""
44+
if isinstance(data, BaseModel):
45+
for field_name in data.__fields__.keys():
46+
field_value = getattr(data, field_name)
47+
if isinstance(field_value, (numpy.ndarray, BaseModel, list)):
48+
setattr(
49+
data,
50+
field_name,
51+
prep_for_serialization(field_value),
52+
)
53+
54+
elif isinstance(data, numpy.ndarray):
55+
data = data.tolist()
56+
57+
elif isinstance(data, list):
58+
for i, value in enumerate(data):
59+
data[i] = prep_for_serialization(value)
60+
61+
elif isinstance(data, dict):
62+
for key, value in data.items():
63+
data[key] = prep_for_serialization(value)
64+
65+
return data
66+
67+
68+
class Metric(BaseModel):
69+
name: str = Field(description="Name of the metric")
70+
value: float = Field(description="Value of the metric")
71+
72+
73+
class Dataset(BaseModel):
74+
type: Optional[str] = Field(description="Type of dataset")
75+
name: str = Field(description="Name of the dataset")
76+
config: Any = Field(description="Configuration for the dataset")
77+
split: Optional[str] = Field(description="Split of the dataset")
78+
79+
80+
class EvalSample(BaseModel):
81+
input: Any = Field(description="Sample input to the model")
82+
output: Any = Field(description="Sample output from the model")
83+
84+
85+
class Evaluation(BaseModel):
86+
task: str = Field(
87+
description="Name of the evaluation integration "
88+
"that the evaluation was performed on"
89+
)
90+
dataset: Dataset = Field(description="Dataset that the evaluation was performed on")
91+
metrics: List[Metric] = Field(description="List of metrics for the evaluation")
92+
samples: Optional[List[EvalSample]] = Field(
93+
description="List of samples for the evaluation"
94+
)
95+
96+
97+
class Result(BaseModel):
98+
formatted: List[Evaluation] = Field(
99+
description="Evaluation result represented in the unified, structured format"
100+
)
101+
raw: Any = Field(
102+
description="Evaluation result represented in the raw format "
103+
"(characteristic for the specific evaluation integration)"
104+
)
105+
106+
107+
def save_result(
108+
result: Result,
109+
save_path: str,
110+
save_format: str = "json",
111+
):
112+
"""
113+
Saves a list of Evaluation objects to a file in the specified format.
114+
115+
:param result: Result object to save
116+
:param save_path: Path to save the evaluations to.
117+
:param save_format: Format to save the evaluations in.
118+
:return: The serialized evaluations
119+
"""
120+
# prepare the Result object for serialization
121+
result: Result = prep_for_serialization(result)
122+
if save_format == "json":
123+
_save_to_json(result, save_path)
124+
elif save_format == "yaml":
125+
_save_to_yaml(result, save_path)
126+
else:
127+
NotImplementedError("Currently only json and yaml formats are supported")
128+
129+
130+
def _save_to_json(result: Result, save_path: str):
131+
_save(result.json(), save_path, expected_ext=".json")
132+
133+
134+
def _save_to_yaml(result: Result, save_path: str):
135+
_save(yaml.dump(result.dict()), save_path, expected_ext=".yaml")
136+
137+
138+
def _save(data: str, save_path: str, expected_ext: str):
139+
if not save_path.endswith(expected_ext):
140+
raise ValueError(f"save_path must end with extension: {expected_ext}")
141+
with open(save_path, "w") as f:
142+
f.write(data)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from sparsezoo.evaluation.registry import EvaluationRegistry
17+
18+
19+
@pytest.fixture
20+
def registry_with_foo():
21+
class Registry(EvaluationRegistry):
22+
pass
23+
24+
@Registry.register()
25+
def foo(*args, **kwargs):
26+
return "foo"
27+
28+
return Registry
29+
30+
31+
@pytest.fixture
32+
def registry_with_buzz():
33+
class Registry(EvaluationRegistry):
34+
pass
35+
36+
@Registry.register(name=["buzz", "buzzer"])
37+
def buzz(*args, **kwargs):
38+
return "buzz"
39+
40+
return Registry
41+
42+
43+
def test_get_foo_from_registry(registry_with_foo):
44+
eval_function = registry_with_foo.load_from_registry("foo")
45+
assert eval_function() == "foo"
46+
47+
48+
def test_get_multiple_buzz_from_registry(registry_with_buzz):
49+
eval_function_1 = registry_with_buzz.load_from_registry("buzz")
50+
eval_function_2 = registry_with_buzz.load_from_registry("buzzer")
51+
assert eval_function_1() == eval_function_2() == "buzz"
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
17+
import numpy as np
18+
import pytest
19+
import yaml
20+
21+
from sparsezoo.evaluation.results import (
22+
Dataset,
23+
EvalSample,
24+
Evaluation,
25+
Metric,
26+
Result,
27+
save_result,
28+
)
29+
30+
31+
@pytest.fixture()
32+
def evaluations():
33+
return [
34+
Evaluation(
35+
task="task_1",
36+
dataset=Dataset(
37+
type="type_1", name="name_1", config="config_1", split="split_1"
38+
),
39+
metrics=[Metric(name="metric_name_1", value=1.0)],
40+
samples=[EvalSample(input=np.array([[5]]), output=5)],
41+
),
42+
Evaluation(
43+
task="task_2",
44+
dataset=Dataset(
45+
type="type_2", name="name_2", config="config_2", split="split_2"
46+
),
47+
metrics=[
48+
Metric(name="metric_name_2", value=2.0),
49+
Metric(name="metric_name_3", value=3.0),
50+
],
51+
samples=[
52+
EvalSample(input=np.array([[10.0]]), output=10.0),
53+
EvalSample(input=np.array([[20.0]]), output=20.0),
54+
],
55+
),
56+
]
57+
58+
59+
@pytest.fixture()
60+
def result(evaluations):
61+
return Result(formatted=evaluations, raw="dummy_raw_evaluation")
62+
63+
64+
def test_serialize_result_json(tmp_path, result):
65+
path_to_file = tmp_path / "result.json"
66+
save_result(result=result, save_format="json", save_path=path_to_file.as_posix())
67+
68+
with open(path_to_file.as_posix(), "r") as f:
69+
reloaded_results = json.load(f)
70+
assert reloaded_results == result.dict()
71+
72+
73+
def test_serialize_result_yaml(tmp_path, result):
74+
path_to_file = tmp_path / "result.yaml"
75+
save_result(result=result, save_format="yaml", save_path=path_to_file.as_posix())
76+
with open(path_to_file.as_posix(), "r") as f:
77+
reloaded_results = yaml.safe_load(f)
78+
assert reloaded_results == result.dict()

0 commit comments

Comments
 (0)