Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cases and (eventually) fixes for #114 #117

Open
wants to merge 24 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
edcd2e9
starting off
mmcdermott Jun 22, 2024
785ccc9
Starting to add test contents; not yet working
mmcdermott Jun 22, 2024
fa6844e
Merge branch 'dev' into fix_114_typing_with_subject_IDs
mmcdermott Jun 22, 2024
cf67bcd
Added a more involved test for dataset construction.
mmcdermott Jun 22, 2024
0434ed4
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
98cafd5
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
48ba63d
Merge branch 'fix_114_typing_with_subject_IDs' of github.com:mmcdermo…
mmcdermott Jun 22, 2024
b134133
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
9acff54
Removed the unnecessary pandas conversion in the getting number of ev…
mmcdermott Jun 22, 2024
1b4f0d8
Things are partially improved, but other tests are still failing. Inv…
mmcdermott Jun 22, 2024
39ba674
This may have fixed it.
mmcdermott Jun 22, 2024
ace03dd
Added minor changes from another branch and more debug logs
pargaw Jun 22, 2024
b41b014
Added assertion to check _filter_col_inclusion
pargaw Jun 22, 2024
aa0fa66
Some more corrections to subject ID typing.
mmcdermott Jun 22, 2024
413dda5
Fixed pytorch dataset issue (maybe)
mmcdermott Jun 22, 2024
9d703b7
Tests still having challenges with shards not overlapping.
mmcdermott Jun 22, 2024
7616598
fixed broken pytorch dataset test given string conventions.
mmcdermott Jun 23, 2024
674b50f
Other changes to try to get things working
mmcdermott Jun 23, 2024
9f3ce52
Added some more logging and dropping of nulls for safety
mmcdermott Jun 23, 2024
71ac9a8
Added flat rep changes from another branch
pargaw Jun 23, 2024
a813468
Added additional changes to fix flat_rep run from another branch
pargaw Jun 23, 2024
87a3874
Set assertion error to warning log
pargaw Jun 23, 2024
5150e05
Added functionality when do_update is False
pargaw Jun 30, 2024
532c3dd
Added debug statements for caching flat reps
pargaw Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def _agg_by_time(self):
)

def _update_subject_event_properties(self):
self.subject_id_dtype = self.events_df.schema["subject_id"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure events_df is not None before accessing its schema.

The assignment of self.subject_id_dtype should be inside the conditional check to ensure events_df is not None before accessing its schema.

-        self.subject_id_dtype = self.events_df.schema["subject_id"]
-        if self.events_df is not None:
+        if self.events_df is not None:
+            self.subject_id_dtype = self.events_df.schema["subject_id"]
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.subject_id_dtype = self.events_df.schema["subject_id"]
if self.events_df is not None:
self.subject_id_dtype = self.events_df.schema["subject_id"]

if self.events_df is not None:
logger.debug("Collecting event types")
self.event_types = (
Expand All @@ -699,15 +700,19 @@ def _update_subject_event_properties(self):
.to_list()
)

n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas()
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict()
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor subject event properties calculation for clarity and efficiency.

The calculation of n_events_per_subject can be streamlined by using direct dictionary construction from n_events dataframe, which avoids the need for zipping and iterating over the series.

-            n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
-            self.n_events_per_subject = {
-                subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
-            }
+            self.n_events_per_subject = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)).to_dict(as_series=False)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
self.n_events_per_subject = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)).to_dict(as_series=False)

self.subject_ids = set(self.n_events_per_subject.keys())

if self.subjects_df is not None:
logger.debug("Collecting subject event counts")
subjects_with_no_events = (
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids
)
subjects_df_subjects = self.subjects_df.select(pl.col("subject_id").cast(pl.Utf8))
subjects_with_no_events = set(subjects_df_subjects["subject_id"].to_list()) - self.subject_ids
for sid in subjects_with_no_events:
self.n_events_per_subject[sid] = 0
self.subject_ids.update(subjects_with_no_events)
Expand All @@ -723,7 +728,18 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |
filter_exprs.append(pl.col(col).is_null())
case _:
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

incl_list = incl_list.cast(df.schema[col])

logger.debug(f"Converted to Series of type {incl_list.dtype}")
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
Expand Down
2 changes: 1 addition & 1 deletion sample_data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw/"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: null
DL_chunk_size: 25

inputs:
subjects:
Expand Down
69 changes: 69 additions & 0 deletions sample_data/dataset_parquet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
defaults:
- dataset_base
- _self_

# So that it can be run multiple times without issue.
do_overwrite: True

cohort_name: "sample"
subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw_parquet"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: 25

