-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test cases and (eventually) fixes for #114 #117
base: dev
Are you sure you want to change the base?
Changes from 10 commits
edcd2e9
785ccc9
fa6844e
cf67bcd
0434ed4
98cafd5
48ba63d
b134133
9acff54
1b4f0d8
39ba674
ace03dd
b41b014
aa0fa66
413dda5
9d703b7
7616598
674b50f
9f3ce52
71ac9a8
a813468
87a3874
5150e05
532c3dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -690,6 +690,7 @@ def _agg_by_time(self): | |||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def _update_subject_event_properties(self): | ||||||||||||||||||||||||
self.subject_id_dtype = self.events_df.schema["subject_id"] | ||||||||||||||||||||||||
if self.events_df is not None: | ||||||||||||||||||||||||
logger.debug("Collecting event types") | ||||||||||||||||||||||||
self.event_types = ( | ||||||||||||||||||||||||
|
@@ -699,15 +700,19 @@ def _update_subject_event_properties(self): | |||||||||||||||||||||||
.to_list() | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() | ||||||||||||||||||||||||
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() | ||||||||||||||||||||||||
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count")) | ||||||||||||||||||||||||
# here we cast to str to avoid issues with the subject_id column being various other types as we | ||||||||||||||||||||||||
# will eventually JSON serialize it. | ||||||||||||||||||||||||
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||
self.n_events_per_subject = { | ||||||||||||||||||||||||
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"]) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor subject event properties calculation for clarity and efficiency. The calculation of - n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
- self.n_events_per_subject = {
- subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
- }
+ self.n_events_per_subject = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)).to_dict(as_series=False) Committable suggestion
Suggested change
|
||||||||||||||||||||||||
self.subject_ids = set(self.n_events_per_subject.keys()) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self.subjects_df is not None: | ||||||||||||||||||||||||
logger.debug("Collecting subject event counts") | ||||||||||||||||||||||||
subjects_with_no_events = ( | ||||||||||||||||||||||||
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
subjects_df_subjects = self.subjects_df.select(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||
subjects_with_no_events = set(subjects_df_subjects["subject_id"].to_list()) - self.subject_ids | ||||||||||||||||||||||||
for sid in subjects_with_no_events: | ||||||||||||||||||||||||
self.n_events_per_subject[sid] = 0 | ||||||||||||||||||||||||
self.subject_ids.update(subjects_with_no_events) | ||||||||||||||||||||||||
|
@@ -723,7 +728,18 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | | |||||||||||||||||||||||
filter_exprs.append(pl.col(col).is_null()) | ||||||||||||||||||||||||
case _: | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for " | ||||||||||||||||||||||||
f"{col} to {df.schema[col]}" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
if isinstance(list(incl_targets)[0], str): | ||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
incl_list = incl_list.cast(df.schema[col]) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
logger.debug(f"Converted to Series of type {incl_list.dtype}") | ||||||||||||||||||||||||
except TypeError as e: | ||||||||||||||||||||||||
incl_targets_by_type = defaultdict(list) | ||||||||||||||||||||||||
for t in incl_targets: | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
defaults: | ||
- dataset_base | ||
- _self_ | ||
|
||
# So that it can be run multiple times without issue. | ||
do_overwrite: True | ||
|
||
cohort_name: "sample" | ||
subject_id_col: "MRN" | ||
raw_data_dir: "./sample_data/raw_parquet" | ||
save_dir: "./sample_data/processed/${cohort_name}" | ||
|
||
DL_chunk_size: 25 | ||
|
||
inputs: | ||
subjects: | ||
input_df: "${raw_data_dir}/subjects.parquet" | ||
admissions: | ||
input_df: "${raw_data_dir}/admit_vitals.parquet" | ||
start_ts_col: "admit_date" | ||
end_ts_col: "disch_date" | ||
ts_format: "%m/%d/%Y, %H:%M:%S" | ||
event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"] | ||
vitals: | ||
input_df: "${raw_data_dir}/admit_vitals.parquet" | ||
ts_col: "vitals_date" | ||
ts_format: "%m/%d/%Y, %H:%M:%S" | ||
labs: | ||
input_df: "${raw_data_dir}/labs.parquet" | ||
ts_col: "timestamp" | ||
ts_format: "%H:%M:%S-%Y-%m-%d" | ||
medications: | ||
input_df: "${raw_data_dir}/medications.parquet" | ||
ts_col: "timestamp" | ||
ts_format: "%H:%M:%S-%Y-%m-%d" | ||
columns: {"name": "medication"} | ||
|
||
measurements: | ||
static: | ||
single_label_classification: | ||
subjects: ["eye_color"] | ||
functional_time_dependent: | ||
age: | ||
functor: AgeFunctor | ||
necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] } | ||
kwargs: { dob_col: "dob" } | ||
dynamic: | ||
multi_label_classification: | ||
admissions: ["department"] | ||
medications: | ||
- name: medication | ||
modifiers: | ||
- [dose, "float"] | ||
- [frequency, "categorical"] | ||
- [duration, "categorical"] | ||
- [generic_name, "categorical"] | ||
univariate_regression: | ||
vitals: ["HR", "temp"] | ||
multivariate_regression: | ||
labs: [["lab_name", "lab_value"]] | ||
|
||
outlier_detector_config: | ||
stddev_cutoff: 1.5 | ||
min_valid_vocab_element_observations: 5 | ||
min_valid_column_observations: 5 | ||
min_true_float_frequency: 0.1 | ||
min_unique_numerical_observations: 20 | ||
min_events_per_subject: 3 | ||
agg_by_time_scale: "1h" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) | ||
|
||
import json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reorder module-level imports to the top of the file. To comply with PEP 8, module-level imports should be at the top of the file. +import json
ToolsRuff
|
||
import os | ||
import subprocess | ||
import unittest | ||
|
@@ -10,6 +11,8 @@ | |
from tempfile import TemporaryDirectory | ||
from typing import Any | ||
|
||
import polars as pl | ||
mmcdermott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from tests.utils import MLTypeEqualityCheckableMixin | ||
|
||
|
||
|
@@ -32,6 +35,7 @@ def setUp(self): | |
self.paths = {} | ||
for n in ( | ||
"dataset", | ||
"dataset_from_parquet", | ||
"esds", | ||
"pretraining/CI", | ||
"pretraining/NA", | ||
|
@@ -45,6 +49,49 @@ def tearDown(self): | |
for o in self.dir_objs.values(): | ||
o.cleanup() | ||
|
||
def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path): | ||
DL_save_dir = dataset_save_dir / "DL_reps" | ||
|
||
train_files = list((DL_save_dir / "train").glob("*.parquet")) | ||
tuning_files = list((DL_save_dir / "tuning").glob("*.parquet")) | ||
held_out_files = list((DL_save_dir / "held_out").glob("*.parquet")) | ||
|
||
self.assertTrue(len(train_files) > 0) | ||
self.assertTrue(len(tuning_files) > 0) | ||
self.assertTrue(len(held_out_files) > 0) | ||
|
||
train_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in train_files]) | ||
tuning_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in tuning_files]) | ||
held_out_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in held_out_files]) | ||
|
||
DL_shards = json.loads((dataset_save_dir / "DL_shards.json").read_text()) | ||
|
||
ESD_subjects = pl.read_parquet(dataset_save_dir / "subjects_df.parquet", use_pyarrow=False) | ||
|
||
# Check that the DL shards are correctly partitioned. | ||
all_subjects = set(ESD_subjects["subject_id"].unique().to_list()) | ||
|
||
self.assertEqual(len(all_subjects), len(ESD_subjects)) | ||
|
||
all_subj_in_DL_shards = set().union(*DL_shards.values()) | ||
|
||
self.assertEqual(all_subjects, all_subj_in_DL_shards) | ||
|
||
train_DL_subjects = set(train_DL_reps["subject_id"].unique().to_list()) | ||
tuning_DL_subjects = set(tuning_DL_reps["subject_id"].unique().to_list()) | ||
held_out_DL_subjects = set(held_out_DL_reps["subject_id"].unique().to_list()) | ||
|
||
all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects | ||
|
||
self.assertEqual(all_DL_subjects, all_subjects) | ||
|
||
self.assertEqual(len(train_DL_subjects & tuning_DL_subjects), 0) | ||
self.assertEqual(len(train_DL_subjects & held_out_DL_subjects), 0) | ||
self.assertEqual(len(tuning_DL_subjects & held_out_DL_subjects), 0) | ||
|
||
self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects)) | ||
self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects)) | ||
|
||
mmcdermott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True): | ||
if use_subtest: | ||
with self.subTest(case_name): | ||
|
@@ -71,6 +118,17 @@ def build_dataset(self): | |
f"save_dir={self.paths['dataset']}", | ||
] | ||
self._test_command(command_parts, "Build Dataset", use_subtest=False) | ||
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset"]) | ||
|
||
command_parts = [ | ||
"./scripts/build_dataset.py", | ||
f"--config-path='{(root / 'sample_data').resolve()}'", | ||
"--config-name=dataset_parquet", | ||
'"hydra.searchpath=[./configs]"', | ||
f"save_dir={self.paths['dataset_from_parquet']}", | ||
] | ||
self._test_command(command_parts, "Build Dataset from Parquet", use_subtest=False) | ||
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset_from_parquet"]) | ||
|
||
def build_ESDS_dataset(self): | ||
command_parts = [ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure
events_df
is notNone
before accessing its schema.The assignment of
self.subject_id_dtype
should be inside the conditional check to ensureevents_df
is notNone
before accessing its schema.Committable suggestion