Skip to content

Commit 8458a35

Browse files
committed
#237 Add helper dataclass converters
1 parent 30b82e2 commit 8458a35

File tree

2 files changed

+104
-10
lines changed

2 files changed

+104
-10
lines changed

xarray_dataclasses/api.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,31 @@
22

33

44
# standard library
5-
from typing import Any, Callable, overload
6-
5+
from dataclasses import replace
6+
from typing import Any, ForwardRef, Literal, Optional, overload
77

88
# dependencies
9-
from xarray import DataArray, Dataset
10-
from .typing import DataClass, DataClassOf, PAny, TDataArray, TDataset, TXarray
9+
from dataspecs import ID, ROOT, Spec, Specs
10+
from numpy import asarray, array
11+
from typing_extensions import get_args, get_origin
12+
from xarray import DataArray, Dataset, Variable
13+
from .typing import (
14+
DataClass,
15+
DataClassOf,
16+
Factory,
17+
HashDict,
18+
PAny,
19+
TAny,
20+
TDataArray,
21+
TDataset,
22+
TXarray,
23+
Tag,
24+
)
25+
26+
27+
# type hints
28+
Attrs = HashDict[Any]
29+
Vars = HashDict[Variable]
1130

1231

1332
@overload
@@ -24,7 +43,7 @@ def asdataarray(
2443
obj: DataClass[PAny],
2544
/,
2645
*,
27-
factory: Callable[..., TDataArray],
46+
factory: Factory[TDataArray],
2847
) -> TDataArray: ...
2948

3049

@@ -56,7 +75,7 @@ def asdataset(
5675
obj: DataClass[PAny],
5776
/,
5877
*,
59-
factory: Callable[..., TDataset],
78+
factory: Factory[TDataset],
6079
) -> TDataset: ...
6180

6281

@@ -88,10 +107,84 @@ def asxarray(
88107
obj: DataClass[PAny],
89108
/,
90109
*,
91-
factory: Callable[..., TXarray],
110+
factory: Factory[TXarray],
92111
) -> TXarray: ...
93112

94113

95114
def asxarray(obj: Any, /, *, factory: Any = None) -> Any:
96115
"""Create a DataArray/set object from a dataclass object."""
97116
...
117+
118+
119+
def get_attrs(specs: Specs[Spec[Any]], /, *, at: ID = ROOT) -> Attrs:
120+
"""Create attributes from data specs."""
121+
attrs: Attrs = {}
122+
123+
for spec in specs[at.children][Tag.ATTR]:
124+
options = specs[spec.id.children]
125+
factory = maybe(options[Tag.FACTORY].unique).data or identity
126+
name = maybe(options[Tag.NAME].unique).data or spec.id.name
127+
128+
if Tag.MULTIPLE not in spec.tags:
129+
spec = replace(spec, data={name: spec.data})
130+
131+
for name, data in spec[HashDict[Any]].data.items():
132+
attrs[name] = factory(data)
133+
134+
return attrs
135+
136+
137+
def get_vars(specs: Specs[Spec[Any]], of: Tag, /, *, at: ID = ROOT) -> Vars:
138+
"""Create variables of given tag from data specs."""
139+
vars: Vars = {}
140+
141+
for spec in specs[at.children][of]:
142+
options = specs[spec.id.children]
143+
attrs = get_attrs(specs, at=spec.id)
144+
factory = maybe(options[Tag.FACTORY].unique).data or Variable
145+
name = maybe(options[Tag.NAME].unique).data or spec.id.name
146+
147+
if (type_ := maybe(options[Tag.DIMS].unique).type) is None:
148+
raise RuntimeError("Could not find any data spec for dims.")
149+
elif get_origin(type_) is tuple:
150+
dims = tuple(str(unwrap(arg)) for arg in get_args(type_))
151+
else:
152+
dims = (str(unwrap(type_)),)
153+
154+
if (type_ := maybe(options[Tag.DTYPE].unique).type) is None:
155+
raise RuntimeError("Could not find any data spec for dims.")
156+
elif type_ is type(None) or type_ is Any:
157+
dtype = None
158+
else:
159+
dtype = unwrap(type_)
160+
161+
if Tag.MULTIPLE not in spec.tags:
162+
spec = replace(spec, data={name: spec.data})
163+
164+
for name, data in spec[HashDict[Any]].data.items():
165+
if not (data := asarray(data, dtype)).ndim:
166+
data = array(data, ndmin=len(dims))
167+
168+
vars[name] = factory(attrs=attrs, data=data, dims=dims)
169+
170+
return vars
171+
172+
173+
def identity(obj: TAny, /) -> TAny:
174+
"""Identity function used for the default factory."""
175+
return obj
176+
177+
178+
def maybe(obj: Optional[Spec[Any]], /) -> Spec[Any]:
179+
"""Return a dummy (``None``-filled) data spec if an object is not one."""
180+
return Spec(ROOT, (), None, None) if obj is None else obj
181+
182+
183+
def unwrap(obj: Any, /) -> Any:
184+
"""Unwrap if an object is a literal or a forward reference."""
185+
if get_origin(obj) is Literal:
186+
return args[0] if len(args := get_args(obj)) == 1 else obj
187+
elif isinstance(obj, ForwardRef):
188+
return obj.__forward_arg__
189+
else:
190+
return obj

xarray_dataclasses/typing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
TDtype = TypeVar("TDtype", covariant=True)
4242
THashable = TypeVar("THashable", bound=Hashable)
4343
TXarray = TypeVar("TXarray", covariant=True, bound="Xarray")
44+
HashDict = dict[Hashable, TAny]
4445
Xarray = Union[DataArray, Dataset]
4546

4647

@@ -103,19 +104,19 @@ class Tag(TagBase):
103104
Attr = Annotated[TAny, Tag.ATTR]
104105
"""Type alias for an attribute of DataArray/set."""
105106

106-
Attrs = Annotated[dict[str, TAny], Tag.ATTR, Tag.MULTIPLE]
107+
Attrs = Annotated[HashDict[TAny], Tag.ATTR, Tag.MULTIPLE]
107108
"""Type alias for attributes of DataArray/set."""
108109

109110
Coord = Annotated[Arrayable[TDims, TDtype], Tag.COORD]
110111
"""Type alias for a coordinate of DataArray/set."""
111112

112-
Coords = Annotated[dict[str, Arrayable[TDims, TDtype]], Tag.COORD, Tag.MULTIPLE]
113+
Coords = Annotated[HashDict[Arrayable[TDims, TDtype]], Tag.COORD, Tag.MULTIPLE]
113114
"""Type alias for coordinates of DataArray/set."""
114115

115116
Data = Annotated[Arrayable[TDims, TDtype], Tag.DATA]
116117
"""Type alias for a data object of DataArray/set."""
117118

118-
DataVars = Annotated[dict[str, Arrayable[TDims, TDtype]], Tag.DATA, Tag.MULTIPLE]
119+
DataVars = Annotated[HashDict[Arrayable[TDims, TDtype]], Tag.DATA, Tag.MULTIPLE]
119120
"""Type alias for data objects of DataArray/set."""
120121

121122
Factory = Annotated[Callable[..., TXarray], Tag.FACTORY]

0 commit comments

Comments
 (0)