Skip to content

Commit 2c2a9ae

Browse files
authored
Merge branch 'v1_feature_branch' into pydantic-error-handling
2 parents 5084bb7 + 82fc11f commit 2c2a9ae

File tree

8 files changed

+329
-69
lines changed

8 files changed

+329
-69
lines changed

src/sed/core/config.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,78 @@ def complete_dictionary(dictionary: dict, base_dictionary: dict) -> dict:
254254
dictionary[k] = v
255255

256256
return dictionary
257+
258+
259+
def _parse_env_file(file_path: Path) -> dict:
260+
"""Helper function to parse a .env file into a dictionary.
261+
262+
Args:
263+
file_path (Path): Path to the .env file
264+
265+
Returns:
266+
dict: Dictionary of environment variables from the file
267+
"""
268+
env_content = {}
269+
if file_path.exists():
270+
with open(file_path) as f:
271+
for line in f:
272+
line = line.strip()
273+
if line and "=" in line:
274+
key, val = line.split("=", 1)
275+
env_content[key.strip()] = val.strip()
276+
return env_content
277+
278+
279+
def read_env_var(var_name: str) -> str | None:
280+
"""Read an environment variable from multiple locations in order:
281+
1. OS environment variables
282+
2. .env file in current directory
283+
3. .env file in user config directory
284+
285+
Args:
286+
var_name (str): Name of the environment variable to read
287+
288+
Returns:
289+
str | None: Value of the environment variable or None if not found
290+
"""
291+
# First check OS environment variables
292+
value = os.getenv(var_name)
293+
if value is not None:
294+
logger.debug(f"Found {var_name} in OS environment variables")
295+
return value
296+
297+
# Then check .env in current directory
298+
local_vars = _parse_env_file(Path(".env"))
299+
if var_name in local_vars:
300+
logger.debug(f"Found {var_name} in ./.env file")
301+
return local_vars[var_name]
302+
303+
# Finally check .env in user config directory
304+
user_vars = _parse_env_file(USER_CONFIG_PATH / ".env")
305+
if var_name in user_vars:
306+
logger.debug(f"Found {var_name} in user config .env file")
307+
return user_vars[var_name]
308+
309+
logger.debug(f"Environment variable {var_name} not found in any location")
310+
return None
311+
312+
313+
def save_env_var(var_name: str, value: str) -> None:
314+
"""Save an environment variable to the .env file in the user config directory.
315+
If the file exists, preserves other variables. If not, creates a new file.
316+
317+
Args:
318+
var_name (str): Name of the environment variable to save
319+
value (str): Value to save for the environment variable
320+
"""
321+
env_path = USER_CONFIG_PATH / ".env"
322+
env_content = _parse_env_file(env_path)
323+
324+
# Update or add new variable
325+
env_content[var_name] = value
326+
327+
# Write all variables back to file
328+
with open(env_path, "w") as f:
329+
for key, val in env_content.items():
330+
f.write(f"{key}={val}\n")
331+
logger.debug(f"Environment variable {var_name} saved to .env file")

src/sed/core/config_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pydantic import HttpUrl
1515
from pydantic import NewPath
1616
from pydantic import PositiveInt
17-
from pydantic import SecretStr
1817

1918
from sed.loader.loader_interface import get_names_of_all_loaders
2019

@@ -323,7 +322,6 @@ class MetadataModel(BaseModel):
323322
model_config = ConfigDict(extra="forbid")
324323

325324
archiver_url: Optional[HttpUrl] = None
326-
token: Optional[SecretStr] = None
327325
epics_pvs: Optional[Sequence[str]] = None
328326
fa_in_channel: Optional[str] = None
329327
fa_hor_channel: Optional[str] = None

src/sed/loader/flash/buffer_handler.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
from pathlib import Path
5+
import time
56

67
import dask.dataframe as dd
78
import pyarrow.parquet as pq
@@ -20,7 +21,7 @@
2021

2122
DF_TYP = ["electron", "timed"]
2223

23-
logger = setup_logging(__name__)
24+
logger = setup_logging("flash_buffer_handler")
2425

2526

