Skip to content

Optimize with dask #1981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions doctor_visits/delphi_doctor_visits/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ class Config:
# data columns
CLI_COLS = ["Covid_like", "Flu_like", "Mixed"]
FLU1_COL = ["Flu1"]
COUNT_COLS = ["Denominator"] + FLU1_COL + CLI_COLS
COUNT_COLS = CLI_COLS + FLU1_COL + ["Denominator"]
DATE_COL = "ServiceDate"
GEO_COL = "PatCountyFIPS"
AGE_COL = "PatAgeGroup"
HRR_COLS = ["Pat HRR Name", "Pat HRR ID"]
# as of 2020-05-11, input file expected to have 10 columns
# id cols: ServiceDate, PatCountyFIPS, PatAgeGroup, Pat HRR ID/Pat HRR Name
# value cols: Denominator, Covid_like, Flu_like, Flu1, Mixed
ID_COLS = [DATE_COL] + [GEO_COL] + HRR_COLS + [AGE_COL]
FILT_COLS = ID_COLS + COUNT_COLS
# drop HRR columns - unused for now since we assign HRRs by FIPS
FILT_COLS = [DATE_COL] + [GEO_COL] + [AGE_COL] + COUNT_COLS
DTYPES = {
"ServiceDate": str,
"PatCountyFIPS": str,
Expand Down
Binary file not shown.
101 changes: 101 additions & 0 deletions doctor_visits/delphi_doctor_visits/process_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import dask.dataframe as dd
from datetime import datetime
import numpy as np
import pandas as pd
from pathlib import Path

from .config import Config


def write_to_csv(output_df: pd.DataFrame, geo_level: str, se:bool, out_name: str, logger, output_path="."):
"""Write sensor values to csv.

Args:
output_dict: dictionary containing sensor rates, se, unique dates, and unique geo_id
geo_level: geographic resolution, one of ["county", "state", "msa", "hrr", "nation", "hhs"]
se: boolean to write out standard errors, if true, use an obfuscated name
out_name: name of the output file
output_path: outfile path to write the csv (default is current directory)
"""
if se:
logger.info(f"========= WARNING: WRITING SEs TO {out_name} =========")

out_n = 0
for d in set(output_df["date"]):
filename = "%s/%s_%s_%s.csv" % (output_path,
(d + Config.DAY_SHIFT).strftime("%Y%m%d"),
geo_level,
out_name)
single_date_df = output_df[output_df["date"] == d]
with open(filename, "w") as outfile:
outfile.write("geo_id,val,se,direction,sample_size\n")

for line in single_date_df.itertuples():
geo_id = line.geo_id
sensor = 100 * line.val # report percentages
se_val = 100 * line.se
assert not np.isnan(sensor), "sensor value is nan, check pipeline"
assert sensor < 90, f"strangely high percentage {geo_id, sensor}"
if not np.isnan(se_val):
assert se_val < 5, f"standard error suspiciously high! investigate {geo_id}"

if se:
assert sensor > 0 and se_val > 0, "p=0, std_err=0 invalid"
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, se_val, "NA", "NA"))
else:
# for privacy reasons we will not report the standard error
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, "NA", "NA", "NA"))
out_n += 1
logger.debug(f"wrote {out_n} rows for {geo_level}")


def csv_to_df(filepath: str, startdate: datetime, enddate: datetime, dropdate: datetime, logger) -> pd.DataFrame:
'''
Reads csv using Dask and filters out based on date range and currently unused column,
then converts back into pandas dataframe.
Parameters
----------
filepath: path to the aggregated doctor-visits data
startdate: first sensor date (YYYY-mm-dd)
enddate: last sensor date (YYYY-mm-dd)
dropdate: data drop date (YYYY-mm-dd)

-------
'''
filepath = Path(filepath)
logger.info(f"Processing {filepath}")

ddata = dd.read_csv(
filepath,
compression="gzip",
dtype=Config.DTYPES,
blocksize=None,
)

ddata = ddata.dropna()
# rename inconsistent column names to match config column names
ddata = ddata.rename(columns=Config.DEVIANT_COLS_MAP)

ddata = ddata[Config.FILT_COLS]
ddata[Config.DATE_COL] = dd.to_datetime(ddata[Config.DATE_COL])

# restrict to training start and end date
startdate = startdate - Config.DAY_SHIFT

assert startdate > Config.FIRST_DATA_DATE, "Start date <= first day of data"
assert startdate < enddate, "Start date >= end date"
assert enddate <= dropdate, "End date > drop date"

