2
2
3
3
4
4
# standard library
5
- from typing import Any , Callable , overload
6
-
5
+ from dataclasses import replace
6
+ from typing import Any , ForwardRef , Literal , Optional , overload
7
7
8
8
# dependencies
9
- from xarray import DataArray , Dataset
10
- from .typing import DataClass , DataClassOf , PAny , TDataArray , TDataset , TXarray
9
+ from dataspecs import ID , ROOT , Spec , Specs
10
+ from numpy import asarray , array
11
+ from typing_extensions import get_args , get_origin
12
+ from xarray import DataArray , Dataset , Variable
13
+ from .typing import (
14
+ DataClass ,
15
+ DataClassOf ,
16
+ Factory ,
17
+ HashDict ,
18
+ PAny ,
19
+ TAny ,
20
+ TDataArray ,
21
+ TDataset ,
22
+ TXarray ,
23
+ Tag ,
24
+ )
25
+
26
+
27
+ # type hints
28
+ Attrs = HashDict [Any ]
29
+ Vars = HashDict [Variable ]
11
30
12
31
13
32
@overload
@@ -24,7 +43,7 @@ def asdataarray(
24
43
obj : DataClass [PAny ],
25
44
/ ,
26
45
* ,
27
- factory : Callable [..., TDataArray ],
46
+ factory : Factory [ TDataArray ],
28
47
) -> TDataArray : ...
29
48
30
49
@@ -56,7 +75,7 @@ def asdataset(
56
75
obj : DataClass [PAny ],
57
76
/ ,
58
77
* ,
59
- factory : Callable [..., TDataset ],
78
+ factory : Factory [ TDataset ],
60
79
) -> TDataset : ...
61
80
62
81
@@ -88,10 +107,84 @@ def asxarray(
88
107
obj : DataClass [PAny ],
89
108
/ ,
90
109
* ,
91
- factory : Callable [..., TXarray ],
110
+ factory : Factory [ TXarray ],
92
111
) -> TXarray : ...
93
112
94
113
95
114
def asxarray (obj : Any , / , * , factory : Any = None ) -> Any :
96
115
"""Create a DataArray/set object from a dataclass object."""
97
116
...
117
+
118
+
119
+ def get_attrs (specs : Specs [Spec [Any ]], / , * , at : ID = ROOT ) -> Attrs :
120
+ """Create attributes from data specs."""
121
+ attrs : Attrs = {}
122
+
123
+ for spec in specs [at .children ][Tag .ATTR ]:
124
+ options = specs [spec .id .children ]
125
+ factory = maybe (options [Tag .FACTORY ].unique ).data or identity
126
+ name = maybe (options [Tag .NAME ].unique ).data or spec .id .name
127
+
128
+ if Tag .MULTIPLE not in spec .tags :
129
+ spec = replace (spec , data = {name : spec .data })
130
+
131
+ for name , data in spec [HashDict [Any ]].data .items ():
132
+ attrs [name ] = factory (data )
133
+
134
+ return attrs
135
+
136
+
137
+ def get_vars (specs : Specs [Spec [Any ]], of : Tag , / , * , at : ID = ROOT ) -> Vars :
138
+ """Create variables of given tag from data specs."""
139
+ vars : Vars = {}
140
+
141
+ for spec in specs [at .children ][of ]:
142
+ options = specs [spec .id .children ]
143
+ attrs = get_attrs (specs , at = spec .id )
144
+ factory = maybe (options [Tag .FACTORY ].unique ).data or Variable
145
+ name = maybe (options [Tag .NAME ].unique ).data or spec .id .name
146
+
147
+ if (type_ := maybe (options [Tag .DIMS ].unique ).type ) is None :
148
+ raise RuntimeError ("Could not find any data spec for dims." )
149
+ elif get_origin (type_ ) is tuple :
150
+ dims = tuple (str (unwrap (arg )) for arg in get_args (type_ ))
151
+ else :
152
+ dims = (str (unwrap (type_ )),)
153
+
154
+ if (type_ := maybe (options [Tag .DTYPE ].unique ).type ) is None :
155
+ raise RuntimeError ("Could not find any data spec for dims." )
156
+ elif type_ is type (None ) or type_ is Any :
157
+ dtype = None
158
+ else :
159
+ dtype = unwrap (type_ )
160
+
161
+ if Tag .MULTIPLE not in spec .tags :
162
+ spec = replace (spec , data = {name : spec .data })
163
+
164
+ for name , data in spec [HashDict [Any ]].data .items ():
165
+ if not (data := asarray (data , dtype )).ndim :
166
+ data = array (data , ndmin = len (dims ))
167
+
168
+ vars [name ] = factory (attrs = attrs , data = data , dims = dims )
169
+
170
+ return vars
171
+
172
+
173
+ def identity (obj : TAny , / ) -> TAny :
174
+ """Identity function used for the default factory."""
175
+ return obj
176
+
177
+
178
+ def maybe (obj : Optional [Spec [Any ]], / ) -> Spec [Any ]:
179
+ """Return a dummy (``None``-filled) data spec if an object is not one."""
180
+ return Spec (ROOT , (), None , None ) if obj is None else obj
181
+
182
+
183
+ def unwrap (obj : Any , / ) -> Any :
184
+ """Unwrap if an object is a literal or a forward reference."""
185
+ if get_origin (obj ) is Literal :
186
+ return args [0 ] if len (args := get_args (obj )) == 1 else obj
187
+ elif isinstance (obj , ForwardRef ):
188
+ return obj .__forward_arg__
189
+ else :
190
+ return obj
0 commit comments