inputs:
subjects:
input_df: "${raw_data_dir}/subjects.parquet"
admissions:
input_df: "${raw_data_dir}/admit_vitals.parquet"
start_ts_col: "admit_date"
end_ts_col: "disch_date"
ts_format: "%m/%d/%Y, %H:%M:%S"
event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"]
vitals:
input_df: "${raw_data_dir}/admit_vitals.parquet"
ts_col: "vitals_date"
ts_format: "%m/%d/%Y, %H:%M:%S"
labs:
input_df: "${raw_data_dir}/labs.parquet"
ts_col: "timestamp"
ts_format: "%H:%M:%S-%Y-%m-%d"
medications:
input_df: "${raw_data_dir}/medications.parquet"
ts_col: "timestamp"
ts_format: "%H:%M:%S-%Y-%m-%d"
columns: {"name": "medication"}

measurements:
static:
single_label_classification:
subjects: ["eye_color"]
functional_time_dependent:
age:
functor: AgeFunctor
necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] }
kwargs: { dob_col: "dob" }
dynamic:
multi_label_classification:
admissions: ["department"]
medications:
- name: medication
modifiers:
- [dose, "float"]
- [frequency, "categorical"]
- [duration, "categorical"]
- [generic_name, "categorical"]
univariate_regression:
vitals: ["HR", "temp"]
multivariate_regression:
labs: [["lab_name", "lab_value"]]

outlier_detector_config:
stddev_cutoff: 1.5
min_valid_vocab_element_observations: 5
min_valid_column_observations: 5
min_true_float_frequency: 0.1
min_unique_numerical_observations: 20
min_events_per_subject: 3
agg_by_time_scale: "1h"
Binary file added sample_data/raw_parquet/admit_vitals.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/labs.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/medications.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/subjects.parquet
Binary file not shown.
58 changes: 58 additions & 0 deletions tests/test_e2e_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

import json
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorder module-level imports to the top of the file.

To comply with PEP 8, module-level imports should be at the top of the file.

+import json

Committable suggestion was skipped due to low confidence.

Tools
Ruff

5-5: Module level import not at top of file (E402)

import os
import subprocess
import unittest
Expand All @@ -10,6 +11,8 @@
from tempfile import TemporaryDirectory
from typing import Any

import polars as pl
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved

from tests.utils import MLTypeEqualityCheckableMixin


Expand All @@ -32,6 +35,7 @@ def setUp(self):
self.paths = {}
for n in (
"dataset",
"dataset_from_parquet",
"esds",
"pretraining/CI",
"pretraining/NA",
Expand All @@ -45,6 +49,49 @@ def tearDown(self):
for o in self.dir_objs.values():
o.cleanup()

def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path):
DL_save_dir = dataset_save_dir / "DL_reps"

train_files = list((DL_save_dir / "train").glob("*.parquet"))
tuning_files = list((DL_save_dir / "tuning").glob("*.parquet"))
held_out_files = list((DL_save_dir / "held_out").glob("*.parquet"))

self.assertTrue(len(train_files) > 0)
self.assertTrue(len(tuning_files) > 0)
self.assertTrue(len(held_out_files) > 0)

train_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in train_files])
tuning_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in tuning_files])
held_out_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in held_out_files])

DL_shards = json.loads((dataset_save_dir / "DL_shards.json").read_text())

ESD_subjects = pl.read_parquet(dataset_save_dir / "subjects_df.parquet", use_pyarrow=False)

# Check that the DL shards are correctly partitioned.
all_subjects = set(ESD_subjects["subject_id"].unique().to_list())

self.assertEqual(len(all_subjects), len(ESD_subjects))

all_subj_in_DL_shards = set().union(*DL_shards.values())

self.assertEqual(all_subjects, all_subj_in_DL_shards)

train_DL_subjects = set(train_DL_reps["subject_id"].unique().to_list())
tuning_DL_subjects = set(tuning_DL_reps["subject_id"].unique().to_list())
held_out_DL_subjects = set(held_out_DL_reps["subject_id"].unique().to_list())

all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects

self.assertEqual(all_DL_subjects, all_subjects)

self.assertEqual(len(train_DL_subjects & tuning_DL_subjects), 0)
self.assertEqual(len(train_DL_subjects & held_out_DL_subjects), 0)
self.assertEqual(len(tuning_DL_subjects & held_out_DL_subjects), 0)

self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects))
self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects))

mmcdermott marked this conversation as resolved.
Show resolved Hide resolved
def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True):
if use_subtest:
with self.subTest(case_name):
Expand All @@ -71,6 +118,17 @@ def build_dataset(self):
f"save_dir={self.paths['dataset']}",
]
self._test_command(command_parts, "Build Dataset", use_subtest=False)
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset"])

command_parts = [
"./scripts/build_dataset.py",
f"--config-path='{(root / 'sample_data').resolve()}'",
"--config-name=dataset_parquet",
'"hydra.searchpath=[./configs]"',
f"save_dir={self.paths['dataset_from_parquet']}",
]
self._test_command(command_parts, "Build Dataset from Parquet", use_subtest=False)
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset_from_parquet"])

def build_ESDS_dataset(self):
command_parts = [
Expand Down
Loading