date_filter = ((ddata[Config.DATE_COL] >= Config.FIRST_DATA_DATE) & (ddata[Config.DATE_COL] < dropdate))

df = ddata[date_filter].compute()

# aggregate age groups (so data is unique by service date and FIPS)
df = df.groupby([Config.DATE_COL, Config.GEO_COL]).sum(numeric_only=True).reset_index()
assert np.sum(df.duplicated()) == 0, "Duplicates after age group aggregation"
assert (df[Config.COUNT_COLS] >= 0).all().all(), "Counts must be nonnegative"

logger.info(f"Done processing {filepath}")
return df
12 changes: 7 additions & 5 deletions doctor_visits/delphi_doctor_visits/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from delphi_utils import get_structured_logger

# first party
from .update_sensor import update_sensor, write_to_csv
from .update_sensor import update_sensor
from .process_data import csv_to_df, write_to_csv
from .download_claims_ftp_files import download
from .get_latest_claims_name import get_latest_filename

Expand Down Expand Up @@ -85,6 +86,7 @@ def run_module(params): # pylint: disable=too-many-statements
## geographies
geos = ["state", "msa", "hrr", "county", "hhs", "nation"]

claims_df = csv_to_df(claims_file, startdate_dt, enddate_dt, dropdate_dt, logger)

