diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 95fa40b1..bbc64b9c 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -35,7 +35,7 @@ jobs: #---------------------------------------------- - name: Run tests run: | - pytest -v --doctest-modules --cov=src --junitxml=junit.xml -s --ignore=docs + pytest src/ tests/ -v --doctest-modules --cov=src --junitxml=junit.xml -s - name: Upload coverage to Codecov uses: codecov/codecov-action@v4.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 61bde520..0d7f40da 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,11 @@ default_language_version: python: python3.12 -exclude: "docs/index.md|MIMIC-IV_Example/README.md|eICU_Example/README.md" +exclude: "docs/index.md|MIMIC-IV_Example/README.md|eICU_Example/README.md|AUMCdb_Example/README.md" repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: # list of supported hooks: https://pre-commit.com/hooks.html - id: trailing-whitespace @@ -22,27 +22,27 @@ repos: # python code formatting - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 24.10.0 hooks: - id: black args: [--line-length, "110"] # python import sorting - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black", "--filter-files", "-o", "wandb"] - repo: https://github.com/PyCQA/autoflake - rev: v2.2.0 + rev: v2.3.1 hooks: - id: autoflake args: [--in-place, --remove-all-unused-imports] # python upgrading syntax to newer version - repo: https://github.com/asottile/pyupgrade - rev: v3.10.1 + rev: v3.19.0 hooks: - id: pyupgrade args: [--py311-plus] @@ -56,7 +56,7 @@ repos: # python check (PEP8), programming errors and code complexity - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.1.1 hooks: - id: flake8 args: @@ -73,7 +73,7 @@ repos: # yaml formatting - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 + rev: v4.0.0-alpha.8 hooks: - id: prettier types: [yaml] @@ -81,13 +81,13 @@ repos: # shell scripts linter - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.9.0.5 + rev: v0.10.0.1 hooks: - id: shellcheck # md formatting - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.18 hooks: - id: mdformat args: ["--number"] @@ -104,7 +104,7 @@ repos: # word spelling linter - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.3.0 hooks: - id: codespell args: @@ -113,17 +113,21 @@ repos: # jupyter notebook cell output clearing - repo: https://github.com/kynan/nbstripout - rev: 0.6.1 + rev: 0.7.1 hooks: - id: nbstripout # jupyter notebook linting - repo: https://github.com/nbQA-dev/nbQA - rev: 1.7.0 + rev: 1.8.7 hooks: - id: nbqa-black args: ["--line-length=110"] - id: nbqa-isort args: ["--profile=black"] - id: nbqa-flake8 - args: ["--extend-ignore=E203,E402,E501,F401,F841", "--exclude=logs/*,data/*"] + args: + [ + "--extend-ignore=E203,E402,E501,F401,F841", + "--exclude=logs/*,data/*", + ] diff --git a/AUMCdb_Example/README.md b/AUMCdb_Example/README.md new file mode 100644 index 00000000..85a378c3 --- /dev/null +++ b/AUMCdb_Example/README.md @@ -0,0 +1,89 @@ +# AUMC Example + +This is an example of how to extract a MEDS dataset from AUMCdb (https://github.com/AmsterdamUMC/AmsterdamUMCdb). All scripts in this README are assumed to be run from this directory or from the directory in which the files in Step 0.5. were downloaded. + +## Step 0: Installation + +```bash +conda create -n MEDS python=3.12 +conda activate MEDS +pip install "MEDS_transforms[local_parallelism,slurm_parallelism]" +``` + +If you want to profile the time and memory costs of your ETL, also install: `pip install hydra-profiler`. + +## Step 0.5: Set-up + +Set some environment variables and download the necessary files: + +```bash +export AUMC_RAW_DIR=??? # set to the directory in which you want to store the raw data +export AUMC_PRE_MEDS_DIR=??? # set to the directory in which you want to store the intermediate MEDS data +export AUMC_MEDS_COHORT_DIR=??? # set to the directory in which you want to store the final MEDS data + +export VERSION=0.0.8 # or whatever version you want +export URL="https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/$VERSION/AUMC_Example" + +wget $URL/run.sh +wget $URL/pre_MEDS.py +wget $URL/local_parallelism_runner.yaml +wget $URL/slurm_runner.yaml +mkdir configs +cd configs +wget $URL/configs/extract_AUMC.yaml +cd .. +chmod +x run.sh +chmod +x pre_MEDS.py +``` + +## Step 1: Download AUMC + +Download the AUMC dataset from following the instructions on https://github.com/AmsterdamUMC/AmsterdamUMCdb?tab=readme-ov-file. You will need the raw `.csv` files for this example. We will use `$AUMC_RAW_DIR` to denote the root directory of where the resulting _core data files_ are stored. + +## Step 2: Run the MEDS ETL + +To run the MEDS ETL, run the following command: + +```bash +./run.sh $AUMC_RAW_DIR $AUMC_PRE_MEDS_DIR $AUMC_MEDS_COHORT_DIR +``` + +> \[!NOTE\] +> This can take up large amounts of memory if not parallelized. You can reduce the shard size to reduce memory usage by setting the `shard_size` parameter in the `extract_AUMC.yaml` file. +> Check that your environment variables are set correctly. + +To use a specific stage runner file (e.g., to set different parallelism options), you can specify it as an +additional argument + +```bash +export N_WORKERS=5 +./run.sh $AUMC_RAW_DIR $AUMC_PRE_MEDS_DIR $AUMC_MEDS_DIR \ + stage_runner_fp=slurm_runner.yaml +``` + +The `N_WORKERS` environment variable set before the command controls how many parallel workers should be used +at maximum. + +The `slurm_runner.yaml` file (downloaded above) runs each stage across several workers on separate slurm +worker nodes using the `submitit` launcher. _**You will need to customize this file to your own slurm system +so that the partition names are correct before use.**_ The memory and time costs are viable in the current +configuration, but if your nodes are sufficiently different you may need to adjust those as well. + +The `local_parallelism_runner.yaml` file (downloaded above) runs each stage via separate processes on the +launching machine. There are no additional arguments needed for this stage beyond the `N_WORKERS` environment +variable and there is nothing to customize in this file. + +To profile the time and memory costs of your ETL, add the `do_profile=true` flag at the end. + +## Notes + +Note: If you use the slurm system and you launch the hydra submitit jobs from an interactive slurm node, you +may need to run `unset SLURM_CPU_BIND` in your terminal first to avoid errors. + +## Future Work + +Check with AUMCdb authors: + +- How should we deal with `registeredat` and `updatedat`? +- We **IGNORE** several flags for the `drugitems` -- this may be a mistake! +- When is the administered dose recorded? Is this done after the fact? diff --git a/AUMCdb_Example/__init__.py b/AUMCdb_Example/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/AUMCdb_Example/configs/event_configs.yaml b/AUMCdb_Example/configs/event_configs.yaml new file mode 100644 index 00000000..7233e20c --- /dev/null +++ b/AUMCdb_Example/configs/event_configs.yaml @@ -0,0 +1,123 @@ +subject_id_col: patientid + +patient: + dob: + code: "MEDS_BIRTH" + time: col(dateofbirth) + gender: + code: ["GENDER", "col(gender)"] + time: null + dod: + code: "MEDS_DEATH" + time: col(dateofdeath) + +admissions: + icu_admission: + code: + - "ICU_ADMISSION" + - col(location) + - col(urgency) + - col(origin) + - col(specialty) + time: col(admittedattime) + icu_discharge: + code: + - "ICU_DISCHARGE" + - col(destination) + time: col(dischargedattime) + weight: + code: + - "WEIGHT_AT_ADMISSION" + - col(weightsource) + - col(weightgroup) + time: col(admittedattime) + height: + code: + - "HEIGHT_AT_ADMISSION" + - col(heightsource) + - col(heightgroup) + time: col(admittedattime) + +numericitems: + event: + code: + - MEASURE + - col(item) + - col(unit) + time: col(measuredattime) + numeric_value: value + +listitems: + event: + code: + - MEASURE + - col(item) + - col(islabresult) + - col(value) + time: col(measuredattime) + +freetextitems: + event: + code: + - MEASURE + - col(item) + - col(islabresult) + time: col(measuredattime) + text_value: value + +procedureorderitems: + event: + code: + - PROCEDURE + - col(ordercategoryname) + - col(item) + time: col(registeredattime) + +processitems: + start: + code: + - PROCESS + - START + - col(item) + time: col(starttime) + end: + code: + - PROCESS + - END + - col(item) + time: col(stoptime) + +drugitems: + start: + code: + - DRUG + - START + - col(ordercategory) + - col(item) + - col(action) + time: col(starttime) + rate: + code: + - DRUG + - RATE + - col(ordercategory) + - col(item) + - col(rateunit) + time: col(starttime) + numeric_value: col(rate) + dose: + code: + - DRUG + - DOSE + - col(ordercategory) + - col(item) + - col(doseunit) + time: col(starttime) + numeric_value: col(dose) + end: + code: + - DRUG + - END + - col(ordercategory) + - col(item) + time: col(stoptime) diff --git a/AUMCdb_Example/configs/extract_AUMC.yaml b/AUMCdb_Example/configs/extract_AUMC.yaml new file mode 100644 index 00000000..cf364327 --- /dev/null +++ b/AUMCdb_Example/configs/extract_AUMC.yaml @@ -0,0 +1,35 @@ +defaults: + - _extract + - _self_ + +description: |- + This pipeline extracts the AUMCdb dataset in longitudinal, sparse form from an input dataset meeting + select criteria and converts them to the flattened, MEDS format. You can control the key arguments to this + pipeline by setting environment variables: + ```bash + export EVENT_CONVERSION_CONFIG_FP=# Path to your event conversion config + export AUMC_PRE_MEDS_DIR=# Path to the output dir of the pre-MEDS step + export AUMC_MEDS_COHORT_DIR=# Path to where you want the dataset to live + ``` + +# The event conversion configuration file is used throughout the pipeline to define the events to extract. +event_conversion_config_fp: ${oc.env:EVENT_CONVERSION_CONFIG_FP} + +input_dir: ${oc.env:AUMC_PRE_MEDS_DIR} +cohort_dir: ${oc.env:AUMC_MEDS_COHORT_DIR} + +etl_metadata: + dataset_name: AUMCdb + dataset_version: 1.0.2 + +stage_configs: + split_and_shard_subjects: + n_subjects_per_shard: 1000 + +stages: + - shard_events + - split_and_shard_subjects + - convert_to_sharded_events + - merge_to_MEDS_cohort + - finalize_MEDS_metadata + - finalize_MEDS_data diff --git a/AUMCdb_Example/configs/pre_MEDS.yaml b/AUMCdb_Example/configs/pre_MEDS.yaml new file mode 100644 index 00000000..ac311776 --- /dev/null +++ b/AUMCdb_Example/configs/pre_MEDS.yaml @@ -0,0 +1,13 @@ +input_dir: ${oc.env:AUMC_RAW_DIR} +cohort_dir: ${oc.env:AUMC_PRE_MEDS_DIR} + +log_dir: ${cohort_dir}/.logs + +# Hydra +hydra: + job: + name: pre_MEDS_${now:%Y-%m-%d_%H-%M-%S} + run: + dir: ${log_dir} + sweep: + dir: ${log_dir} diff --git a/AUMCdb_Example/configs/table_preprocessors.yaml b/AUMCdb_Example/configs/table_preprocessors.yaml new file mode 100644 index 00000000..6c253bce --- /dev/null +++ b/AUMCdb_Example/configs/table_preprocessors.yaml @@ -0,0 +1,116 @@ +admissions: + offset_col: + - "admittedat" + - "dischargedat" + pseudotime_col: + - "admittedattime" + - "dischargedattime" + output_data_cols: + - "location" + - "urgency" + - "origin" + - "destination" + - "weightgroup" + - "weightsource" + - "heightgroup" + - "heightsource" + - "specialty" + +numericitems: + offset_col: + - "measuredat" + - "registeredat" + - "updatedat" + pseudotime_col: + - "measuredattime" + - "registeredattime" + - "updatedattime" + output_data_cols: + - "item" + - "value" + - "unit" + - "registeredby" + - "updatedby" + warning_items: + - "How should we deal with `registeredat` and `updatedat`?" + +listitems: + offset_col: + - "measuredat" + - "registeredat" + - "updatedat" + pseudotime_col: + - "measuredattime" + - "registeredattime" + - "updatedattime" + output_data_cols: + - "item" + - "value" + - "islabresult" + - "registeredby" + - "updatedby" + warning_items: + - "How should we deal with `registeredat` and `updatedat`?" + +freetextitems: + offset_col: + - "measuredat" + - "registeredat" + - "updatedat" + pseudotime_col: + - "measuredattime" + - "registeredattime" + - "updatedattime" + output_data_cols: + - "item" + - "value" + - "comment" + - "islabresult" + - "registeredby" + - "updatedby" + warning_items: + - "How should we deal with `registeredat` and `updatedat`?" + +drugitems: + offset_col: + - "start" + - "stop" + pseudotime_col: + - "starttime" + - "stoptime" + output_data_cols: + - "orderid" + - "ordercategory" + - "item" + - "rate" + - "rateunit" + - "ratetimeunitid" + - "dose" + - "doseunit" + - "doserateunit" + - "duration" + - "administered" + - "administeredunit" + - "action" + warning_items: + - "We **IGNORE** several flags here -- this may be a mistake!" + - "When is the administered dose recorded? Is this done after the fact?" + +procedureorderitems: + offset_col: "registeredat" + pseudotime_col: "registeredattime" + output_data_cols: + - "orderid" + - "ordercategoryname" + - "item" + - "registeredby" + +processitems: + offset_col: + - "start" + - "stop" + pseudotime_col: + - "starttime" + - "stoptime" + output_data_cols: + - "item" diff --git a/AUMCdb_Example/local_parallelism_runner.yaml b/AUMCdb_Example/local_parallelism_runner.yaml new file mode 100644 index 00000000..a1d9a6c1 --- /dev/null +++ b/AUMCdb_Example/local_parallelism_runner.yaml @@ -0,0 +1,3 @@ +parallelize: + n_workers: ${oc.env:N_WORKERS} + launcher: "joblib" diff --git a/AUMCdb_Example/pre_MEDS.py b/AUMCdb_Example/pre_MEDS.py new file mode 100755 index 00000000..8341f897 --- /dev/null +++ b/AUMCdb_Example/pre_MEDS.py @@ -0,0 +1,255 @@ +#!/usr/bin/env python + +"""Performs pre-MEDS data wrangling for AUMCdb. + +See the docstring of `main` for more information. +""" + +from collections.abc import Callable +from datetime import datetime +from pathlib import Path + +import hydra +import polars as pl +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from MEDS_transforms.utils import get_shard_prefix, hydra_loguru_init, write_lazyframe + +ADMISSION_ID = "admissionid" +PATIENT_ID = "patientid" + + +def load_raw_aumc_file(fp: Path, **kwargs) -> pl.LazyFrame: + """Load a raw AUMCdb file into a Polars DataFrame. + + Args: + fp: The path to the AUMCdb file. + + Returns: + The Polars DataFrame containing the AUMCdb data. + Example: + >>> load_raw_aumc_file(Path("processitems.csv")).collect() + ┌─────────────┬────────┬──────────────────────┬──────────┬───────────┬──────────┐ + │ admissionid ┆ itemid ┆ item ┆ start ┆ stop ┆ duration │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ i64 ┆ str ┆ i64 ┆ i64 ┆ i64 │ + ╞═════════════╪════════╪══════════════════════╪══════════╪═══════════╪══════════╡ + │ 1 ┆ 1 ┆ "Pulse" ┆ 0 ┆ 100000 ┆ 100000 │ + └─────────────┴────────┴──────────────────────┴──────────┴───────────┴──────────┘ + """ + return pl.scan_csv(fp, infer_schema_length=10000000, encoding="utf8-lossy", **kwargs) + + +def process_patient_and_admissions(df: pl.LazyFrame) -> pl.LazyFrame: + """Takes the admissions table and converts it to a form that includes timestamps. + + As AUMCdb stores only offset times, note here that we add a CONSTANT TIME ACROSS ALL PATIENTS for the true + timestamp of their health system admission. This is acceptable because in AUMCdb ONLY RELATIVE TIME + DIFFERENCES ARE MEANINGFUL, NOT ABSOLUTE TIMES. + + The output of this process is ultimately converted to events via the `patient` key in the + `configs/event_configs.yaml` file. + """ + + origin_pseudotime = pl.datetime( + year=pl.col("admissionyeargroup").str.extract(r"(2003|2010)").cast(pl.Int32), month=1, day=1 + ) + + # TODO: consider using better logic to infer date of birth for patients + # with more than one admission. + age_in_years = ( + ( + pl.col("agegroup").str.extract("(\\d{2}).?$").cast(pl.Int32) + + pl.col("agegroup").str.extract("^(\\d{2})").cast(pl.Int32) + ) + / 2 + ).ceil() + age_in_days = age_in_years * 365.25 + # We assume that the patient was born at the midpoint of the year as we don't know the actual birthdate + pseudo_date_of_birth = origin_pseudotime - pl.duration(days=(age_in_days - 365.25 / 2)) + pseudo_date_of_death = origin_pseudotime + pl.duration(milliseconds=pl.col("dateofdeath")) + + return df.filter(pl.col("admissioncount") == 1).select( + PATIENT_ID, + pseudo_date_of_birth.alias("dateofbirth"), + "gender", + origin_pseudotime.alias("firstadmittedattime"), + pseudo_date_of_death.alias("dateofdeath"), + ), df.select(PATIENT_ID, ADMISSION_ID) + + +def join_and_get_pseudotime_fntr( + table_name: str, + offset_col: str | list[str], + pseudotime_col: str | list[str], + output_data_cols: list[str] | None = None, + warning_items: list[str] | None = None, +) -> Callable[[pl.LazyFrame, pl.LazyFrame], pl.LazyFrame]: + """Returns a function that joins a dataframe to the `patient` table and adds pseudotimes. + Also raises specified warning strings via the logger for uncertain columns. + All args except `table_name` are taken from the table_preprocessors.yaml. For example, for the + table `numericitems`, we have the following yaml configuration: + numericitems: + offset_col: + - "measuredat" + - "registeredat" + - "updatedat" + pseudotime_col: + - "measuredattime" + - "registeredattime" + - "updatedattime" + output_data_cols: + - "item" + - "value" + - "unit" + - "registeredby" + - "updatedby" + warning_items: + - "How should we deal with `registeredat` and `updatedat`?" + + Args: + table_name: name of the AUMCdb table that should be joined + offset_col: list of all columns that contain time offsets since the patient's first admission + pseudotime_col: list of all timestamp columns derived from `offset_col` and the linked `patient` + table + output_data_cols: list of all data columns included in the output + warning_items: any warnings noted in the table_preprocessors.yaml + + Returns: + Function that expects the raw data stored in the `table_name` table and the joined output of the + `process_patient_and_admissions` function. Both inputs are expected to be `pl.DataFrame`s. + + Examples: + >>> func = join_and_get_pseudotime_fntr("numericitems", ["measuredat", "registeredat", "updatedat"], + ["measuredattime", "registeredattime", "updatedattime"], + ["item", "value", "unit", "registeredby", "updatedby"], + ["How should we deal with `registeredat` and `updatedat`?"])` + >>> df = load_raw_aumc_file(in_fp) + >>> raw_admissions_df = load_raw_aumc_file(Path("admissions.csv")) + >>> patient_df, link_df = process_patient_and_admissions(raw_admissions_df) + >>> processed_df = func(df, patient_df) + >>> type(processed_df) + + """ + + if output_data_cols is None: + output_data_cols = [] + + if isinstance(offset_col, str): + offset_col = [offset_col] + if isinstance(pseudotime_col, str): + pseudotime_col = [pseudotime_col] + + if len(offset_col) != len(pseudotime_col): + raise ValueError( + "There must be the same number of `offset_col`s and `pseudotime_col`s specified. Got " + f"{len(offset_col)} and {len(pseudotime_col)}, respectively." + ) + + def fn(df: pl.LazyFrame, patient_df: pl.LazyFrame) -> pl.LazyFrame: + f"""Takes the {table_name} table and converts it to a form that includes pseudo-timestamps. + + The output of this process is ultimately converted to events via the `{table_name}` key in the + `configs/event_configs.yaml` file. + """ + pseudotimes = [ + (pl.col("firstadmittedattime") + pl.duration(milliseconds=pl.col(offset))).alias(pseudotime) + for pseudotime, offset in zip(pseudotime_col, offset_col) + ] + + if warning_items: + warning_lines = [ + f"NOT SURE ABOUT THE FOLLOWING for {table_name} table. Check with the AUMCdb team:", + *(f" - {item}" for item in warning_items), + ] + logger.warning("\n".join(warning_lines)) + + return df.join(patient_df, on=ADMISSION_ID, how="inner").select( + PATIENT_ID, + ADMISSION_ID, + *pseudotimes, + *output_data_cols, + ) + + return fn + + +@hydra.main(version_base=None, config_path="configs", config_name="pre_MEDS") +def main(cfg: DictConfig): + """Performs pre-MEDS data wrangling for AUMCdb. + + Inputs are the raw AUMCdb files, read from the `input_dir` config parameter. Output files are written + in processed form and as Parquet files to the `cohort_dir` config parameter. Hydra is used to manage + configuration parameters and logging. + """ + + hydra_loguru_init() + + table_preprocessors_config_fp = Path("./configs/table_preprocessors.yaml") + logger.info(f"Loading table preprocessors from {str(table_preprocessors_config_fp.resolve())}...") + preprocessors = OmegaConf.load(table_preprocessors_config_fp) + functions = {} + for table_name, preprocessor_cfg in preprocessors.items(): + logger.info(f" Adding preprocessor for {table_name}:\n{OmegaConf.to_yaml(preprocessor_cfg)}") + functions[table_name] = join_and_get_pseudotime_fntr(table_name=table_name, **preprocessor_cfg) + + raw_cohort_dir = Path(cfg.input_dir) + MEDS_input_dir = Path(cfg.cohort_dir) + + patient_out_fp = MEDS_input_dir / "patient.parquet" + link_out_fp = MEDS_input_dir / "link_patient_to_admission.parquet" + + if patient_out_fp.is_file(): + logger.info(f"Reloading processed patient df from {str(patient_out_fp.resolve())}") + patient_df = pl.read_parquet(patient_out_fp, use_pyarrow=True).lazy() + link_df = pl.read_parquet(link_out_fp, use_pyarrow=True).lazy() + else: + logger.info("Processing patient table first...") + + admissions_fp = raw_cohort_dir / "admissions.csv" + logger.info(f"Loading {str(admissions_fp.resolve())}...") + raw_admissions_df = load_raw_aumc_file(admissions_fp) + + logger.info("Processing patient table...") + patient_df, link_df = process_patient_and_admissions(raw_admissions_df) + write_lazyframe(patient_df, patient_out_fp) + write_lazyframe(link_df, link_out_fp) + + patient_df = patient_df.join(link_df, on=PATIENT_ID) + + all_fps = [fp for fp in raw_cohort_dir.glob("*.csv")] + + unused_tables = {} + + for in_fp in all_fps: + pfx = get_shard_prefix(raw_cohort_dir, in_fp) + if pfx in unused_tables: + logger.warning(f"Skipping {pfx} as it is not supported in this pipeline.") + continue + elif pfx not in functions: + logger.warning(f"No function needed for {pfx}. For AUMCdb, THIS IS UNEXPECTED") + continue + + out_fp = MEDS_input_dir / f"{pfx}.parquet" + + if out_fp.is_file(): + print(f"Done with {pfx}. Continuing") + continue + + out_fp.parent.mkdir(parents=True, exist_ok=True) + + fn = functions[pfx] + + st = datetime.now() + logger.info(f"Processing {pfx}...") + df = load_raw_aumc_file(in_fp) + processed_df = fn(df, patient_df) + processed_df.sink_parquet(out_fp) + logger.info(f" * Processed and wrote to {str(out_fp.resolve())} in {datetime.now() - st}") + + logger.info(f"Done! All dataframes processed and written to {str(MEDS_input_dir.resolve())}") + + +if __name__ == "__main__": + main() diff --git a/AUMCdb_Example/run.sh b/AUMCdb_Example/run.sh new file mode 100755 index 00000000..4097ab7c --- /dev/null +++ b/AUMCdb_Example/run.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes the AUMCdb (AmsterdamUMCdb, Amsterdam University Medical Center database, short version: AUMC) data through several steps," + echo "handling raw data conversion, sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." + echo + echo "Arguments:" + echo " AUMC_RAW_DIR Directory containing raw AUMCdb data files." + echo " AUMC_PREMEDS_DIR Output directory for pre-MEDS data." + echo " AUMC_MEDS_DIR Output directory for processed MEDS data." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +echo "Unsetting SLURM_CPU_BIND in case you're running this on a slurm interactive node with slurm parallelism" +unset SLURM_CPU_BIND + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -lt 3 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +export AUMC_RAW_DIR=$1 +export AUMC_PRE_MEDS_DIR=$2 +export AUMC_MEDS_COHORT_DIR=$3 +shift 3 + +# TODO: Add wget blocks once testing is validated. + +EVENT_CONVERSION_CONFIG_FP="$(pwd)/configs/event_configs.yaml" +PIPELINE_CONFIG_FP="$(pwd)/configs/extract_AUMC.yaml" +PRE_MEDS_PY_FP="$(pwd)/pre_MEDS.py" + +# We export these variables separately from their assignment so that any errors during assignment are caught. +export EVENT_CONVERSION_CONFIG_FP +export PIPELINE_CONFIG_FP +export PRE_MEDS_PY_FP + + +echo "Running pre-MEDS conversion." +python "$PRE_MEDS_PY_FP" input_dir="$AUMC_RAW_DIR" cohort_dir="$AUMC_PRE_MEDS_DIR" + +if [ -z "$N_WORKERS" ]; then + echo "Setting N_WORKERS to 1 to avoid issues with the runners." + export N_WORKERS="1" +fi + +echo "Running extraction pipeline." +MEDS_transform-runner "pipeline_config_fp=$PIPELINE_CONFIG_FP" "$@" diff --git a/AUMCdb_Example/slurm_runner.yaml b/AUMCdb_Example/slurm_runner.yaml new file mode 100644 index 00000000..4dbed261 --- /dev/null +++ b/AUMCdb_Example/slurm_runner.yaml @@ -0,0 +1,61 @@ +parallelize: + n_workers: ${oc.env:N_WORKERS} + launcher: "submitit_slurm" + +shard_events: + parallelize: + launcher_params: + timeout_min: 50 + cpus_per_task: 10 + mem_gb: 40 + partition: "short" + +split_and_shard_subjects: + parallelize: + n_workers: 1 + launcher_params: + timeout_min: 10 + cpus_per_task: 10 + mem_gb: 7 + partition: "short" + +convert_to_sharded_events: + parallelize: + launcher_params: + timeout_min: 10 + cpus_per_task: 10 + mem_gb: 25 + partition: "short" + +merge_to_MEDS_cohort: + parallelize: + launcher_params: + timeout_min: 15 + cpus_per_task: 10 + mem_gb: 85 + partition: "short" + +extract_code_metadata: + parallelize: + launcher_params: + timeout_min: 10 + cpus_per_task: 10 + mem_gb: 25 + partition: "short" + +finalize_MEDS_metadata: + parallelize: + n_workers: 1 + launcher_params: + timeout_min: 10 + cpus_per_task: 5 + mem_gb: 10 + partition: "short" + +finalize_MEDS_data: + parallelize: + launcher_params: + timeout_min: 10 + cpus_per_task: 10 + mem_gb: 70 + partition: "short" diff --git a/MIMIC-IV_Example/README.md b/MIMIC-IV_Example/README.md index 465c56a9..cd90e45c 100644 --- a/MIMIC-IV_Example/README.md +++ b/MIMIC-IV_Example/README.md @@ -1,8 +1,7 @@ # MIMIC-IV Example This is an example of how to extract a MEDS dataset from MIMIC-IV. All scripts in this README are assumed to -be run **not** from this directory but from the root directory of this entire repository (e.g., one directory -up from this one). +be run from this directory or from the directory in which the files in Step 0.5. were downloaded. ## Step 0: Installation @@ -21,8 +20,8 @@ If you want to profile the time and memory costs of your ETL, also install: `pip Set some environment variables and download the necessary files: ```bash export MIMICIV_RAW_DIR=??? # set to the directory in which you want to store the raw MIMIC-IV data -export MIMICIV_PRE_MEDS_DIR=??? # set to the directory in which you want to store the raw MIMIC-IV data -export MIMICIV_MEDS_COHORT_DIR=??? # set to the directory in which you want to store the raw MIMIC-IV data +export MIMICIV_PRE_MEDS_DIR=??? # set to the directory in which you want to store the intermediate MEDS MIMIC-IV data +export MIMICIV_MEDS_COHORT_DIR=??? # set to the directory in which you want to store the final MEDS MIMIC-IV data export URL="https://raw.githubusercontent.com/mmcdermott/MEDS_transforms/$VERSION/MIMIC-IV_Example" @@ -50,18 +49,18 @@ the root directory of where the resulting _core data files_ are stored -- e.g., ## Step 1.5: Download MIMIC-IV Metadata files ```bash -cd $MIMIC_RAW_DIR -export MIMIC_URL=https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map -wget $MIMIC_URL/d_labitems_to_loinc.csv -wget $MIMIC_URL/inputevents_to_rxnorm.csv -wget $MIMIC_URL/lab_itemid_to_loinc.csv -wget $MIMIC_URL/meas_chartevents_main.csv -wget $MIMIC_URL/meas_chartevents_value.csv -wget $MIMIC_URL/numerics-summary.csv -wget $MIMIC_URL/outputevents_to_loinc.csv -wget $MIMIC_URL/proc_datetimeevents.csv -wget $MIMIC_URL/proc_itemid.csv -wget $MIMIC_URL/waveforms-summary.csv +cd $MIMICIV_RAW_DIR +export MIMICIV_RAW_DIR=https://raw.githubusercontent.com/MIT-LCP/mimic-code/v2.4.0/mimic-iv/concepts/concept_map +wget $MIMICIV_RAW_DIR/d_labitems_to_loinc.csv +wget $MIMICIV_RAW_DIR/inputevents_to_rxnorm.csv +wget $MIMICIV_RAW_DIR/lab_itemid_to_loinc.csv +wget $MIMICIV_RAW_DIR/meas_chartevents_main.csv +wget $MIMICIV_RAW_DIR/meas_chartevents_value.csv +wget $MIMICIV_RAW_DIR/numerics-summary.csv +wget $MIMICIV_RAW_DIR/outputevents_to_loinc.csv +wget $MIMICIV_RAW_DIR/proc_datetimeevents.csv +wget $MIMICIV_RAW_DIR/proc_itemid.csv +wget $MIMICIV_RAW_DIR/waveforms-summary.csv ``` ## Step 2: Run the MEDS ETL @@ -69,9 +68,11 @@ wget $MIMIC_URL/waveforms-summary.csv To run the MEDS ETL, run the following command: ```bash -./run.sh $MIMICIV_RAW_DIR $MIMICIV_PRE_MEDS_DIR $MIMICIV_MEDS_DIR do_unzip=true +./run.sh $MIMICIV_RAW_DIR $MIMICIV_PRE_MEDS_DIR $MIMICIV_MEDS_COHORT_DIR do_unzip=true ``` - +> [!NOTE] +> This can take up large amounts of memory if not parallelized. You can reduce the shard size to reduce memory usage by setting the `shard_size` parameter in the `extract_MIMIC.yaml` file. +> Check that your environment variables are set correctly. To not unzip the `.csv.gz` files, set `do_unzip=false` instead of `do_unzip=true`. To use a specific stage runner file (e.g., to set different parallelism options), you can specify it as an diff --git a/MIMIC-IV_Example/configs/extract_MIMIC.yaml b/MIMIC-IV_Example/configs/extract_MIMIC.yaml index eb9b32ee..650d6e56 100644 --- a/MIMIC-IV_Example/configs/extract_MIMIC.yaml +++ b/MIMIC-IV_Example/configs/extract_MIMIC.yaml @@ -25,6 +25,8 @@ etl_metadata: stage_configs: shard_events: infer_schema_length: 999999999 + split_and_shard_subjects: + n_subjects_per_shard: 1000 stages: - shard_events diff --git a/MIMIC-IV_Example/pre_MEDS.py b/MIMIC-IV_Example/pre_MEDS.py index 846c3a9d..d8b39ba3 100755 --- a/MIMIC-IV_Example/pre_MEDS.py +++ b/MIMIC-IV_Example/pre_MEDS.py @@ -229,7 +229,7 @@ def main(cfg: DictConfig): ) exit(0) - all_fps = list(input_dir.glob("**/*.*")) + all_fps = list(input_dir.rglob("*/*.*")) dfs_to_load = {} seen_fps = {} @@ -299,9 +299,9 @@ def main(cfg: DictConfig): logger.info(f"Loading {str(df_to_load_fp.resolve())} for manipulating other dataframes...") if df_to_load_fp.suffix in [".csv.gz"]: - df = read_fn(df_to_load_fp, columns=cols) + df = df_to_load_read_fn(df_to_load_fp, columns=cols) else: - df = read_fn(df_to_load_fp) + df = df_to_load_read_fn(df_to_load_fp) logger.info(f" Loaded in {datetime.now() - st}") for fp in fps: diff --git a/MIMIC-IV_Example/run.sh b/MIMIC-IV_Example/run.sh index 9c06c7e9..e5c0513b 100755 --- a/MIMIC-IV_Example/run.sh +++ b/MIMIC-IV_Example/run.sh @@ -35,6 +35,12 @@ if [ "$#" -lt 3 ]; then display_help fi +# Check that the do_unzip flag is not set as a positional argument +if [[ "$1" == "do_unzip=true" || "$1" == "do_unzip=false" || "$2" == "do_unzip=true" || "$2" == "do_unzip=false" || "$3" == "do_unzip=true" || "$3" == "do_unzip=false" ]]; then + echo "Error: Incorrect number of arguments provided. Check if your environment variables are set correctly." + display_help +fi + export MIMICIV_RAW_DIR=$1 export MIMICIV_PRE_MEDS_DIR=$2 export MIMICIV_MEDS_COHORT_DIR=$3 diff --git a/eICU_Example/README.md b/eICU_Example/README.md index cc820067..fb1fb036 100644 --- a/eICU_Example/README.md +++ b/eICU_Example/README.md @@ -49,35 +49,22 @@ cd .. Download the eICU-CRD dataset (version 2.0) from https://physionet.org/content/eicu-crd/2.0/ following the instructions on that page. You will need the raw `.csv.gz` files for this example. We will use -`$EICU_RAW_DIR` to denote the root directory of where the resulting _core data files_ are stored -- e.g., -there should be a `hosp` and `icu` subdirectory of `$EICU_RAW_DIR`. +`$EICU_RAW_DIR` to denote the root directory of where the resulting _core data files_ are stored -## Step 2: Get the data ready for base MEDS extraction +## Step 2: Run the MEDS extraction ETL -This is a step in a few parts: - -1. Join a few tables by `hadm_id` to get the right timestamps in the right rows for processing. In - particular, we need to join: - - TODO -2. Convert the subject's static data to a more parseable form. This entails: - - Get the subject's DOB in a format that is usable for MEDS, rather than the integral `anchor_year` and - `anchor_offset` fields. - - Merge the subject's `dod` with the `deathtime` from the `admissions` table. - -After these steps, modified files or symlinks to the original files will be written in a new directory which -will be used as the input to the actual MEDS extraction ETL. We'll use `$EICU_PREMEDS_DIR` to denote this -directory. - -To run this step, you can use the following script (assumed to be run **not** from this directory but from the -root directory of this repository): +To run the MEDS ETL, run the following command: ```bash -./eICU_Example/pre_MEDS.py raw_cohort_dir=$EICU_RAW_DIR output_dir=$EICU_PREMEDS_DIR +./run.sh $EICU_RAW_DIR $EICU_PRE_MEDS_DIR $EICU_MEDS_COHORT_DIR $N_PARALLEL_WORKERS do_unzip=true ``` -In practice, on a machine with 150 GB of RAM and 10 cores, this step takes less than 5 minutes in total. +To not unzip the `.csv.gz` files, set `do_unzip=false` instead of `do_unzip=true`. + +To use a specific stage runner file (e.g., to set different parallelism options), you can specify it as an +additional argument -## Step 3: Run the MEDS extraction ETL +The `N_PARALLEL_WORKERS` variable controls how many parallel workers should be used at maximum. Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable memory burden (e.g., \< 150GB per worker), you will want a smaller shard size, as well as to turn off the final unique diff --git a/eICU_Example/configs/event_configs.yaml b/eICU_Example/configs/event_configs.yaml index fb7901cf..708863c1 100644 --- a/eICU_Example/configs/event_configs.yaml +++ b/eICU_Example/configs/event_configs.yaml @@ -21,51 +21,51 @@ patient: - col(hospitalregion) - col(hospitalteachingstatus) - col(hospitalnumbedscategory) - time: col(hospitaladmittime) + time: col(hospitaladmittimestamp) hospital_id: "hospitalid" hosp_discharge: code: - "HOSPITAL_DISCHARGE" - col(hospitaldischargestatus) - col(hospitaldischargelocation) - time: col(hospitaldischargetime) + time: col(hospitaldischargetimestamp) unit_admission: code: - "UNIT_ADMISSION" - col(unitadmitsource) - col(unitstaytype) - time: col(unitadmittime) + time: col(unitadmittimestamp) ward_id: "wardid" unit_stay_id: "patientunitstayid" unit_admission_weight: code: - "UNIT_ADMISSION_WEIGHT" - time: col(unitadmittime) + time: col(unitadmittimestamp) numeric_value: "unitadmissionweight" unit_admission_height: code: - "UNIT_ADMISSION_HEIGHT" - time: col(unitadmittime) + time: col(unitadmittimestamp) numeric_value: "unitadmissionheight" unit_discharge: code: - "UNIT_DISCHARGE" - col(unitdischargestatus) - col(unitdischargelocation) - time: col(unitdischargetime) + time: col(unitdischargetimestamp) unit_discharge_weight: code: - "UNIT_DISCHARGE_WEIGHT" - time: col(unitdischargetime) + time: col(unitdischargetimestamp) numeric_value: "unitdischargeweight" -admissiondx: +admissionDx: admission_diagnosis: code: - "ADMISSION_DX" - col(admitdxname) time: col(admitDxEnteredTimestamp) - admission_dx_id: "admitDxID" + admission_dx_id: "admissiondxid" unit_stay_id: "patientunitstayid" allergy: @@ -153,7 +153,7 @@ medication: - "MEDICATION" - "ORDERED" - col(drugname) - time: col(drugordertime) + time: col(drugordertimestamp) medication_id: "medicationid" drug_iv_admixture: "drugivadmixture" dosage: "dosage" @@ -167,14 +167,14 @@ medication: - "MEDICATION" - "STARTED" - col(drugname) - time: col(drugstarttime) + time: col(drugstarttimestamp) medication_id: "medicationid" drug_stopped: code: - "MEDICATION" - "STOPPED" - col(drugname) - time: col(drugstoptime) + time: col(drugstoptimestamp) medication_id: "medicationid" nurseAssessment: diff --git a/eICU_Example/configs/extract_eICU.yaml b/eICU_Example/configs/extract_eICU.yaml new file mode 100644 index 00000000..9abe9153 --- /dev/null +++ b/eICU_Example/configs/extract_eICU.yaml @@ -0,0 +1,39 @@ +defaults: + - _extract + - _self_ + +description: |- + This pipeline extracts the eICU dataset in longitudinal, sparse form from an input dataset meeting + select criteria and converts them to the flattened, MEDS format. You can control the key arguments to this + pipeline by setting environment variables: + ```bash + export EVENT_CONVERSION_CONFIG_FP=# Path to your event conversion config + export EICU_PRE_MEDS_DIR=# Path to the output dir of the pre-MEDS step + export EICU_MEDS_COHORT_DIR=# Path to where you want the dataset to live + ``` + +# The event conversion configuration file is used throughout the pipeline to define the events to extract. +event_conversion_config_fp: ${oc.env:EVENT_CONVERSION_CONFIG_FP} + +input_dir: ${oc.env:EICU_PRE_MEDS_DIR} +cohort_dir: ${oc.env:EICU_MEDS_COHORT_DIR} + +etl_metadata: + dataset_name: eICU + dataset_version: 2.0 + +stage_configs: + shard_events: + infer_schema_length: 999999999 + split_and_shard_subjects: + n_subjects_per_shard: 10000 + merge_to_MEDS_cohort: + unique_by: null + +stages: + - shard_events + - split_and_shard_subjects + - convert_to_sharded_events + - merge_to_MEDS_cohort + - finalize_MEDS_metadata + - finalize_MEDS_data diff --git a/eICU_Example/configs/table_preprocessors.yaml b/eICU_Example/configs/table_preprocessors.yaml index a3ad2c30..29049887 100644 --- a/eICU_Example/configs/table_preprocessors.yaml +++ b/eICU_Example/configs/table_preprocessors.yaml @@ -1,8 +1,10 @@ -admissiondx: +admissionDx: offset_col: "admitdxenteredoffset" pseudotime_col: "admitDxEnteredTimestamp" - output_data_cols: ["admitdxname", "admitdxid"] - warning_items: ["How should we use `admitdxtest`?", "How should we use `admitdxpath`?"] + output_data_cols: ["admitdxname", "admissiondxid"] + warning_items: + - "How should we use `admitdxtest`?" + - "How should we use `admitdxpath`?" allergy: offset_col: "allergyenteredoffset" diff --git a/eICU_Example/joint_script.sh b/eICU_Example/joint_script.sh deleted file mode 100755 index 0b3ad6c5..00000000 --- a/eICU_Example/joint_script.sh +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env bash - -# This makes the script fail if any internal script fails -set -e - -# Function to display help message -function display_help() { - echo "Usage: $0 " - echo - echo "This script processes eICU data through several steps, handling raw data conversion," - echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." - echo - echo "Arguments:" - echo " EICU_RAW_DIR Directory containing raw eICU data files." - echo " EICU_PREMEDS_DIR Output directory for pre-MEDS data." - echo " EICU_MEDS_DIR Output directory for processed MEDS data." - echo " N_PARALLEL_WORKERS Number of parallel workers for processing." - echo - echo "Options:" - echo " -h, --help Display this help message and exit." - exit 1 -} - -# Check if the first parameter is '-h' or '--help' -if [[ "$1" == "-h" || "$1" == "--help" ]]; then - display_help -fi - -# Check for mandatory parameters -if [ "$#" -lt 4 ]; then - echo "Error: Incorrect number of arguments provided." - display_help -fi - -EICU_RAW_DIR="$1" -EICU_PREMEDS_DIR="$2" -EICU_MEDS_DIR="$3" -N_PARALLEL_WORKERS="$4" - -shift 4 - -echo "Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable " -echo "memory burden (e.g., < 150GB per worker), you will want a smaller shard size, as well as to turn off " -echo "the final unique check (which should not be necessary given the structure of eICU and is expensive) " -echo "in the merge stage. You can do this by setting the following parameters at the end of the mandatory " -echo "args when running this script:" -echo " * stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000" -echo " * stage_configs.merge_to_MEDS_cohort.unique_by=null" - -echo "Running pre-MEDS conversion." -./eICU_Example/pre_MEDS.py raw_cohort_dir="$EICU_RAW_DIR" output_dir="$EICU_PREMEDS_DIR" - -echo "Running shard_events.py with $N_PARALLEL_WORKERS workers in parallel" -./scripts/extraction/shard_events.py \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$EICU_PREMEDS_DIR" \ - cohort_dir="$EICU_MEDS_DIR" \ - event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" - -echo "Splitting subjects in serial" -./scripts/extraction/split_and_shard_subjects.py \ - input_dir="$EICU_PREMEDS_DIR" \ - cohort_dir="$EICU_MEDS_DIR" \ - event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" - -echo "Converting to sharded events with $N_PARALLEL_WORKERS workers in parallel" -./scripts/extraction/convert_to_sharded_events.py \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$EICU_PREMEDS_DIR" \ - cohort_dir="$EICU_MEDS_DIR" \ - event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" - -echo "Merging to a MEDS cohort with $N_PARALLEL_WORKERS workers in parallel" -./scripts/extraction/merge_to_MEDS_cohort.py \ - --multirun \ - worker="range(0,$N_PARALLEL_WORKERS)" \ - hydra/launcher=joblib \ - input_dir="$EICU_PREMEDS_DIR" \ - cohort_dir="$EICU_MEDS_DIR" \ - event_conversion_config_fp=./eICU_Example/configs/event_configs.yaml "$@" diff --git a/eICU_Example/pre_MEDS.py b/eICU_Example/pre_MEDS.py index 5ebe0582..e00fe72a 100755 --- a/eICU_Example/pre_MEDS.py +++ b/eICU_Example/pre_MEDS.py @@ -274,7 +274,7 @@ def main(cfg: DictConfig): hydra_loguru_init() - table_preprocessors_config_fp = Path("./eICU_Example/configs/table_preprocessors.yaml") + table_preprocessors_config_fp = Path("./configs/table_preprocessors.yaml") logger.info(f"Loading table preprocessors from {str(table_preprocessors_config_fp.resolve())}...") preprocessors = OmegaConf.load(table_preprocessors_config_fp) functions = {} diff --git a/eICU_Example/run.sh b/eICU_Example/run.sh new file mode 100644 index 00000000..c236d3db --- /dev/null +++ b/eICU_Example/run.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +# This makes the script fail if any internal script fails +set -e + +# Function to display help message +function display_help() { + echo "Usage: $0 " + echo + echo "This script processes eICU data through several steps, handling raw data conversion," + echo "sharding events, splitting subjects, converting to sharded events, and merging into a MEDS cohort." + echo + echo "Arguments:" + echo " EICU_RAW_DIR Directory containing raw eICU data files." + echo " EICU_PREMEDS_DIR Output directory for pre-MEDS data." + echo " EICU_MEDS_DIR Output directory for processed MEDS data." + echo " (OPTIONAL) do_unzip=true OR do_unzip=false Optional flag to unzip files before processing." + echo + echo "Options:" + echo " -h, --help Display this help message and exit." + exit 1 +} + +# Check if the first parameter is '-h' or '--help' +if [[ "$1" == "-h" || "$1" == "--help" ]]; then + display_help +fi + +# Check for mandatory parameters +if [ "$#" -lt 4 ]; then + echo "Error: Incorrect number of arguments provided." + display_help +fi + +EICU_RAW_DIR="$1" +EICU_PRE_MEDS_DIR="$2" +EICU_MEDS_COHORT_DIR="$3" + +export EICU_PRE_MEDS_DIR="$2" +export EICU_MEDS_COHORT_DIR="$3" + +shift 4 + +echo "Note that eICU has a lot more observations per subject than does MIMIC-IV, so to keep to a reasonable " +echo "memory burden (e.g., < 150GB per worker), you will want a smaller shard size, as well as to turn off " +echo "the final unique check (which should not be necessary given the structure of eICU and is expensive) " +echo "in the merge stage. You can do this by setting the following parameters at the end of the mandatory " +echo "args when running this script:" +echo " * stage_configs.split_and_shard_subjects.n_subjects_per_shard=10000" +echo " * stage_configs.merge_to_MEDS_cohort.unique_by=null" +echo "Additionally, consider reducing N_PARALLEL_WORKERS if > 1" + +# Defaults +_DO_UNZIP_ARG_STR="" + +if [ $# -ge 1 ]; then + case "$1" in + do_unzip=*) + _DO_UNZIP_ARG_STR="$1" + shift 1 + ;; + esac +fi + +DO_UNZIP="false" + +if [ -n "$_DO_UNZIP_ARG_STR" ]; then + case "$_DO_UNZIP_ARG_STR" in + do_unzip=true) + DO_UNZIP="true" + ;; + do_unzip=false) + DO_UNZIP="false" + ;; + *) + echo "Error: Invalid do_unzip value. Use 'do_unzip=true' or 'do_unzip=false'." + exit 1 + ;; + esac + echo "Setting DO_UNZIP=$DO_UNZIP" +fi + +# TODO: Add wget blocks once testing is validated. +EVENT_CONVERSION_CONFIG_FP="$(pwd)/configs/event_configs.yaml" +PIPELINE_CONFIG_FP="$(pwd)/configs/extract_eICU.yaml" +PRE_MEDS_PY_FP="$(pwd)/pre_MEDS.py" + +# We export these variables separately from their assignment so that any errors during assignment are caught. +export EVENT_CONVERSION_CONFIG_FP +export PIPELINE_CONFIG_FP +export PRE_MEDS_PY_FP + + +if [ "$DO_UNZIP" == "true" ]; then + GZ_FILES="${EICU_RAW_DIR}/*.csv.gz" + if compgen -G "$GZ_FILES" > /dev/null; then + echo "Unzipping csv.gz files matching $GZ_FILES." + for file in $GZ_FILES; do gzip -d --force "$file"; done + else + echo "No csz.gz files to unzip at $GZ_FILES." + fi +else + echo "Skipping unzipping." +fi + +echo "Running pre-MEDS conversion." +./pre_MEDS.py raw_cohort_dir="$EICU_RAW_DIR" output_dir="$EICU_PRE_MEDS_DIR" + +if [ -z "$N_WORKERS" ]; then + echo "Setting N_WORKERS to 1 to avoid issues with the runners." + export N_WORKERS="1" +fi + +echo "Running extraction pipeline." +MEDS_transform-runner "pipeline_config_fp=$PIPELINE_CONFIG_FP" "$@" diff --git a/src/MEDS_transforms/aggregate_code_metadata.py b/src/MEDS_transforms/aggregate_code_metadata.py index 1ac8cdfd..c235d2c7 100755 --- a/src/MEDS_transforms/aggregate_code_metadata.py +++ b/src/MEDS_transforms/aggregate_code_metadata.py @@ -695,7 +695,7 @@ def run_map_reduce(cfg: DictConfig): logger.info("Starting reduction process") - while not all(is_complete_parquet_file(fp) for fp in all_out_fps): + while not all(is_complete_parquet_file(fp) for fp in all_out_fps): # pragma: no cover logger.info("Waiting to begin reduction for all files to be written...") time.sleep(cfg.polling_time) diff --git a/src/MEDS_transforms/extract/convert_to_sharded_events.py b/src/MEDS_transforms/extract/convert_to_sharded_events.py index 39aea54f..1cbb665e 100755 --- a/src/MEDS_transforms/extract/convert_to_sharded_events.py +++ b/src/MEDS_transforms/extract/convert_to_sharded_events.py @@ -344,6 +344,52 @@ def extract_event( │ 2 ┆ DISCHARGE//Home ┆ 2021-01-05 15:23:45 ┆ AOx4 ┆ Home │ │ 3 ┆ DISCHARGE//SNF ┆ 2021-01-06 16:34:56 ┆ AOx4 ┆ SNF │ └────────────┴─────────────────┴─────────────────────┴───────────────────┴────────────┘ + + If we make a non-key field use the `col(...)` syntax, it will log a warning but parse the field. + >>> valid_discharge_event_cfg = { + ... "code": ["DISCHARGE", "col(discharge_location)"], + ... "time": "col(discharge_time)", + ... "categorical_value": "col(discharge_status)", # Note the raw dtype of this col is str + ... "text_value": "discharge_location", # Note the raw dtype of this col is categorical + ... } + >>> extract_event(complex_raw_data, valid_discharge_event_cfg) + shape: (6, 5) + ┌────────────┬─────────────────┬─────────────────────┬───────────────────┬────────────┐ + │ subject_id ┆ code ┆ time ┆ categorical_value ┆ text_value │ + │ --- ┆ --- ┆ --- ┆ --- ┆ --- │ + │ u8 ┆ str ┆ datetime[μs] ┆ str ┆ str │ + ╞════════════╪═════════════════╪═════════════════════╪═══════════════════╪════════════╡ + │ 1 ┆ DISCHARGE//Home ┆ 2021-01-01 11:23:45 ┆ AOx4 ┆ Home │ + │ 1 ┆ DISCHARGE//SNF ┆ 2021-01-02 12:34:56 ┆ AO ┆ SNF │ + │ 2 ┆ DISCHARGE//Home ┆ 2021-01-03 13:45:56 ┆ AAO ┆ Home │ + │ 2 ┆ DISCHARGE//SNF ┆ 2021-01-04 14:56:45 ┆ AOx3 ┆ SNF │ + │ 2 ┆ DISCHARGE//Home ┆ 2021-01-05 15:23:45 ┆ AOx4 ┆ Home │ + │ 3 ┆ DISCHARGE//SNF ┆ 2021-01-06 16:34:56 ┆ AOx4 ┆ SNF │ + └────────────┴─────────────────┴─────────────────────┴───────────────────┴────────────┘ + + If a `categorical_value` field is of non-string type, it will be converted. + >>> valid_admission_event_cfg = { + ... "code": ["ADMISSION", "col(admission_type)"], + ... "time": "col(admission_time)", + ... "time_format": "%Y-%m-%d %H:%M:%S", + ... "categorical_value": "severity_score", + ... } + >>> extract_event(complex_raw_data, valid_admission_event_cfg) + shape: (6, 4) + ┌────────────┬──────────────┬─────────────────────┬───────────────────┐ + │ subject_id ┆ code ┆ time ┆ categorical_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ u8 ┆ str ┆ datetime[μs] ┆ str │ + ╞════════════╪══════════════╪═════════════════════╪═══════════════════╡ + │ 1 ┆ ADMISSION//A ┆ 2021-01-01 00:00:00 ┆ 1.0 │ + │ 1 ┆ ADMISSION//B ┆ 2021-01-02 00:00:00 ┆ 2.0 │ + │ 2 ┆ ADMISSION//C ┆ 2021-01-03 00:00:00 ┆ 3.0 │ + │ 2 ┆ ADMISSION//D ┆ 2021-01-04 00:00:00 ┆ 4.0 │ + │ 2 ┆ ADMISSION//E ┆ 2021-01-05 00:00:00 ┆ 5.0 │ + │ 3 ┆ ADMISSION//F ┆ 2021-01-06 00:00:00 ┆ 6.0 │ + └────────────┴──────────────┴─────────────────────┴───────────────────┘ + + More examples: >>> extract_event(complex_raw_data, valid_death_event_cfg) shape: (3, 3) ┌────────────┬───────┬─────────────────────┐ @@ -395,6 +441,10 @@ def extract_event( Traceback (most recent call last): ... ValueError: Source column 'discharge_time' for event column foobar is not numeric, string, or categorical! Cannot be used as an event col. + >>> extract_event(complex_raw_data, {"code": "col(NOT_PRESENT)", "time": None}) + Traceback (most recent call last): + ... + KeyError: "Source column 'NOT_PRESENT' for event column code not found in DataFrame schema." """ # noqa: E501 event_cfg = copy.deepcopy(event_cfg) event_exprs = {"subject_id": pl.col("subject_id")} @@ -761,7 +811,7 @@ def compute_fn(df: pl.LazyFrame) -> pl.LazyFrame: event_cfgs=copy.deepcopy(event_cfgs), do_dedup_text_and_numeric=cfg.stage_cfg.get("do_dedup_text_and_numeric", False), ) - except Exception as e: + except Exception as e: # pragma: no cover raise ValueError( f"Error converting {str(shard_fp.resolve())} for {sp}/{input_prefix}: {e}" ) from e diff --git a/src/MEDS_transforms/extract/extract_code_metadata.py b/src/MEDS_transforms/extract/extract_code_metadata.py index 3460cfbd..279ce513 100644 --- a/src/MEDS_transforms/extract/extract_code_metadata.py +++ b/src/MEDS_transforms/extract/extract_code_metadata.py @@ -397,13 +397,13 @@ def main(cfg: DictConfig): logger.info("Extracted metadata for all events. Merging.") - if cfg.worker != 0: + if cfg.worker != 0: # pragma: no cover logger.info("Code metadata extraction completed. Exiting") return logger.info("Starting reduction process") - while not all(fp.exists() for fp in all_out_fps): + while not all(fp.exists() for fp in all_out_fps): # pragma: no cover missing_files_str = "\n".join(f" - {str(fp.resolve())}" for fp in all_out_fps if not fp.exists()) logger.info("Waiting to begin reduction for all files to be written...\n" f"{missing_files_str}") time.sleep(cfg.polling_time) diff --git a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py index a0201803..a02932b8 100755 --- a/src/MEDS_transforms/extract/finalize_MEDS_metadata.py +++ b/src/MEDS_transforms/extract/finalize_MEDS_metadata.py @@ -2,6 +2,7 @@ """Utilities for finalizing the metadata files for extracted MEDS datasets.""" import json +from collections import defaultdict from datetime import datetime from pathlib import Path @@ -17,12 +18,9 @@ code_metadata_schema, dataset_metadata_filepath, dataset_metadata_schema, - held_out_split, subject_id_field, subject_split_schema, subject_splits_filepath, - train_split, - tuning_split, ) from omegaconf import DictConfig @@ -147,7 +145,7 @@ def main(cfg: DictConfig): etl_metadata.dataset_version: The version of the dataset being extracted. """ - if cfg.worker != 0: + if cfg.worker != 0: # pragma: no cover logger.info("Non-zero worker found in reduce-only stage. Exiting") return @@ -206,12 +204,10 @@ def main(cfg: DictConfig): logger.info("Creating subject splits from {str(shards_map_fp.resolve())}") shards_map = json.loads(shards_map_fp.read_text()) subject_splits = [] - seen_splits = {train_split: 0, tuning_split: 0, held_out_split: 0} + seen_splits = defaultdict(int) for shard, subject_ids in shards_map.items(): split = "/".join(shard.split("/")[:-1]) - if split not in seen_splits: - seen_splits[split] = 0 seen_splits[split] += len(subject_ids) subject_splits.extend([{subject_id_field: pid, "split": split} for pid in subject_ids]) @@ -219,7 +215,7 @@ def main(cfg: DictConfig): for split, cnt in seen_splits.items(): if cnt: logger.info(f"Split {split} has {cnt} subjects") - else: + else: # pragma: no cover logger.warning(f"Split {split} not found in shards map") subject_splits_tbl = pa.Table.from_pylist(subject_splits, schema=subject_split_schema) diff --git a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py index 2b75eb0f..adf36cfc 100755 --- a/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py +++ b/src/MEDS_transforms/extract/merge_to_MEDS_cohort.py @@ -137,7 +137,7 @@ def merge_subdirs_and_sort( ... merge_subdirs_and_sort( ... sp_dir, ... event_subsets=["subdir1", "subdir2"], - ... unique_by=["subject_id", "time", "code"], + ... unique_by=["subject_id", "time", "code", "missing_col_will_not_error"], ... additional_sort_by=["code", "numeric_value"] ... ).select("subject_id", "time", "code").collect() shape: (6, 3) @@ -153,6 +153,42 @@ def merge_subdirs_and_sort( │ 2 ┆ 20 ┆ B │ │ 3 ┆ 8 ┆ E │ └────────────┴──────┴──────┘ + >>> with TemporaryDirectory() as tmpdir: + ... sp_dir = Path(tmpdir) + ... (sp_dir / "subdir1").mkdir() + ... df1.write_parquet(sp_dir / "subdir1" / "file1.parquet") + ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") + ... (sp_dir / "subdir2").mkdir() + ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") + ... # We just display the subject ID, time, and code columns as the numeric value column + ... # is not guaranteed to be deterministic in the output given some rows will be dropped due to + ... # the unique-by constraint. + ... merge_subdirs_and_sort( + ... sp_dir, + ... event_subsets=["subdir1", "subdir2"], + ... unique_by=352.2, # This will error + ... ) + Traceback (most recent call last): + ... + ValueError: Invalid unique_by value: 352.2 + >>> with TemporaryDirectory() as tmpdir: + ... sp_dir = Path(tmpdir) + ... (sp_dir / "subdir1").mkdir() + ... df1.write_parquet(sp_dir / "subdir1" / "file1.parquet") + ... df2.write_parquet(sp_dir / "subdir1" / "file2.parquet") + ... (sp_dir / "subdir2").mkdir() + ... df3.write_parquet(sp_dir / "subdir2" / "df.parquet") + ... # We just display the subject ID, time, and code columns as the numeric value column + ... # is not guaranteed to be deterministic in the output given some rows will be dropped due to + ... # the unique-by constraint. + ... merge_subdirs_and_sort( + ... sp_dir, + ... event_subsets=["subdir1", "subdir2", "subdir3", "this is missing so will error"], + ... unique_by=None, + ... ) + Traceback (most recent call last): + ... + RuntimeError: Number of found subsets (2) does not match number of subsets in event_config (4): ... """ files_to_read = [fp for es in event_subsets for fp in (sp_dir / es).glob("*.parquet")] if not files_to_read: diff --git a/src/MEDS_transforms/extract/split_and_shard_subjects.py b/src/MEDS_transforms/extract/split_and_shard_subjects.py index 0dc3342f..cf968e52 100755 --- a/src/MEDS_transforms/extract/split_and_shard_subjects.py +++ b/src/MEDS_transforms/extract/split_and_shard_subjects.py @@ -251,7 +251,7 @@ def main(cfg: DictConfig): if not external_splits_json_fp.exists(): raise FileNotFoundError(f"External splits JSON file not found at {external_splits_json_fp}") - logger.info(f"Reading external splits from {str(cfg.stage_cfg.external_splits_json_fp.resolve())}") + logger.info(f"Reading external splits from {str(external_splits_json_fp.resolve())}") external_splits = json.loads(external_splits_json_fp.read_text()) size_strs = ", ".join(f"{k}: {len(v)}" for k, v in external_splits.items()) diff --git a/src/MEDS_transforms/mapreduce/mapper.py b/src/MEDS_transforms/mapreduce/mapper.py index ade9910a..9ce69c8d 100644 --- a/src/MEDS_transforms/mapreduce/mapper.py +++ b/src/MEDS_transforms/mapreduce/mapper.py @@ -426,12 +426,26 @@ def match_revise_fntr(cfg: DictConfig, stage_cfg: DictConfig, compute_fn: ANY_CO ValueError: Missing needed columns {'missing'} for local matcher 0: [(col("missing")) == (String(CODE//TEMP_2))].all_horizontal() Columns available: 'code', 'initial_idx', 'subject_id', 'time' + + It will throw an error if the match and revise configuration is missing. >>> stage_cfg = DictConfig({"global_code_end": "foo"}) >>> cfg = DictConfig({"stage_cfg": stage_cfg}) >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) Traceback (most recent call last): ... ValueError: Invalid match and revise configuration... + + It does not accept invalid modes. + >>> stage_cfg = DictConfig({ + ... "global_code_end": "foo", + ... "_match_revise_mode": "foobar", + ... "_match_revise": [{"_matcher": {"code": "CODE//TEMP_2"}}] + ... }) + >>> cfg = DictConfig({"stage_cfg": stage_cfg}) + >>> match_revise_fn = match_revise_fntr(cfg, stage_cfg, compute_fn) + Traceback (most recent call last): + ... + ValueError: Invalid match and revise mode: foobar """ try: validate_match_revise(stage_cfg) @@ -639,13 +653,13 @@ def map_over( .collect()[subject_id_field] .to_list() ) - read_fn = read_and_filter_fntr(train_subjects, read_fn) + read_fn = read_and_filter_fntr(pl.col("subject_id").is_in(train_subjects), read_fn) else: raise FileNotFoundError( f"Train split requested, but shard prefixes can't be used and " f"subject split file not found at {str(split_fp.resolve())}." ) - elif includes_only_train: + elif includes_only_train: # pragma: no cover raise ValueError("All splits should be used, but shard iterator is returning only train splits?!?") if is_match_revise(cfg.stage_cfg): diff --git a/src/MEDS_transforms/mapreduce/utils.py b/src/MEDS_transforms/mapreduce/utils.py index 716ddc09..cfd184c6 100644 --- a/src/MEDS_transforms/mapreduce/utils.py +++ b/src/MEDS_transforms/mapreduce/utils.py @@ -134,6 +134,7 @@ def rwlock_wrap( compute_fn: Callable[[DF_T], DF_T], do_overwrite: bool = False, out_fp_checker: Callable[[Path], bool] = default_file_checker, + register_lock_fn: Callable[[Path], tuple[datetime, Path]] = register_lock, # For dependency injection ) -> bool: """Wrap a series of file-in file-out map transformations on a dataframe with caching and locking. @@ -161,6 +162,8 @@ def rwlock_wrap( >>> import polars as pl >>> import tempfile >>> directory = tempfile.TemporaryDirectory() + >>> read_fn = pl.read_csv + >>> write_fn = pl.DataFrame.write_csv >>> root = Path(directory.name) >>> # For this example we'll use a simple CSV file, but in practice we *strongly* recommend using >>> # Parquet files for performance reasons. @@ -168,9 +171,8 @@ def rwlock_wrap( >>> out_fp = root / "output.csv" >>> in_df = pl.DataFrame({"a": [1, 3, 3], "b": [2, 4, 5], "c": [3, -1, 6]}) >>> in_df.write_csv(in_fp) - >>> read_fn = pl.read_csv - >>> write_fn = pl.DataFrame.write_csv - >>> compute_fn = lambda df: df.with_columns(pl.col("c") * 2).filter(pl.col("c") > 4) + >>> def compute_fn(df: pl.DataFrame) -> pl.DataFrame: + ... return df.with_columns(pl.col("c") * 2).filter(pl.col("c") > 4) >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn) >>> assert result_computed >>> print(out_fp.read_text()) @@ -178,21 +180,108 @@ def rwlock_wrap( 1,2,6 3,5,12 + >>> in_df_2 = pl.DataFrame({"a": [1], "b": [3], "c": [-1]}) + >>> in_fp_2 = root / "input_2.csv" + >>> in_df_2.write_csv(in_fp_2) + >>> compute_fn = lambda df: df + >>> result_computed = rwlock_wrap(in_fp_2, out_fp, read_fn, write_fn, compute_fn, do_overwrite=True) + >>> assert result_computed + >>> print(out_fp.read_text()) + a,b,c + 1,3,-1 + >>> out_fp.unlink() >>> compute_fn = lambda df: df.with_columns(pl.col("c") * 2).filter(pl.col("d") > 4) >>> rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn) Traceback (most recent call last): ... polars.exceptions.ColumnNotFoundError: unable to find column "d"; valid columns: ["a", "b", "c"] + >>> assert not out_fp.is_file() # Out file should not be created when the process crashes + + If we check the locks during computation, one should be present >>> cache_directory = root / f".output.csv_cache" >>> lock_dir = cache_directory / "locks" - >>> assert not list(lock_dir.iterdir()) + >>> assert not list(lock_dir.iterdir()), "Lock dir starts empty" >>> def lock_dir_checker_fn(df: pl.DataFrame) -> pl.DataFrame: ... print(f"Lock dir empty? {not (list(lock_dir.iterdir()))}") ... return df >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, lock_dir_checker_fn) Lock dir empty? False - >>> assert result_computed + >>> result_computed + True + >>> assert not list(lock_dir.iterdir()), "Lock dir should be empty again" + >>> out_fp.unlink() + + If we register a lock before we run, the process won't actually compute + >>> compute_fn = lambda df: df + >>> lock_time, lock_fp = register_lock(cache_directory) + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn, write_fn, compute_fn) + >>> result_computed + False + >>> len(list(lock_dir.iterdir())) # The lock file at lock_fp should still exist + 1 + >>> lock_fp.unlink() + >>> assert not list(lock_dir.iterdir()), "Lock dir should be empty again" + + If two processes collide when writing locks during lock registration before reading, the one that + writes a lock with an earlier timestamp wins and the later one does not read: + >>> def read_fn_and_print(in_fp: Path) -> pl.DataFrame: + ... print("Reading!") + ... return read_fn(in_fp) + >>> def register_lock_with_conflict_fntr(early: bool) -> callable: + ... fake_lock_time = datetime(2021, 1, 1, 0, 0, 0) if early else datetime(5000, 1, 2, 0, 0, 0) + ... def fn(cache_directory: Path) -> tuple[datetime, Path]: + ... lock_fp = cache_directory / "locks" / f"{fake_lock_time.strftime(LOCK_TIME_FMT)}.json" + ... lock_fp.write_text(json.dumps({"start": fake_lock_time.strftime(LOCK_TIME_FMT)})) + ... return register_lock(cache_directory) + ... return fn + >>> result_computed = rwlock_wrap( + ... in_fp, out_fp, read_fn_and_print, write_fn, compute_fn, + ... register_lock_fn=register_lock_with_conflict_fntr(early=True) + ... ) + >>> result_computed + False + >>> len(list(lock_dir.iterdir())) # The lock file added during the registration should still exist. + 1 + >>> next(lock_dir.iterdir()).unlink() + >>> result_computed = rwlock_wrap( + ... in_fp, out_fp, read_fn_and_print, write_fn, compute_fn, + ... register_lock_fn=register_lock_with_conflict_fntr(early=False) + ... ) + Reading! + >>> result_computed + True + >>> len(list(lock_dir.iterdir())) # The lock file added during the registration should still exist. + 1 + >>> next(lock_dir.iterdir()).unlink() + >>> out_fp.unlink() + + If two processes collide when writing locks during reading, the one that writes a lock with an earlier + timestamp wins: + >>> def read_fn_with_lock_fntr(early: bool) -> callable: + ... fake_lock_time = datetime(2021, 1, 1, 0, 0, 0) if early else datetime(5000, 1, 2, 0, 0, 0) + ... def fn(in_fp: Path) -> pl.DataFrame: + ... print("Reading!") + ... df = read_fn(in_fp) + ... lock_fp = lock_dir / f"{fake_lock_time.strftime(LOCK_TIME_FMT)}.json" + ... lock_fp.write_text(json.dumps({"start": fake_lock_time.strftime(LOCK_TIME_FMT)})) + ... return df + ... return fn + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn_with_lock_fntr(True), write_fn, compute_fn) + Reading! + >>> result_computed + False + >>> len(list(lock_dir.iterdir())) # The lock file added during the read should still exist. + 1 + >>> next(lock_dir.iterdir()).unlink() + >>> result_computed = rwlock_wrap(in_fp, out_fp, read_fn_with_lock_fntr(False), write_fn, compute_fn) + Reading! + >>> result_computed + True + >>> len(list(lock_dir.iterdir())) # The lock file added during the read should still exist. + 1 + >>> next(lock_dir.iterdir()).unlink() + >>> out_fp.unlink() >>> directory.cleanup() """ @@ -212,7 +301,7 @@ def rwlock_wrap( logger.info(f"{out_fp} is in progress as of {earliest_lock_time}. Returning.") return False - st_time, lock_fp = register_lock(cache_directory) + st_time, lock_fp = register_lock_fn(cache_directory) logger.info(f"Registered lock at {st_time}. Double checking no earlier locks have been registered.") earliest_lock_time = get_earliest_lock(cache_directory) @@ -263,13 +352,23 @@ def shuffle_shards(shards: list[str], cfg: DictConfig) -> list[str]: The shuffled list of shards. Examples: - >>> cfg = DictConfig({"worker": 1}) >>> shards = ["train/0", "train/1", "tuning", "held_out"] - >>> shuffle_shards(shards, cfg) + >>> shuffle_shards(shards, DictConfig({"worker": 1})) ['train/1', 'held_out', 'tuning', 'train/0'] - >>> cfg = DictConfig({"worker": 2}) - >>> shuffle_shards(shards, cfg) + >>> shuffle_shards(shards, DictConfig({"worker": 2})) ['tuning', 'held_out', 'train/1', 'train/0'] + + It can also shuffle the shards without a worker ID, but the order is then based on the time, which + is not consistent across runs. + >>> sorted(shuffle_shards(shards, DictConfig({}))) + ['held_out', 'train/0', 'train/1', 'tuning'] + + If the shards aren't unique, it will error + >>> shards = ["train/0", "train/0", "tuning", "held_out"] + >>> shuffle_shards(shards, DictConfig({"worker": 1})) + Traceback (most recent call last): + ... + ValueError: Hash collision for shard train/0 with add_str 1! """ if "worker" in cfg: @@ -279,10 +378,10 @@ def shuffle_shards(shards: list[str], cfg: DictConfig) -> list[str]: shard_keys = [] for shard in shards: - shard_hash = hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest() + shard_hash = int(hashlib.sha256((add_str + shard).encode("utf-8")).hexdigest(), 16) if shard_hash in shard_keys: raise ValueError(f"Hash collision for shard {shard} with add_str {add_str}!") - shard_keys.append(int(shard_hash, 16)) + shard_keys.append(shard_hash) return [shard for _, shard in sorted(zip(shard_keys, shards))] diff --git a/src/MEDS_transforms/reshard_to_split.py b/src/MEDS_transforms/reshard_to_split.py index deccc49f..650aa57e 100644 --- a/src/MEDS_transforms/reshard_to_split.py +++ b/src/MEDS_transforms/reshard_to_split.py @@ -112,7 +112,7 @@ def main(cfg: DictConfig): max_iters = cfg.get("max_iters", 10) iters = 0 - while not valid_json_file(shards_fp) and iters < max_iters: + while not valid_json_file(shards_fp) and iters < max_iters: # pragma: no cover logger.info(f"Waiting to begin until shards map is written. Iteration {iters}/{max_iters}...") time.sleep(cfg.polling_time) iters += 1 diff --git a/src/MEDS_transforms/runner.py b/src/MEDS_transforms/runner.py index e99e014a..ca281223 100755 --- a/src/MEDS_transforms/runner.py +++ b/src/MEDS_transforms/runner.py @@ -33,6 +33,20 @@ def get_script_from_name(stage_name: str) -> str | None: Returns: The script name for the given stage name. + + Examples: + >>> get_script_from_name("shard_events") + 'MEDS_extract-shard_events' + >>> get_script_from_name("fit_vocabulary_indices") + 'MEDS_transform-fit_vocabulary_indices' + >>> get_script_from_name("filter_subjects") + 'MEDS_transform-filter_subjects' + >>> get_script_from_name("reorder_measurements") + 'MEDS_transform-reorder_measurements' + >>> get_script_from_name("nonexistent_stage") + Traceback (most recent call last): + ... + ValueError: Could not find a script for stage nonexistent_stage. """ try: @@ -54,7 +68,43 @@ def get_script_from_name(stage_name: str) -> str | None: def get_parallelization_args( parallelization_cfg: dict | DictConfig | None, default_parallelization_cfg: dict | DictConfig ) -> list[str]: - """Gets the parallelization args.""" + """Extracts the specific parallelization arguments given the default and stage-specific configurations. + + Args: + parallelization_cfg: The stage-specific parallelization configuration. + default_parallelization_cfg: The default parallelization configuration. + + Returns: + A list of command-line arguments for parallelization. + + Examples: + >>> get_parallelization_args({}, {}) + [] + >>> get_parallelization_args(None, {"n_workers": 4}) + [] + >>> get_parallelization_args({"launcher": "joblib"}, {}) + ['--multirun', 'worker="range(0,1)"', 'hydra/launcher=joblib'] + >>> get_parallelization_args({"n_workers": 2, "launcher_params": 'foo'}, {}) + Traceback (most recent call last): + ... + ValueError: If launcher_params is provided, launcher must also be provided. + >>> get_parallelization_args({"n_workers": 2}, {}) + ['--multirun', 'worker="range(0,2)"'] + >>> get_parallelization_args( + ... {"launcher": "slurm"}, + ... {"n_workers": 3, "launcher": "joblib"} + ... ) + ['--multirun', 'worker="range(0,3)"', 'hydra/launcher=slurm'] + >>> get_parallelization_args( + ... {"n_workers": 2, "launcher": "joblib"}, + ... {"n_workers": 5, "launcher_params": {"foo": "bar"}}, + ... ) + ['--multirun', 'worker="range(0,2)"', 'hydra/launcher=joblib', 'hydra.launcher.foo=bar'] + >>> get_parallelization_args( + ... {"n_workers": 5, "launcher_params": {"biz": "baz"}, "launcher": "slurm"}, {} + ... ) + ['--multirun', 'worker="range(0,5)"', 'hydra/launcher=slurm', 'hydra.launcher.biz=baz'] + """ if parallelization_cfg is None: return [] @@ -82,11 +132,11 @@ def get_parallelization_args( launcher = None if launcher is None: - return parallelization_args - if "launcher_params" in parallelization_cfg: raise ValueError("If launcher_params is provided, launcher must also be provided.") + return parallelization_args + parallelization_args.append(f"hydra/launcher={launcher}") if "launcher_params" in parallelization_cfg: @@ -102,12 +152,68 @@ def get_parallelization_args( return parallelization_args -def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dict | DictConfig | None = None): +def run_stage( + cfg: DictConfig, + stage_name: str, + default_parallelization_cfg: dict | DictConfig | None = None, + runner_fn: callable = subprocess.run, # For dependency injection +): """Runs a single stage of the pipeline. Args: cfg: The configuration for the entire pipeline. stage_name: The name of the stage to run. + + Raises: + ValueError: If the stage fails to run. + + Examples: + >>> def fake_shell_succeed(cmd, shell, capture_output): + ... print(cmd) + ... return subprocess.CompletedProcess(args=cmd, returncode=0, stdout=b"", stderr=b"") + >>> def fake_shell_fail(cmd, shell, capture_output): + ... print(cmd) + ... return subprocess.CompletedProcess(args=cmd, returncode=1, stdout=b"", stderr=b"") + >>> cfg = OmegaConf.create({ + ... "pipeline_config_fp": "pipeline_config.yaml", + ... "do_profile": False, + ... "_local_pipeline_config": { + ... "stage_configs": { + ... "shard_events": {}, + ... "fit_vocabulary_indices": {"_script": "foobar"}, + ... }, + ... }, + ... "_stage_runners": { + ... "shard_events": {"_script": "not used"}, + ... "fit_vocabulary_indices": {}, + ... "baz": {"script": "baz_script"}, + ... }, + ... }) + >>> run_stage(cfg, "shard_events", runner_fn=fake_shell_succeed) # doctest: +NORMALIZE_WHITESPACE + MEDS_extract-shard_events --config-dir=... --config-name=pipeline_config + 'hydra.searchpath=[pkg://MEDS_transforms.configs]' stage=shard_events + >>> run_stage( + ... cfg, "fit_vocabulary_indices", runner_fn=fake_shell_succeed + ... ) # doctest: +NORMALIZE_WHITESPACE + foobar --config-dir=... --config-name=pipeline_config + 'hydra.searchpath=[pkg://MEDS_transforms.configs]' stage=fit_vocabulary_indices + >>> run_stage(cfg, "baz", runner_fn=fake_shell_succeed) # doctest: +NORMALIZE_WHITESPACE + baz_script --config-dir=... --config-name=pipeline_config + 'hydra.searchpath=[pkg://MEDS_transforms.configs]' stage=baz + >>> cfg.do_profile = True + >>> run_stage(cfg, "baz", runner_fn=fake_shell_succeed) # doctest: +NORMALIZE_WHITESPACE + baz_script --config-dir=... --config-name=pipeline_config + 'hydra.searchpath=[pkg://MEDS_transforms.configs]' stage=baz + ++hydra.callbacks.profiler._target_=hydra_profiler.profiler.ProfilerCallback + >>> cfg._stage_runners.baz.parallelize = {"n_workers": 2} + >>> cfg.do_profile = False + >>> run_stage(cfg, "baz", runner_fn=fake_shell_succeed) # doctest: +NORMALIZE_WHITESPACE + baz_script --config-dir=... --config-name=pipeline_config --multirun + 'hydra.searchpath=[pkg://MEDS_transforms.configs]' stage=baz worker="range(0,2)" + >>> run_stage(cfg, "baz", runner_fn=fake_shell_fail) + Traceback (most recent call last): + ... + ValueError: Stage baz failed via ... """ if default_parallelization_cfg is None: @@ -147,7 +253,7 @@ def run_stage(cfg: DictConfig, stage_name: str, default_parallelization_cfg: dic full_cmd = " ".join(command_parts) logger.info(f"Running command: {full_cmd}") - command_out = subprocess.run(full_cmd, shell=True, capture_output=True) + command_out = runner_fn(full_cmd, shell=True, capture_output=True) # https://stackoverflow.com/questions/21953835/run-subprocess-and-print-output-to-logging # https://loguru.readthedocs.io/en/stable/api/logger.html#loguru._logger.Logger.parse @@ -173,13 +279,7 @@ def main(cfg: DictConfig): pipeline. """ - hydra_loguru_init() - pipeline_config_fp = Path(cfg.pipeline_config_fp) - if not pipeline_config_fp.exists(): - raise FileNotFoundError(f"Pipeline configuration file {pipeline_config_fp} does not exist.") - if not pipeline_config_fp.suffix == ".yaml": - raise ValueError(f"Pipeline configuration file {pipeline_config_fp} must have a .yaml extension.") if pipeline_config_fp.stem in RESERVED_CONFIG_NAMES: raise ValueError( f"Pipeline configuration file {pipeline_config_fp} must not have a name in " @@ -191,9 +291,11 @@ def main(cfg: DictConfig): if not stages: raise ValueError("Pipeline configuration must specify at least one stage.") + hydra_loguru_init() + log_dir = Path(cfg.log_dir) - if cfg.get("do_profile", False): + if cfg.get("do_profile", False): # pragma: no cover try: import hydra_profiler # noqa: F401 except ImportError as e: @@ -229,6 +331,37 @@ def main(cfg: DictConfig): def load_yaml_file(path: str | None) -> dict | DictConfig: + """Loads a YAML file as an OmegaConf object. + + Args: + path: The path to the YAML file. + + Returns: + The OmegaConf object representing the YAML file, or None if no path is provided. + + Raises: + FileNotFoundError: If the file does not exist. + + Examples: + >>> load_yaml_file(None) + {} + >>> load_yaml_file("nonexistent_file.yaml") + Traceback (most recent call last): + ... + FileNotFoundError: File nonexistent_file.yaml does not exist. + >>> import tempfile + >>> with tempfile.NamedTemporaryFile(suffix=".yaml") as f: + ... _ = f.write(b"foo: bar") + ... f.flush() + ... load_yaml_file(f.name) + {'foo': 'bar'} + >>> with tempfile.NamedTemporaryFile(suffix=".yaml") as f: + ... cfg = OmegaConf.create({"foo": "bar"}) + ... OmegaConf.save(cfg, f.name) + ... load_yaml_file(f.name) + {'foo': 'bar'} + """ + if not path: return {} @@ -238,7 +371,7 @@ def load_yaml_file(path: str | None) -> dict | DictConfig: try: return OmegaConf.load(path) - except Exception as e: + except Exception as e: # pragma: no cover logger.warning(f"Failed to load {path} as an OmegaConf: {e}. Trying as a plain YAML file.") yaml_text = path.read_text() return yaml.load(yaml_text, Loader=Loader) diff --git a/src/MEDS_transforms/transforms/add_time_derived_measurements.py b/src/MEDS_transforms/transforms/add_time_derived_measurements.py index d1d3d1e1..a785336c 100644 --- a/src/MEDS_transforms/transforms/add_time_derived_measurements.py +++ b/src/MEDS_transforms/transforms/add_time_derived_measurements.py @@ -330,6 +330,26 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: │ 2 ┆ 2023-01-03 12:00:00 ┆ time_of_day//[12,18) │ │ 3 ┆ 2022-01-01 18:00:00 ┆ time_of_day//[18,24) │ └────────────┴─────────────────────┴──────────────────────┘ + >>> time_of_day_fntr(DictConfig({"endpoints": []})) + Traceback (most recent call last): + ... + ValueError: The 'endpoints' key must contain at least one endpoint for time of day categories. + >>> time_of_day_fntr(DictConfig({"endpoints": [6, 12, 36]})) + Traceback (most recent call last): + ... + ValueError: All endpoints must be between 0 and 24 inclusive. Got: [6, 12, 36] + >>> time_of_day_fntr(DictConfig({"endpoints": [6, 1.2]})) + Traceback (most recent call last): + ... + ValueError: All endpoints must be integer, whole-hour boundaries, but got: [6, 1.2] + >>> time_of_day_fntr(DictConfig({"endpoints": [6, 6]})) + Traceback (most recent call last): + ... + ValueError: All endpoints must be unique. Got: [6, 6] + >>> time_of_day_fntr(DictConfig({"endpoints": [6, 12, 10]})) + Traceback (most recent call last): + ... + ValueError: All endpoints must be in sorted order. Got: [6, 12, 10] """ if not cfg.endpoints: raise ValueError("The 'endpoints' key must contain at least one endpoint for time of day categories.") @@ -337,8 +357,10 @@ def time_of_day_fntr(cfg: DictConfig) -> Callable[[pl.DataFrame], pl.DataFrame]: raise ValueError(f"All endpoints must be between 0 and 24 inclusive. Got: {cfg.endpoints}") if not all(isinstance(endpoint, int) for endpoint in cfg.endpoints): raise ValueError(f"All endpoints must be integer, whole-hour boundaries, but got: {cfg.endpoints}") - if len(cfg.endpoints) != len(set(cfg.endpoints)) or cfg.endpoints != sorted(cfg.endpoints): - raise ValueError(f"All endpoints must be unique and in sorted order. Got: {cfg.endpoints}") + if len(cfg.endpoints) != len(set(cfg.endpoints)): + raise ValueError(f"All endpoints must be unique. Got: {cfg.endpoints}") + if cfg.endpoints != sorted(cfg.endpoints): + raise ValueError(f"All endpoints must be in sorted order. Got: {cfg.endpoints}") def fn(df: pl.LazyFrame) -> pl.LazyFrame: hour = pl.col("time").dt.hour() @@ -365,6 +387,27 @@ def tod_code(start: int, end: int) -> str: def add_time_derived_measurements_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.LazyFrame]: + """Adds all requested time-derived measurements to a DataFrame. + + Args: + stage_cfg: The configuration for the time-derived measurements. Recognized time derived functors + include the following keys: + - "age": The configuration for the age function. + - "time_of_day": The configuration for the time of day function. + + Returns: + A function that adds all requested time-derived measurements to a DataFrame. + + Raises: + ValueError: If an unrecognized time-derived measurement is requested. + + Examples: + >>> add_time_derived_measurements_fntr(DictConfig({"buzz": {}})) + Traceback (most recent call last): + ... + ValueError: Unknown time-derived measurement: buzz + """ + compute_fns = [] # We use the raw stages object as the induced `stage_cfg` has extra properties like the input and output # directories. diff --git a/src/MEDS_transforms/transforms/extract_values.py b/src/MEDS_transforms/transforms/extract_values.py index f2335504..033dcc6f 100644 --- a/src/MEDS_transforms/transforms/extract_values.py +++ b/src/MEDS_transforms/transforms/extract_values.py @@ -95,18 +95,18 @@ def extract_values_fntr(stage_cfg: DictConfig) -> Callable[[pl.LazyFrame], pl.La match out_col_n: case str() if out_col_n in MANDATORY_TYPES: expr = expr.cast(MANDATORY_TYPES[out_col_n]) - if out_col_n == subject_id_field: + if out_col_n == subject_id_field: # pragma: no cover logger.warning( f"You should almost CERTAINLY not be extracting {subject_id_field} as a value." ) - if out_col_n == "time": + if out_col_n == "time": # pragma: no cover logger.warning("Warning: `time` is being extracted post-hoc!") - case str() if out_col_n in DEPRECATED_NAMES: + case str() if out_col_n in DEPRECATED_NAMES: # pragma: no cover logger.warning( f"Deprecated column name: {out_col_n} -> {DEPRECATED_NAMES[out_col_n]}. " "This column name will not be re-typed." ) - case str(): + case str(): # pragma: no cover pass case _: raise ValueError(f"Invalid column name: {out_col_n}") diff --git a/src/MEDS_transforms/transforms/normalization.py b/src/MEDS_transforms/transforms/normalization.py index 5109861a..9e5e9262 100644 --- a/src/MEDS_transforms/transforms/normalization.py +++ b/src/MEDS_transforms/transforms/normalization.py @@ -165,6 +165,20 @@ def normalize( │ 2 ┆ 2022-10-02 00:00:00 ┆ 2 ┆ null │ │ 3 ┆ 2022-10-02 00:00:00 ┆ 5 ┆ null │ └────────────┴─────────────────────┴──────┴───────────────┘ + + Note that while this function is robust to the inclusion of the default row index column name, it + doesn't retain any extra columns after the operation. If you want to retain the row index, you should + file a GitHub issue with this request and we can add it in a future release. + >>> MEDS_df = MEDS_df.with_columns(pl.lit(1).alias("_row_idx"), pl.lit(2).alias("foobar")) + >>> normalize(MEDS_df.head(1).lazy(), code_metadata, ["unit"]).collect() + shape: (1, 4) + ┌────────────┬─────────────────────┬──────┬───────────────┐ + │ subject_id ┆ time ┆ code ┆ numeric_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ u32 ┆ datetime[μs] ┆ u32 ┆ f64 │ + ╞════════════╪═════════════════════╪══════╪═══════════════╡ + │ 1 ┆ 2021-01-01 00:00:00 ┆ 0 ┆ -2.0 │ + └────────────┴─────────────────────┴──────┴───────────────┘ """ if code_modifiers is None: diff --git a/src/MEDS_transforms/transforms/occlude_outliers.py b/src/MEDS_transforms/transforms/occlude_outliers.py index d65ecd59..b670f9e7 100644 --- a/src/MEDS_transforms/transforms/occlude_outliers.py +++ b/src/MEDS_transforms/transforms/occlude_outliers.py @@ -54,6 +54,22 @@ def occlude_outliers_fntr( │ 2 ┆ A ┆ 2 ┆ null ┆ false │ │ 2 ┆ C ┆ 2 ┆ 1.0 ┆ true │ └────────────┴──────┴───────────┴───────────────┴─────────────────────────┘ + + If no standard deviation cutoff is provided, the function should return the input DataFrame unchanged: + >>> stage_cfg = DictConfig({}) + >>> fn = occlude_outliers_fntr(stage_cfg, code_metadata_df, ["modifier1"]) + >>> fn(data).collect() + shape: (4, 4) + ┌────────────┬──────┬───────────┬───────────────┐ + │ subject_id ┆ code ┆ modifier1 ┆ numeric_value │ + │ --- ┆ --- ┆ --- ┆ --- │ + │ i64 ┆ str ┆ i64 ┆ f64 │ + ╞════════════╪══════╪═══════════╪═══════════════╡ + │ 1 ┆ A ┆ 1 ┆ 15.0 │ + │ 1 ┆ B ┆ 1 ┆ 16.0 │ + │ 2 ┆ A ┆ 2 ┆ 3.9 │ + │ 2 ┆ C ┆ 2 ┆ 1.0 │ + └────────────┴──────┴───────────┴───────────────┘ """ stddev_cutoff = stage_cfg.get("stddev_cutoff", None) diff --git a/src/MEDS_transforms/transforms/tensorization.py b/src/MEDS_transforms/transforms/tensorization.py index 5fd73899..1c780663 100644 --- a/src/MEDS_transforms/transforms/tensorization.py +++ b/src/MEDS_transforms/transforms/tensorization.py @@ -75,6 +75,19 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: time_delta_days [[nan 12.] [nan 0.]] + + With the wrong number of time delta columns, it doesn't work: + >>> nrt = convert_to_NRT(df.drop("time_delta_days").lazy()) + Traceback (most recent call last): + ... + ValueError: Expected at least one time delta column, found none + >>> nrt = convert_to_NRT( + ... df.with_columns(pl.lit([1, 2]).alias("time_delta_hours")).lazy() + ... ) # doctest: +NORMALIZE_WHITESPACE + Traceback (most recent call last): + ... + ValueError: Expected exactly one time delta column, found columns: + ['time_delta_days', 'time_delta_hours'] """ # There should only be one time delta column, but this ensures we catch it regardless of the unit of time @@ -94,10 +107,6 @@ def convert_to_NRT(df: pl.LazyFrame) -> JointNestedRaggedTensorDict: logger.warning("All columns are empty. Returning an empty tensor dict.") return JointNestedRaggedTensorDict({}) - for k, v in tensors_dict.items(): - if not v: - raise ValueError(f"Column {k} is empty") - return JointNestedRaggedTensorDict(tensors_dict) diff --git a/src/MEDS_transforms/utils.py b/src/MEDS_transforms/utils.py index 871b90a0..ebffb03b 100644 --- a/src/MEDS_transforms/utils.py +++ b/src/MEDS_transforms/utils.py @@ -110,7 +110,21 @@ def get_package_version() -> str: def get_script_docstring(filename: str | None = None) -> str: - """Returns the docstring of the main function of the script from which this function was called.""" + """Returns the docstring of the main function of the calling script or the file specified. + + Args: + filename: The name of the file to get the docstring from. If None, the calling script's docstring is + returned. + + Returns: + str: The docstring of the main function of the specified file, if it exists. + + Examples: + >>> get_script_docstring() + '' + >>> get_script_docstring("reshard_to_split") + 'Re-shard a MEDS cohort to in a manner that subdivides subject splits.' + """ if filename is not None: main_module = importlib.import_module(f"MEDS_transforms.{filename}") @@ -129,7 +143,7 @@ def current_script_name() -> str: main_func = getattr(main_module, "main", None) if main_func and callable(main_func): func_module = main_func.__module__ - if func_module == "__main__": + if func_module == "__main__": # pragma: no cover return Path(sys.argv[0]).stem else: return func_module.split(".")[-1] diff --git a/tests/MEDS_Extract/test_convert_to_sharded_events.py b/tests/MEDS_Extract/test_convert_to_sharded_events.py index 653ce737..6b0de6bd 100644 --- a/tests/MEDS_Extract/test_convert_to_sharded_events.py +++ b/tests/MEDS_Extract/test_convert_to_sharded_events.py @@ -336,6 +336,24 @@ def test_convert_to_sharded_events(): df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": False}, ) + # If we don't provide the event_cfgs.yaml file, the script should error. + single_stage_tester( + script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, + stage_name="convert_to_sharded_events", + stage_kwargs={"do_dedup_text_and_numeric": True}, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + test_name="Stage tester: convert_to_sharded_events ; with dedup", + should_error=True, + ) + single_stage_tester( script=CONVERT_TO_SHARDED_EVENTS_SCRIPT, stage_name="convert_to_sharded_events", diff --git a/tests/MEDS_Extract/test_extract_code_metadata.py b/tests/MEDS_Extract/test_extract_code_metadata.py index 7700426f..d75f9ea8 100644 --- a/tests/MEDS_Extract/test_extract_code_metadata.py +++ b/tests/MEDS_Extract/test_extract_code_metadata.py @@ -4,7 +4,6 @@ scripts. """ - import polars as pl from tests.MEDS_Extract import EXTRACT_CODE_METADATA_SCRIPT @@ -202,3 +201,20 @@ def test_convert_to_sharded_events(): df_check_kwargs={"check_row_order": False, "check_column_order": False, "check_dtypes": True}, assert_no_other_outputs=False, ) + + # The script should error if the event config file is missing. + single_stage_tester( + script=EXTRACT_CODE_METADATA_SCRIPT, + stage_name="extract_code_metadata", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "demo_metadata.csv": DEMO_METADATA_FILE, + "input_metadata.csv": INPUT_METADATA_FILE, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + should_error=True, + ) diff --git a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py index 74688043..6339da79 100644 --- a/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py +++ b/tests/MEDS_Extract/test_merge_to_MEDS_cohort.py @@ -250,7 +250,7 @@ ) -def test_convert_to_sharded_events(): +def test_merge_to_MEDS_cohort(): single_stage_tester( script=MERGE_TO_MEDS_COHORT_SCRIPT, stage_name="merge_to_MEDS_cohort", @@ -266,3 +266,18 @@ def test_convert_to_sharded_events(): want_outputs=WANT_OUTPUTS, df_check_kwargs={"check_column_order": False}, ) + + # Should error without event conversion file + single_stage_tester( + script=MERGE_TO_MEDS_COHORT_SCRIPT, + stage_name="merge_to_MEDS_cohort", + stage_kwargs=None, + config_name="extract", + input_files={ + **INPUT_SHARDS, + "metadata/.shards.json": SHARDS_JSON, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + shards_map_fp="{input_dir}/metadata/.shards.json", + should_error=True, + ) diff --git a/tests/MEDS_Extract/test_shard_events.py b/tests/MEDS_Extract/test_shard_events.py index f19746ec..7e1c696a 100644 --- a/tests/MEDS_Extract/test_shard_events.py +++ b/tests/MEDS_Extract/test_shard_events.py @@ -21,6 +21,10 @@ 68729,03/09/1978,HAZEL,160.3953106166676 """ +EMPTY_SUBJECTS_CSV = """ +MRN,dob,eye_color,height +""" + ADMIT_VITALS_CSV = """ subject_id,admit_date,disch_date,department,vitals_date,HR,temp 239684,"05/11/2010, 17:41:51","05/11/2010, 19:27:19",CARDIAC,"05/11/2010, 18:57:18",112.6,95.5 @@ -112,3 +116,39 @@ def test_shard_events(): }, df_check_kwargs={"check_column_order": False}, ) + + single_stage_tester( + script=SHARD_EVENTS_SCRIPT, + stage_name="shard_events", + stage_kwargs={"row_chunksize": 10}, + config_name="extract", + input_files={ + "subjects.csv": SUBJECTS_CSV, + "admit_vitals.csv": ADMIT_VITALS_CSV, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + should_error=True, + test_name="Shard events should error without event conversion config", + ) + + single_stage_tester( + script=SHARD_EVENTS_SCRIPT, + stage_name="shard_events", + stage_kwargs={"row_chunksize": 10}, + config_name="extract", + input_files={"event_cfgs.yaml": EVENT_CFGS_YAML}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + should_error=True, + test_name="Shard events should error when missing all input files", + ) + + single_stage_tester( + script=SHARD_EVENTS_SCRIPT, + stage_name="shard_events", + stage_kwargs={"row_chunksize": 10}, + config_name="extract", + input_files={"subjects.csv": EMPTY_SUBJECTS_CSV, "event_cfgs.yaml": EVENT_CFGS_YAML}, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + should_error=True, + test_name="Shard events should error when an input file is empty", + ) diff --git a/tests/MEDS_Extract/test_split_and_shard_subjects.py b/tests/MEDS_Extract/test_split_and_shard_subjects.py index db74896d..ef80f7b9 100644 --- a/tests/MEDS_Extract/test_split_and_shard_subjects.py +++ b/tests/MEDS_Extract/test_split_and_shard_subjects.py @@ -103,6 +103,12 @@ "held_out/0": [1500733], } +EXTERNAL_SPLITS = { + "train": [239684, 1195293, 68729, 814703], + "tuning": [754281], + "held_out": [1500733], +} + SUBJECT_SPLITS_DF = pl.DataFrame( { "subject_id": [239684, 1195293, 68729, 814703, 754281, 1500733], @@ -131,3 +137,68 @@ def test_split_and_shard(): event_conversion_config_fp="{input_dir}/event_cfgs.yaml", want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, ) + + single_stage_tester( + script=SPLIT_AND_SHARD_SCRIPT, + stage_name="split_and_shard_subjects", + stage_kwargs={ + "split_fracs.train": 4 / 6, + "split_fracs.tuning": 1 / 6, + "split_fracs.held_out": 1 / 6, + "n_subjects_per_shard": 2, + "external_splits_json_fp": "{input_dir}/external_splits.json", + }, + config_name="extract", + input_files={ + "external_splits.json": EXTERNAL_SPLITS, + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + want_outputs={"metadata/.shards.json": EXPECTED_SPLITS}, + test_name="Split and shard events should work with an external splits file.", + ) + + single_stage_tester( + script=SPLIT_AND_SHARD_SCRIPT, + stage_name="split_and_shard_subjects", + stage_kwargs={ + "split_fracs.train": 4 / 6, + "split_fracs.tuning": 1 / 6, + "split_fracs.held_out": 1 / 6, + "n_subjects_per_shard": 2, + }, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + should_error=True, + test_name="Split and shard events should error without an event config file.", + ) + + single_stage_tester( + script=SPLIT_AND_SHARD_SCRIPT, + stage_name="split_and_shard_subjects", + stage_kwargs={ + "split_fracs.train": 4 / 6, + "split_fracs.tuning": 1 / 6, + "split_fracs.held_out": 1 / 6, + "n_subjects_per_shard": 2, + "external_splits_json_fp": "{input_dir}/external_splits.json", + }, + config_name="extract", + input_files={ + "data/subjects/[0-6).parquet": pl.read_csv(StringIO(SUBJECTS_CSV)), + "data/admit_vitals/[0-10).parquet": pl.read_csv(StringIO(ADMIT_VITALS_0_10_CSV)), + "data/admit_vitals/[10-16).parquet": pl.read_csv(StringIO(ADMIT_VITALS_10_16_CSV)), + "event_cfgs.yaml": EVENT_CFGS_YAML, + }, + event_conversion_config_fp="{input_dir}/event_cfgs.yaml", + should_error=True, + test_name="Split and shard events should error if an external splits file is requested but absent.", + ) diff --git a/tests/MEDS_Transforms/test_aggregate_code_metadata.py b/tests/MEDS_Transforms/test_aggregate_code_metadata.py index a2abce52..90850c3c 100644 --- a/tests/MEDS_Transforms/test_aggregate_code_metadata.py +++ b/tests/MEDS_Transforms/test_aggregate_code_metadata.py @@ -9,6 +9,7 @@ from tests.MEDS_Transforms import AGGREGATE_CODE_METADATA_SCRIPT from tests.MEDS_Transforms.transform_tester_base import ( MEDS_CODE_METADATA_SCHEMA, + MEDS_SHARDS, single_stage_transform_tester, ) @@ -186,3 +187,29 @@ def test_aggregate_code_metadata(): assert_no_other_outputs=False, df_check_kwargs={"check_column_order": False}, ) + + # Test with shards re-mapped so it has to use the splits file. + remapped_shards = {str(i): v for i, v in enumerate(MEDS_SHARDS.values())} + single_stage_transform_tester( + transform_script=AGGREGATE_CODE_METADATA_SCRIPT, + stage_name="aggregate_code_metadata", + transform_stage_kwargs={"aggregations": AGGREGATIONS, "do_summarize_over_all_codes": True}, + want_metadata=WANT_OUTPUT_CODE_METADATA_FILE, + input_code_metadata=MEDS_CODE_METADATA_FILE, + do_use_config_yaml=True, + assert_no_other_outputs=False, + df_check_kwargs={"check_column_order": False}, + input_shards=remapped_shards, + ) + + single_stage_transform_tester( + transform_script=AGGREGATE_CODE_METADATA_SCRIPT, + stage_name="aggregate_code_metadata", + transform_stage_kwargs={"aggregations": AGGREGATIONS, "do_summarize_over_all_codes": True}, + want_metadata=WANT_OUTPUT_CODE_METADATA_FILE, + input_code_metadata=MEDS_CODE_METADATA_FILE, + do_use_config_yaml=True, + input_shards=remapped_shards, + splits_fp=None, + should_error=True, + ) diff --git a/tests/MEDS_Transforms/test_reshard_to_split.py b/tests/MEDS_Transforms/test_reshard_to_split.py index d0094a96..bd6d1f29 100644 --- a/tests/MEDS_Transforms/test_reshard_to_split.py +++ b/tests/MEDS_Transforms/test_reshard_to_split.py @@ -4,7 +4,6 @@ scripts. """ - from meds import subject_id_field from tests.MEDS_Transforms import RESHARD_TO_SPLIT_SCRIPT @@ -207,7 +206,7 @@ def test_reshard_to_split(): single_stage_transform_tester( transform_script=RESHARD_TO_SPLIT_SCRIPT, stage_name="reshard_to_split", - transform_stage_kwargs={"n_patients_per_shard": 2, "+train_only": True}, + transform_stage_kwargs={"n_subjects_per_shard": 2, "+train_only": True}, want_data=WANT_SHARDS, input_shards=IN_SHARDS, input_shards_map=IN_SHARDS_MAP, diff --git a/tests/MEDS_Transforms/test_tokenization.py b/tests/MEDS_Transforms/test_tokenization.py index 4a416c83..802ada2d 100644 --- a/tests/MEDS_Transforms/test_tokenization.py +++ b/tests/MEDS_Transforms/test_tokenization.py @@ -296,7 +296,7 @@ def test_tokenization(): single_stage_transform_tester( transform_script=TOKENIZATION_SCRIPT, stage_name="tokenization", - transform_stage_kwargs={"train_only": True}, + transform_stage_kwargs={"++train_only": True}, input_shards=NORMALIZED_SHARDS, want_data={**WANT_SCHEMAS, **WANT_EVENT_SEQS}, should_error=True, diff --git a/tests/MEDS_Transforms/transform_tester_base.py b/tests/MEDS_Transforms/transform_tester_base.py index 7a26c855..e955946c 100644 --- a/tests/MEDS_Transforms/transform_tester_base.py +++ b/tests/MEDS_Transforms/transform_tester_base.py @@ -4,7 +4,6 @@ scripts. """ - from collections import defaultdict from io import StringIO from pathlib import Path @@ -158,6 +157,7 @@ def remap_inputs_for_transform( input_shards: dict[str, pl.DataFrame] | None = None, input_shards_map: dict[str, list[int]] | None = None, input_splits_map: dict[str, list[int]] | None = None, + splits_fp: Path | str | None = subject_splits_filepath, ) -> dict[str, FILE_T]: unified_inputs = {} @@ -192,7 +192,9 @@ def remap_inputs_for_transform( input_splits_df = pl.DataFrame(input_splits_as_df) - unified_inputs[subject_splits_filepath] = input_splits_df + if splits_fp is not None: + # This case is added for error testing; not for general use. + unified_inputs[splits_fp] = input_splits_df return unified_inputs diff --git a/tests/test_with_runner.py b/tests/test_with_runner.py index 7913cb13..84d9ad63 100644 --- a/tests/test_with_runner.py +++ b/tests/test_with_runner.py @@ -17,7 +17,6 @@ The stage configuration arguments will be as given in the yaml block below: """ - from functools import partial from meds import code_metadata_filepath, subject_splits_filepath @@ -87,6 +86,14 @@ {STAGE_RUNNER_YAML} """ +PIPELINE_NO_STAGES_YAML = """ +defaults: + - _preprocess + - _self_ + +input_dir: {{input_dir}} +cohort_dir: {{cohort_dir}} +""" PIPELINE_YAML = f""" defaults: @@ -273,3 +280,41 @@ def test_pipeline(): do_include_dirs=False, df_check_kwargs={"check_column_order": False}, ) + + single_stage_tester( + script=RUNNER_SCRIPT, + config_name="runner", + stage_name=None, + stage_kwargs=None, + do_pass_stage_name=False, + do_use_config_yaml=False, + input_files={ + **{f"data/{k}": v for k, v in MEDS_SHARDS.items()}, + code_metadata_filepath: MEDS_CODE_METADATA, + subject_splits_filepath: SPLITS_DF, + "_preprocess.yaml": partial(add_params, PIPELINE_YAML), + }, + do_include_dirs=False, + should_error=True, + pipeline_config_fp="{input_dir}/_preprocess.yaml", + test_name="Runner should fail if the pipeline config has an invalid name", + ) + + single_stage_tester( + script=RUNNER_SCRIPT, + config_name="runner", + stage_name=None, + stage_kwargs=None, + do_pass_stage_name=False, + do_use_config_yaml=False, + input_files={ + **{f"data/{k}": v for k, v in MEDS_SHARDS.items()}, + code_metadata_filepath: MEDS_CODE_METADATA, + subject_splits_filepath: SPLITS_DF, + "pipeline.yaml": partial(add_params, PIPELINE_NO_STAGES_YAML), + }, + do_include_dirs=False, + should_error=True, + pipeline_config_fp="{input_dir}/pipeline.yaml", + test_name="Runner should fail if the pipeline has no stages", + ) diff --git a/tests/utils.py b/tests/utils.py index 0e9ae943..4bb9741c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -405,10 +405,16 @@ def single_stage_tester( if df_check_kwargs is None: df_check_kwargs = {} + if stage_kwargs is None: + stage_kwargs = {} + with input_dataset(input_files) as (input_dir, cohort_dir): for k, v in pipeline_kwargs.items(): if type(v) is str and "{input_dir}" in v: pipeline_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) + for k, v in stage_kwargs.items(): + if type(v) is str and "{input_dir}" in v: + stage_kwargs[k] = v.format(input_dir=str(input_dir.resolve())) pipeline_config_kwargs = { "hydra.verbose": hydra_verbose,