2627
class BufferFilePaths:
@@ -135,16 +136,15 @@ def __init__(
135136
def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
136137
"""
137138
Checks the schema of the Parquet files.
138-
139-
Raises:
140-
ValueError: If the schema of the Parquet files does not match the configuration.
141139
"""
140+
logger.debug(f"Checking schema for {len(files)} files")
142141
existing = [file for file in files if file.exists()]
143142
parquet_schemas = [pq.read_schema(file) for file in existing]
144143

145144
for filename, schema in zip(existing, parquet_schemas):
146145
schema_set = set(schema.names)
147146
if schema_set != expected_schema_set:
147+
logger.error(f"Schema mismatch in file: {filename}")
148148
missing_in_parquet = expected_schema_set - schema_set
149149
missing_in_config = schema_set - expected_schema_set
150150

@@ -159,6 +159,7 @@ def _schema_check(self, files: list[Path], expected_schema_set: set) -> None:
159159
f"{' '.join(errors)}. "
160160
"Please check the configuration file or set force_recreate to True.",
161161
)
162+
logger.debug("Schema check passed successfully")
162163

163164
def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame:
164165
"""Creates the timed dataframe, optionally filtering by electron events.
@@ -185,35 +186,31 @@ def _create_timed_dataframe(self, df: dd.DataFrame) -> dd.DataFrame:
185186
return df_timed.loc[:, :, 0]
186187

187188
def _save_buffer_file(self, paths: dict[str, Path]) -> None:
188-
"""
189-
Creates the electron and timed buffer files from the raw H5 file.
190-
First the dataframe is accessed and forward filled in the non-electron channels.
191-
Then the data types are set. For the electron dataframe, all values not in the electron
192-
channels are dropped. For the timed dataframe, only the train and pulse channels are taken
193-
and it pulse resolved (no longer electron resolved). Both are saved as parquet files.
194-
195-
Args:
196-
paths (dict[str, Path]): Dictionary containing the paths to the H5 and buffer files.
197-
"""
198-
# Create a DataFrameCreator instance and get the h5 file
189+
"""Creates the electron and timed buffer files from the raw H5 file."""
190+
logger.debug(f"Processing file: {paths['raw'].stem}")
191+
start_time = time.time()
192+
# Create DataFrameCreator and get get dataframe
199193
df = DataFrameCreator(config_dataframe=self._config, h5_path=paths["raw"]).df
200194

201-
# forward fill all the non-electron channels
195+
# Forward fill non-electron channels
196+
logger.debug(f"Forward filling {len(self.fill_channels)} channels")
202197
df[self.fill_channels] = df[self.fill_channels].ffill()
203198

204199
# Save electron resolved dataframe
205200
electron_channels = get_channels(self._config, "per_electron")
206201
dtypes = get_dtypes(self._config, df.columns.values)
207-
df.dropna(subset=electron_channels).astype(dtypes).reset_index().to_parquet(
208-
paths["electron"],
209-
)
202+
electron_df = df.dropna(subset=electron_channels).astype(dtypes).reset_index()
203+
logger.debug(f"Saving electron buffer with shape: {electron_df.shape}")
204+
electron_df.to_parquet(paths["electron"])
210205

211206
# Create and save timed dataframe
212207
df_timed = self._create_timed_dataframe(df)
213208
dtypes = get_dtypes(self._config, df_timed.columns.values)
214-
df_timed.astype(dtypes).reset_index().to_parquet(paths["timed"])
209+
timed_df = df_timed.astype(dtypes).reset_index()
210+
logger.debug(f"Saving timed buffer with shape: {timed_df.shape}")
211+
timed_df.to_parquet(paths["timed"])
215212

216-
logger.debug(f"Processed {paths['raw'].stem}")
213+
logger.debug(f"Processed {paths['raw'].stem} in {time.time() - start_time:.2f}s")
217214

218215
def _save_buffer_files(self, force_recreate: bool, debug: bool) -> None:
219216
"""

src/sed/loader/flash/dataframe.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from sed.loader.flash.utils import get_channels
1616
from sed.loader.flash.utils import InvalidFileError
17+
from sed.core.logging import setup_logging
18+
19+
logger = setup_logging("flash_dataframe_creator")
1720

1821

1922
class DataFrameCreator:
@@ -34,6 +37,7 @@ def __init__(self, config_dataframe: dict, h5_path: Path) -> None:
3437
config_dataframe (dict): The configuration dictionary with only the dataframe key.
3538
h5_path (Path): Path to the h5 file.
3639
"""
40+
logger.debug(f"Initializing DataFrameCreator for file: {h5_path}")
3741
self.h5_file = h5py.File(h5_path, "r")
3842
self.multi_index = get_channels(index=True)
3943
self._config = config_dataframe
@@ -76,6 +80,7 @@ def get_dataset_array(
7680
tuple[pd.Index, np.ndarray | h5py.Dataset]: A tuple containing the train ID
7781
pd.Index and the channel's data.
7882
"""
83+
logger.debug(f"Getting dataset array for channel: {channel}")
7984
# Get the data from the necessary h5 file and channel
8085
index_key, dataset_key = self.get_index_dataset_key(channel)
8186