## print out other vars
logger.info("outpath:\t\t%s", export_dir)
Expand All @@ -103,10 +105,10 @@ def run_module(params): # pylint: disable=too-many-statements
else:
logger.info("starting %s, no adj", geo)
sensor = update_sensor(
filepath=claims_file,
startdate=startdate,
enddate=enddate,
dropdate=dropdate,
data=claims_df,
startdate=startdate_dt,
enddate=enddate_dt,
dropdate=dropdate_dt,
geo=geo,
parallel=params["indicator"]["parallel"],
weekday=weekday,
Expand Down
91 changes: 6 additions & 85 deletions doctor_visits/delphi_doctor_visits/update_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
"""

# standard packages
from datetime import timedelta
from datetime import timedelta, datetime
from multiprocessing import Pool, cpu_count
from pathlib import Path

# third party
import numpy as np
Expand All @@ -24,57 +23,14 @@
from .sensor import DoctorVisitsSensor


def write_to_csv(output_df: pd.DataFrame, geo_level, se, out_name, logger, output_path="."):
"""Write sensor values to csv.

Args:
output_dict: dictionary containing sensor rates, se, unique dates, and unique geo_id
se: boolean to write out standard errors, if true, use an obfuscated name
out_name: name of the output file
output_path: outfile path to write the csv (default is current directory)
"""
if se:
logger.info(f"========= WARNING: WRITING SEs TO {out_name} =========")

out_n = 0
for d in set(output_df["date"]):
filename = "%s/%s_%s_%s.csv" % (output_path,
(d + Config.DAY_SHIFT).strftime("%Y%m%d"),
geo_level,
out_name)
single_date_df = output_df[output_df["date"] == d]
with open(filename, "w") as outfile:
outfile.write("geo_id,val,se,direction,sample_size\n")

for line in single_date_df.itertuples():
geo_id = line.geo_id
sensor = 100 * line.val # report percentages
se_val = 100 * line.se
assert not np.isnan(sensor), "sensor value is nan, check pipeline"
assert sensor < 90, f"strangely high percentage {geo_id, sensor}"
if not np.isnan(se_val):
assert se_val < 5, f"standard error suspiciously high! investigate {geo_id}"

if se:
assert sensor > 0 and se_val > 0, "p=0, std_err=0 invalid"
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, se_val, "NA", "NA"))
else:
# for privacy reasons we will not report the standard error
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, sensor, "NA", "NA", "NA"))
out_n += 1
logger.debug(f"wrote {out_n} rows for {geo_level}")


def update_sensor(
filepath, startdate, enddate, dropdate, geo, parallel,
weekday, se, logger
data:pd.DataFrame, startdate:datetime, enddate:datetime, dropdate:datetime, geo:str, parallel: bool,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 we should start doing type specification.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but I think that should be a seperate ticket maybe? don't want to have more PR/confusion within already scoped out feature.

weekday:bool, se:bool, logger
):
"""Generate sensor values.

Args:
filepath: path to the aggregated doctor-visits data
data: dataframe of the cleaned claims file
startdate: first sensor date (YYYY-mm-dd)
enddate: last sensor date (YYYY-mm-dd)
dropdate: data drop date (YYYY-mm-dd)
Expand All @@ -84,45 +40,10 @@ def update_sensor(
se: boolean to write out standard errors, if true, use an obfuscated name
logger: the structured logger
"""
# as of 2020-05-11, input file expected to have 10 columns
# id cols: ServiceDate, PatCountyFIPS, PatAgeGroup, Pat HRR ID/Pat HRR Name
# value cols: Denominator, Covid_like, Flu_like, Flu1, Mixed
filename = Path(filepath).name
data = pd.read_csv(
filepath,
dtype=Config.DTYPES,
)
logger.info(f"Starting processing {filename} ")
data.rename(columns=Config.DEVIANT_COLS_MAP, inplace=True)
data = data[Config.FILT_COLS]
data[Config.DATE_COL] = data[Config.DATE_COL].apply(pd.to_datetime)
logger.info(f"finished processing {filename} ")
assert (
np.sum(data.duplicated(subset=Config.ID_COLS)) == 0
), "Duplicated data! Check the input file"

# drop HRR columns - unused for now since we assign HRRs by FIPS
data.drop(columns=Config.HRR_COLS, inplace=True)
data.dropna(inplace=True) # drop rows with any missing entries

# aggregate age groups (so data is unique by service date and FIPS)
data = data.groupby([Config.DATE_COL, Config.GEO_COL]).sum(numeric_only=True).reset_index()
assert np.sum(data.duplicated()) == 0, "Duplicates after age group aggregation"
assert (data[Config.COUNT_COLS] >= 0).all().all(), "Counts must be nonnegative"

## collect dates
# restrict to training start and end date

drange = lambda s, e: np.array([s + timedelta(days=x) for x in range((e - s).days)])
startdate = pd.to_datetime(startdate) - Config.DAY_SHIFT
burnindate = startdate - Config.DAY_SHIFT
enddate = pd.to_datetime(enddate)
dropdate = pd.to_datetime(dropdate)
assert startdate > Config.FIRST_DATA_DATE, "Start date <= first day of data"
assert startdate < enddate, "Start date >= end date"
assert enddate <= dropdate, "End date > drop date"
data = data[(data[Config.DATE_COL] >= Config.FIRST_DATA_DATE) & \
(data[Config.DATE_COL] < dropdate)]
fit_dates = drange(Config.FIRST_DATA_DATE, dropdate)
burnindate = startdate - Config.DAY_SHIFT
burn_in_dates = drange(burnindate, dropdate)
sensor_dates = drange(startdate, enddate)
# The ordering of sensor dates corresponds to the order of burn-in dates
Expand Down
1 change: 1 addition & 0 deletions doctor_visits/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"pytest-cov",
"pytest",
"scikit-learn",
"dask",
]

setup(
Expand Down
Binary file not shown.
Binary file not shown.
21 changes: 21 additions & 0 deletions doctor_visits/tests/test_process_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Tests for update_sensor.py."""
from datetime import datetime
import logging
import pandas as pd

from delphi_doctor_visits.process_data import csv_to_df

TEST_LOGGER = logging.getLogger()

class TestProcessData:
def test_csv_to_df(self):
actual = csv_to_df(
filepath="./test_data/SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.csv.gz",
startdate=datetime(2020, 2, 4),
enddate=datetime(2020, 2, 5),
dropdate=datetime(2020, 2,6),
logger=TEST_LOGGER,
)

comparison = pd.read_pickle("./comparison/process_data/main_after_date_SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.pkl")
pd.testing.assert_frame_equal(actual.reset_index(drop=True), comparison)
10 changes: 6 additions & 4 deletions doctor_visits/tests/test_update_sensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for update_sensor.py."""
from datetime import datetime
import logging
import pandas as pd

Expand All @@ -8,11 +9,12 @@

class TestUpdateSensor:
def test_update_sensor(self):
df = pd.read_pickle("./test_data/SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.pkl")
actual = update_sensor(
filepath="./test_data/SYNEDI_AGG_OUTPATIENT_07022020_1455CDT.csv.gz",
startdate="2020-02-04",
enddate="2020-02-05",
dropdate="2020-02-06",
data=df,
startdate=datetime(2020, 2, 4),
enddate=datetime(2020, 2, 5),
dropdate=datetime(2020, 2,6),
geo="state",
parallel=False,
weekday=False,
Expand Down