Skip to content

Commit 541d4c8

Browse files
authored
Merge pull request #557 from OpenCOMPES/fix_normed_dtype
Fix normed dtype
2 parents 60b114d + 6cc7655 commit 541d4c8

File tree

3 files changed

+30
-31
lines changed

3 files changed

+30
-31
lines changed

src/sed/binning/binning.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,7 @@ def normalization_histogram_from_timed_dataframe(
465465
axis: str,
466466
bin_centers: np.ndarray,
467467
time_unit: float,
468+
**kwds,
468469
) -> xr.DataArray:
469470
"""Get a normalization histogram from a timed dataframe.
470471
@@ -475,17 +476,12 @@ def normalization_histogram_from_timed_dataframe(
475476
histogram.
476477
bin_centers (np.ndarray): Bin centers used for binning of the axis.
477478
time_unit (float): Time unit the data frame entries are based on.
479+
**kwds: Additional keyword arguments passed to the bin_dataframe function.
478480
479481
Returns:
480482
xr.DataArray: Calculated normalization histogram.
481483
"""
482-
bins = df[axis].map_partitions(
483-
pd.cut,
484-
bins=bin_centers_to_bin_edges(bin_centers),
485-
)
486-
487-
histogram = df[axis].groupby([bins]).count().compute().values * time_unit
488-
# histogram = bin_dataframe(df, axes=[axis], bins=[bin_centers]) * time_unit
484+
histogram = bin_dataframe(df, axes=[axis], bins=[bin_centers], **kwds) * time_unit
489485

490486
data_array = xr.DataArray(
491487
data=histogram,

src/sed/core/processor.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,8 @@ def compute(
22832283
)
22842284
# if the axes are named correctly, xarray figures out the normalization correctly
22852285
self._normalized = self._binned / self._normalization_histogram
2286+
# Set datatype of binned data
2287+
self._normalized.data = self._normalized.data.astype(self._binned.data.dtype)
22862288
self._attributes.add(
22872289
self._normalization_histogram.values,
22882290
name="normalization_histogram",
@@ -2353,36 +2355,35 @@ def get_normalization_histogram(
23532355

23542356
if isinstance(df_partitions, int):
23552357
df_partitions = list(range(0, min(df_partitions, self._dataframe.npartitions)))
2358+
23562359
if use_time_stamps or self._timed_dataframe is None:
23572360
if df_partitions is not None:
2358-
self._normalization_histogram = normalization_histogram_from_timestamps(
2359-
self._dataframe.partitions[df_partitions],
2360-
axis,
2361-
self._binned.coords[axis].values,
2362-
self._config["dataframe"]["columns"]["timestamp"],
2363-
)
2361+
dataframe = self._dataframe.partitions[df_partitions]
23642362
else:
2365-
self._normalization_histogram = normalization_histogram_from_timestamps(
2366-
self._dataframe,
2367-
axis,
2368-
self._binned.coords[axis].values,
2369-
self._config["dataframe"]["columns"]["timestamp"],
2370-
)
2363+
dataframe = self._dataframe
2364+
self._normalization_histogram = normalization_histogram_from_timestamps(
2365+
df=dataframe,
2366+
axis=axis,
2367+
bin_centers=self._binned.coords[axis].values,
2368+
time_stamp_column=self._config["dataframe"]["columns"]["timestamp"],
2369+
)
23712370
else:
23722371
if df_partitions is not None:
2373-
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
2374-
self._timed_dataframe.partitions[df_partitions],
2375-
axis,
2376-
self._binned.coords[axis].values,
2377-
self._config["dataframe"]["timed_dataframe_unit_time"],
2378-
)
2372+
timed_dataframe = self._timed_dataframe.partitions[df_partitions]
23792373
else:
2380-
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
2381-
self._timed_dataframe,
2382-
axis,
2383-
self._binned.coords[axis].values,
2384-
self._config["dataframe"]["timed_dataframe_unit_time"],
2385-
)
2374+
timed_dataframe = self._timed_dataframe
2375+
self._normalization_histogram = normalization_histogram_from_timed_dataframe(
2376+
df=timed_dataframe,
2377+
axis=axis,
2378+
bin_centers=self._binned.coords[axis].values,
2379+
time_unit=self._config["dataframe"]["timed_dataframe_unit_time"],
2380+
hist_mode=self.config["binning"]["hist_mode"],
2381+
mode=self.config["binning"]["mode"],
2382+
pbar=self.config["binning"]["pbar"],
2383+
n_cores=self.config["core"]["num_cores"],
2384+
threads_per_worker=self.config["binning"]["threads_per_worker"],
2385+
threadpool_api=self.config["binning"]["threadpool_API"],
2386+
)
23862387

23872388
return self._normalization_histogram
23882389

tests/test_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ def test_compute_with_normalization() -> None:
10081008
processor.binned.data,
10091009
(processor.normalized * processor.normalization_histogram).data,
10101010
)
1011+
# check dtype
1012+
assert processor.normalized.dtype == processor.binned.dtype
10111013
# bin only second dataframe partition
10121014
result2 = processor.compute(
10131015
bins=bins,

0 commit comments

Comments
 (0)