@@ -85,6 +90,7 @@ def get_dataset_array(
8590
if slice_:
8691
slice_index = self._config["channels"][channel].get("slice", None)
8792
if slice_index is not None:
93+
logger.debug(f"Slicing dataset with index: {slice_index}")
8894
dataset = np.take(dataset, slice_index, axis=1)
8995
# If np_array is size zero, fill with NaNs, fill it with NaN values
9096
# of the same shape as index
@@ -291,10 +297,14 @@ def df(self) -> pd.DataFrame:
291297
Returns:
292298
pd.DataFrame: The combined pandas DataFrame.
293299
"""
294-
300+
logger.debug("Creating combined DataFrame")
295301
self.validate_channel_keys()
296-
# been tested with merge, join and concat
297-
# concat offers best performance, almost 3 times faster
302+
298303
df = pd.concat((self.df_electron, self.df_pulse, self.df_train), axis=1).sort_index()
299-
# all the negative pulse values are dropped as they are invalid
300-
return df[df.index.get_level_values("pulseId") >= 0]
304+
logger.debug(f"Created DataFrame with shape: {df.shape}")
305+
306+
# Filter negative pulse values
307+
df = df[df.index.get_level_values("pulseId") >= 0]
308+
logger.debug(f"Filtered DataFrame shape: {df.shape}")
309+
310+
return df

src/sed/loader/flash/loader.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,14 @@ def get_files_from_run_id( # type: ignore[override]
207207
# Return the list of found files
208208
return [str(file.resolve()) for file in files]
209209

210-
def parse_metadata(self, scicat_token: str = None) -> dict:
210+
def parse_metadata(self, token: str = None) -> dict:
211211
"""Uses the MetadataRetriever class to fetch metadata from scicat for each run.
212212
213213
Returns:
214214
dict: Metadata dictionary
215-
scicat_token (str, optional):: The scicat token to use for fetching metadata
215+
token (str, optional):: The scicat token to use for fetching metadata
216216
"""
217-
metadata_retriever = MetadataRetriever(self._config["metadata"], scicat_token)
217+
metadata_retriever = MetadataRetriever(self._config["metadata"], token)
218218
metadata = metadata_retriever.get_metadata(
219219
beamtime_id=self._config["core"]["beamtime_id"],
220220
runs=self.runs,
@@ -329,7 +329,9 @@ def read_dataframe(
329329
debug (bool, optional): Whether to run buffer creation in serial. Defaults to False.
330330
remove_invalid_files (bool, optional): Whether to exclude invalid files.
331331
Defaults to False.
332-
scicat_token (str, optional): The scicat token to use for fetching metadata.
332+
token (str, optional): The scicat token to use for fetching metadata. If provided,
333+
will be saved to .env file for future use. If not provided, will check environment
334+
variables when collect_metadata is True.
333335
filter_timed_by_electron (bool, optional): When True, the timed dataframe will only
334336
contain data points where valid electron events were detected. When False, all
335337
timed data points are included regardless of electron detection. Defaults to True.
@@ -341,13 +343,14 @@ def read_dataframe(
341343
Raises:
342344
ValueError: If neither 'runs' nor 'files'/'raw_dir' is provided.
343345
FileNotFoundError: If the conversion fails for some files or no data is available.
346+
ValueError: If collect_metadata is True and no token is available.
344347
"""
345348
detector = kwds.pop("detector", "")
346349
force_recreate = kwds.pop("force_recreate", False)
347350
processed_dir = kwds.pop("processed_dir", None)
348351
debug = kwds.pop("debug", False)
349352
remove_invalid_files = kwds.pop("remove_invalid_files", False)
350-
scicat_token = kwds.pop("scicat_token", None)
353+
token = kwds.pop("token", None)
351354
filter_timed_by_electron = kwds.pop("filter_timed_by_electron", True)
352355

353356
if len(kwds) > 0:
@@ -401,7 +404,7 @@ def read_dataframe(
401404
if self.instrument == "wespe":
402405
df, df_timed = wespe_convert(df, df_timed)
403406

404-
self.metadata.update(self.parse_metadata(scicat_token) if collect_metadata else {})
407+
self.metadata.update(self.parse_metadata(token) if collect_metadata else {})
405408
self.metadata.update(bh.metadata)
406409

407410
print(f"loading complete in {time.time() - t0: .2f} s")

0 commit comments

Comments
 (0)