Skip to content

Commit bb38fb1

Browse files
Turn spatial_partitions into property + validate value in setter (#159)
1 parent 8571ffe commit bb38fb1

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

dask_geopandas/core.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,7 @@ class _Frame(dd.core._Frame, OperatorMethodMixin):
5555

5656
def __init__(self, dsk, name, meta, divisions, spatial_partitions=None):
5757
super().__init__(dsk, name, meta, divisions)
58-
if spatial_partitions is not None and not isinstance(
59-
spatial_partitions, geopandas.GeoSeries
60-
):
61-
spatial_partitions = geopandas.GeoSeries(spatial_partitions)
62-
# TODO make this a property
63-
self.spatial_partitions = spatial_partitions
58+
self._spatial_partitions = spatial_partitions
6459

6560
def to_dask_dataframe(self):
6661
"""Create a dask.dataframe object from a dask_geopandas object"""
@@ -72,6 +67,28 @@ def __dask_postcompute__(self):
7267
def __dask_postpersist__(self):
7368
return type(self), (self._name, self._meta, self.divisions)
7469

70+
@property
71+
def spatial_partitions(self):
72+
"""
73+
The spatial extent of each of the partitions of the dask GeoDataFrame.
74+
"""
75+
return self._spatial_partitions
76+
77+
@spatial_partitions.setter
78+
def spatial_partitions(self, value):
79+
if value is not None:
80+
if not isinstance(value, geopandas.GeoSeries):
81+
raise TypeError(
82+
"Expected a geopandas.GeoSeries for the spatial_partitions, "
83+
f"got {type(value)} instead."
84+
)
85+
if len(value) != self.npartitions:
86+
raise ValueError(
87+
f"Expected spatial partitions of length {self.npartitions}, "
88+
f"got {len(value)} instead."
89+
)
90+
self._spatial_partitions = value
91+
7592
@classmethod
7693
def _bind_property(cls, attr, preserve_spatial_partitions=False):
7794
"""Map property to partitions and bind to class"""

tests/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,18 @@ def check_meta(gdf, name):
427427
check_meta(meta_non_empty, "foo")
428428

429429

430+
def test_spatial_partitions_setter(geodf_points):
431+
dask_obj = dask_geopandas.from_geopandas(geodf_points, npartitions=2)
432+
433+
# needs to be a GeoSeries
434+
with pytest.raises(TypeError):
435+
dask_obj.spatial_partitions = geodf_points
436+
437+
# wrong length
438+
with pytest.raises(ValueError):
439+
dask_obj.spatial_partitions = geodf_points.geometry
440+
441+
430442
def test_to_crs_geodf(geodf_points_crs):
431443
df = geodf_points_crs
432444
dask_obj = dask_geopandas.from_geopandas(df, npartitions=2)

0 commit comments

Comments
 (0)