Skip to content

Commit 30b82e2

Browse files
committed
#237 Add api module for dataclass converters
1 parent 5e59752 commit 30b82e2

File tree

3 files changed

+135
-2
lines changed

3 files changed

+135
-2
lines changed

xarray_dataclasses/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
__all__ = [
22
# submodules
3+
"api",
34
"typing",
45
# aliases
56
"Attr",
@@ -11,13 +12,18 @@
1112
"Factory",
1213
"Name",
1314
"Tag",
15+
"asdataarray",
16+
"asdataset",
17+
"asxarray",
1418
]
1519
__version__ = "2.0.0"
1620

1721

1822
# submodules
23+
from . import api
1924
from . import typing
2025

2126

2227
# aliases
28+
from .api import *
2329
from .typing import *

xarray_dataclasses/api.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
__all__ = ["asdataarray", "asdataset", "asxarray"]
2+
3+
4+
# standard library
5+
from typing import Any, Callable, overload
6+
7+
8+
# dependencies
9+
from xarray import DataArray, Dataset
10+
from .typing import DataClass, DataClassOf, PAny, TDataArray, TDataset, TXarray
11+
12+
13+
@overload
14+
def asdataarray(
15+
obj: DataClassOf[PAny, TDataArray],
16+
/,
17+
*,
18+
factory: None = None,
19+
) -> TDataArray: ...
20+
21+
22+
@overload
23+
def asdataarray(
24+
obj: DataClass[PAny],
25+
/,
26+
*,
27+
factory: Callable[..., TDataArray],
28+
) -> TDataArray: ...
29+
30+
31+
@overload
32+
def asdataarray(
33+
obj: DataClass[PAny],
34+
/,
35+
*,
36+
factory: None = None,
37+
) -> DataArray: ...
38+
39+
40+
def asdataarray(obj: Any, /, *, factory: Any = None) -> Any:
41+
"""Create a DataArray object from a dataclass object."""
42+
...
43+
44+
45+
@overload
46+
def asdataset(
47+
obj: DataClassOf[PAny, TDataset],
48+
/,
49+
*,
50+
factory: None = None,
51+
) -> TDataset: ...
52+
53+
54+
@overload
55+
def asdataset(
56+
obj: DataClass[PAny],
57+
/,
58+
*,
59+
factory: Callable[..., TDataset],
60+
) -> TDataset: ...
61+
62+
63+
@overload
64+
def asdataset(
65+
obj: DataClass[PAny],
66+
/,
67+
*,
68+
factory: None = None,
69+
) -> Dataset: ...
70+
71+
72+
def asdataset(obj: Any, /, *, factory: Any = None) -> Any:
73+
"""Create a Dataset object from a dataclass object."""
74+
...
75+
76+
77+
@overload
78+
def asxarray(
79+
obj: DataClassOf[PAny, TXarray],
80+
/,
81+
*,
82+
factory: None = None,
83+
) -> TXarray: ...
84+
85+
86+
@overload
87+
def asxarray(
88+
obj: DataClass[PAny],
89+
/,
90+
*,
91+
factory: Callable[..., TXarray],
92+
) -> TXarray: ...
93+
94+
95+
def asxarray(obj: Any, /, *, factory: Any = None) -> Any:
96+
"""Create a DataArray/set object from a dataclass object."""
97+
...

xarray_dataclasses/typing.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,18 @@
1313

1414
# standard library
1515
from collections.abc import Collection as Collection_, Hashable
16+
from dataclasses import Field
1617
from enum import auto
17-
from typing import Annotated, Callable, Protocol, TypeVar, Union
18+
from typing import (
19+
Annotated,
20+
Any,
21+
Callable,
22+
ClassVar,
23+
ParamSpec,
24+
Protocol,
25+
TypeVar,
26+
Union,
27+
)
1828

1929

2030
# dependencies
@@ -23,11 +33,14 @@
2333

2434

2535
# type hints
36+
PAny = ParamSpec("PAny")
2637
TAny = TypeVar("TAny")
38+
TDataArray = TypeVar("TDataArray", bound=DataArray)
39+
TDataset = TypeVar("TDataset", bound=Dataset)
2740
TDims = TypeVar("TDims", covariant=True)
2841
TDtype = TypeVar("TDtype", covariant=True)
2942
THashable = TypeVar("THashable", bound=Hashable)
30-
TXarray = TypeVar("TXarray", bound="Xarray")
43+
TXarray = TypeVar("TXarray", covariant=True, bound="Xarray")
3144
Xarray = Union[DataArray, Dataset]
3245

3346

@@ -37,6 +50,23 @@ class Collection(Collection_[TDtype], Protocol[TDims, TDtype]):
3750
pass
3851

3952

53+
class DataClass(Protocol[PAny]):
54+
"""Protocol for a dataclass object."""
55+
56+
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
57+
58+
def __init__(self, *args: PAny.args, **kwargs: PAny.kwargs) -> None: ...
59+
60+
61+
class DataClassOf(Protocol[PAny, TXarray]):
62+
"""Protocol for a dataclass object with an xarray factory."""
63+
64+
_xarray_factory: Callable[..., TXarray]
65+
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
66+
67+
def __init__(self, *args: PAny.args, **kwargs: PAny.kwargs) -> None: ...
68+
69+
4070
# constants
4171
class Tag(TagBase):
4272
"""Collection of xarray-related tags for annotating type hints."""

0 commit comments

Comments
 (0)