-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test cases and (eventually) fixes for #114 #117
base: dev
Are you sure you want to change the base?
Changes from 18 commits
edcd2e9
785ccc9
fa6844e
cf67bcd
0434ed4
98cafd5
48ba63d
b134133
9acff54
1b4f0d8
39ba674
ace03dd
b41b014
aa0fa66
413dda5
9d703b7
7616598
674b50f
9f3ce52
71ac9a8
a813468
87a3874
5150e05
532c3dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -690,6 +690,7 @@ def _agg_by_time(self): | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def _update_subject_event_properties(self): | ||||||||||||||||||||||||||||||||||||||
self.subject_id_dtype = self.events_df.schema["subject_id"] | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure The assignment of - self.subject_id_dtype = self.events_df.schema["subject_id"]
- if self.events_df is not None:
+ if self.events_df is not None:
+ self.subject_id_dtype = self.events_df.schema["subject_id"] Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
if self.events_df is not None: | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting event types") | ||||||||||||||||||||||||||||||||||||||
self.event_types = ( | ||||||||||||||||||||||||||||||||||||||
|
@@ -699,15 +700,19 @@ def _update_subject_event_properties(self): | |||||||||||||||||||||||||||||||||||||
.to_list() | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() | ||||||||||||||||||||||||||||||||||||||
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count")) | ||||||||||||||||||||||||||||||||||||||
# here we cast to str to avoid issues with the subject_id column being various other types as we | ||||||||||||||||||||||||||||||||||||||
# will eventually JSON serialize it. | ||||||||||||||||||||||||||||||||||||||
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject = { | ||||||||||||||||||||||||||||||||||||||
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"]) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactor subject event properties calculation for clarity and efficiency. The calculation of - n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
- self.n_events_per_subject = {
- subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
- }
+ self.n_events_per_subject = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)).to_dict(as_series=False) Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
self.subject_ids = set(self.n_events_per_subject.keys()) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.subjects_df is not None: | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting subject event counts") | ||||||||||||||||||||||||||||||||||||||
subjects_with_no_events = ( | ||||||||||||||||||||||||||||||||||||||
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
subjects_df_subjects = self.subjects_df.select(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
subjects_with_no_events = set(subjects_df_subjects["subject_id"].to_list()) - self.subject_ids | ||||||||||||||||||||||||||||||||||||||
for sid in subjects_with_no_events: | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject[sid] = 0 | ||||||||||||||||||||||||||||||||||||||
self.subject_ids.update(subjects_with_no_events) | ||||||||||||||||||||||||||||||||||||||
|
@@ -723,7 +728,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | | |||||||||||||||||||||||||||||||||||||
filter_exprs.append(pl.col(col).is_null()) | ||||||||||||||||||||||||||||||||||||||
case _: | ||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for " | ||||||||||||||||||||||||||||||||||||||
f"{col} to {df.schema[col]}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
if isinstance(list(incl_targets)[0], str): | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
incl_list = incl_list.cast(df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+740
to
+753
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve type conversion and error handling in The conversion of inclusion targets to the appropriate data type can be simplified by using - logger.debug(
- f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
- f"{col} to {df.schema[col]}"
- )
- if isinstance(list(incl_targets)[0], str):
- incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
- else:
- incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
- incl_list = incl_list.cast(df.schema[col])
- logger.debug(
- f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
- )
+ try:
+ incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
+ except TypeError as e:
+ raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
except TypeError as e: | ||||||||||||||||||||||||||||||||||||||
incl_targets_by_type = defaultdict(list) | ||||||||||||||||||||||||||||||||||||||
for t in incl_targets: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1358,6 +1376,11 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
# 1. Process subject data into the right format. | ||||||||||||||||||||||||||||||||||||||
if subject_ids: | ||||||||||||||||||||||||||||||||||||||
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids}) | ||||||||||||||||||||||||||||||||||||||
if len(subject_ids) != len(subjects_df): | ||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||
f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion " | ||||||||||||||||||||||||||||||||||||||
f"the size of subjects_df are {len(subjects_df)}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
subjects_df = self.subjects_df | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -1369,6 +1392,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
pl.col("index").alias("static_indices"), | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of static_data: {static_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 2. Process event data into the right format. | ||||||||||||||||||||||||||||||||||||||
if subject_ids: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1378,6 +1402,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
events_df = self.events_df | ||||||||||||||||||||||||||||||||||||||
event_ids = None | ||||||||||||||||||||||||||||||||||||||
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures) | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of event_data: {event_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 3. Process measurement data into the right base format: | ||||||||||||||||||||||||||||||||||||||
if event_ids: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1392,6 +1417,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if do_sort_outputs: | ||||||||||||||||||||||||||||||||||||||
dynamic_data = dynamic_data.sort("event_id", "measurement_id") | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 4. Join dynamic and event data. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: | |||||
|
||||||
subject_id, st, end = self.index[idx] | ||||||
|
||||||
shard = self.subj_map[subject_id] | ||||||
if str(subject_id) not in self.subj_map: | ||||||
err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"] | ||||||
|
||||||
if len(self.subj_map) < 10: | ||||||
err_str.append("Subject IDs in map:") | ||||||
err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items()) | ||||||
|
||||||
raise ValueError("\n".join(err_str)) | ||||||
|
||||||
shard = self.subj_map[str(subject_id)] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the handling of subject ID. The - shard = self.subj_map[str(subject_id)]
+ shard = self.subj_map[subject_id] # Assuming subject_id is consistently used in its native type across the codebase. Committable suggestion
Suggested change
|
||||||
subject_idx = self.subj_indices[subject_id] | ||||||
static_row = self.static_dfs[shard][subject_idx].to_dict() | ||||||
|
||||||
|
@@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: | |||||
} | ||||||
|
||||||
if self.config.do_include_subject_id: | ||||||
out["subject_id"] = subject_id | ||||||
out["subject_id"] = static_row["subject_id"].item() | ||||||
|
||||||
seq_len = end - st | ||||||
if seq_len > self.max_seq_len: | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
defaults: | ||
- dataset_base | ||
- _self_ | ||
|
||
# So that it can be run multiple times without issue. | ||
do_overwrite: True | ||
|
||
cohort_name: "sample" | ||
subject_id_col: "MRN" | ||
raw_data_dir: "./sample_data/raw_parquet" | ||
save_dir: "./sample_data/processed/${cohort_name}" | ||
|
||
DL_chunk_size: 25 | ||
|
||
inputs: | ||
subjects: | ||
input_df: "${raw_data_dir}/subjects.parquet" | ||
admissions: | ||
input_df: "${raw_data_dir}/admit_vitals.parquet" | ||
start_ts_col: "admit_date" | ||
end_ts_col: "disch_date" | ||
ts_format: "%m/%d/%Y, %H:%M:%S" | ||
event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"] | ||
vitals: | ||
input_df: "${raw_data_dir}/admit_vitals.parquet" | ||
ts_col: "vitals_date" | ||
ts_format: "%m/%d/%Y, %H:%M:%S" | ||
labs: | ||
input_df: "${raw_data_dir}/labs.parquet" | ||
ts_col: "timestamp" | ||
ts_format: "%H:%M:%S-%Y-%m-%d" | ||
medications: | ||
input_df: "${raw_data_dir}/medications.parquet" | ||
ts_col: "timestamp" | ||
ts_format: "%H:%M:%S-%Y-%m-%d" | ||
columns: {"name": "medication"} | ||
|
||
measurements: | ||
static: | ||
single_label_classification: | ||
subjects: ["eye_color"] | ||
functional_time_dependent: | ||
age: | ||
functor: AgeFunctor | ||
necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] } | ||
kwargs: { dob_col: "dob" } | ||
dynamic: | ||
multi_label_classification: | ||
admissions: ["department"] | ||
medications: | ||
- name: medication | ||
modifiers: | ||
- [dose, "float"] | ||
- [frequency, "categorical"] | ||
- [duration, "categorical"] | ||
- [generic_name, "categorical"] | ||
univariate_regression: | ||
vitals: ["HR", "temp"] | ||
multivariate_regression: | ||
labs: [["lab_name", "lab_value"]] | ||
|
||
outlier_detector_config: | ||
stddev_cutoff: 1.5 | ||
min_valid_vocab_element_observations: 5 | ||
min_valid_column_observations: 5 | ||
min_true_float_frequency: 0.1 | ||
min_unique_numerical_observations: 20 | ||
min_events_per_subject: 3 | ||
agg_by_time_scale: "1h" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True) | ||
|
||
import json | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reorder module-level imports to the top of the file. To comply with PEP 8, module-level imports should be at the top of the file. +import json
ToolsRuff
|
||
import os | ||
import subprocess | ||
import unittest | ||
|
@@ -10,6 +11,8 @@ | |
from tempfile import TemporaryDirectory | ||
from typing import Any | ||
|
||
import polars as pl | ||
mmcdermott marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from tests.utils import MLTypeEqualityCheckableMixin | ||
|
||
|
||
|
@@ -32,6 +35,7 @@ def setUp(self): | |
self.paths = {} | ||
for n in ( | ||
"dataset", | ||
"dataset_from_parquet", | ||
"esds", | ||
"pretraining/CI", | ||
"pretraining/NA", | ||
|
@@ -45,6 +49,67 @@ def tearDown(self): | |
for o in self.dir_objs.values(): | ||
o.cleanup() | ||
|
||
def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path): | ||
DL_save_dir = dataset_save_dir / "DL_reps" | ||
|
||
train_files = list((DL_save_dir / "train").glob("*.parquet")) | ||
tuning_files = list((DL_save_dir / "tuning").glob("*.parquet")) | ||
held_out_files = list((DL_save_dir / "held_out").glob("*.parquet")) | ||
|
||
assert len(set(train_files) & set(tuning_files)) == 0 | ||
assert len(set(train_files) & set(held_out_files)) == 0 | ||
assert len(set(tuning_files) & set(held_out_files)) == 0 | ||
|
||
self.assertTrue(len(train_files) > 0) | ||
self.assertTrue(len(tuning_files) > 0) | ||
self.assertTrue(len(held_out_files) > 0) | ||
|
||
train_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in train_files]) | ||
tuning_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in tuning_files]) | ||
held_out_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in held_out_files]) | ||
|
||
DL_shards = json.loads((dataset_save_dir / "DL_shards.json").read_text()) | ||
|
||
ESD_subjects = pl.read_parquet(dataset_save_dir / "subjects_df.parquet", use_pyarrow=False) | ||
|
||
# Check that the DL shards are correctly partitioned. | ||
all_subjects = set(ESD_subjects["subject_id"].unique().to_list()) | ||
|
||
self.assertEqual(len(all_subjects), len(ESD_subjects)) | ||
|
||
all_subj_in_DL_shards = set().union(*DL_shards.values()) | ||
|
||
all_subj_in_DL_shards = set( | ||
pl.Series(list(all_subj_in_DL_shards)).cast(ESD_subjects["subject_id"].dtype).to_list() | ||
) | ||
|
||
self.assertEqual(all_subjects, all_subj_in_DL_shards) | ||
|
||
all_train_DL_shard_subj = set().union(*(v for k, v in DL_shards.items() if k.startswith("train"))) | ||
all_tuning_DL_shard_subj = set().union(*(v for k, v in DL_shards.items() if k.startswith("tuning"))) | ||
all_held_out_DL_shard_subj = set().union( | ||
*(v for k, v in DL_shards.items() if k.startswith("held_out")) | ||
) | ||
|
||
self.assertEqual(len(all_train_DL_shard_subj & all_tuning_DL_shard_subj), 0) | ||
self.assertEqual(len(all_train_DL_shard_subj & all_held_out_DL_shard_subj), 0) | ||
self.assertEqual(len(all_tuning_DL_shard_subj & all_held_out_DL_shard_subj), 0) | ||
|
||
train_DL_subjects = set(train_DL_reps["subject_id"].to_list()) | ||
tuning_DL_subjects = set(tuning_DL_reps["subject_id"].to_list()) | ||
held_out_DL_subjects = set(held_out_DL_reps["subject_id"].to_list()) | ||
|
||
self.assertEqual(all_train_DL_shard_subj, {str(x) for x in train_DL_subjects}) | ||
self.assertEqual(all_tuning_DL_shard_subj, {str(x) for x in tuning_DL_subjects}) | ||
self.assertEqual(all_held_out_DL_shard_subj, {str(x) for x in held_out_DL_subjects}) | ||
|
||
self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects)) | ||
self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects)) | ||
|
||
all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects | ||
|
||
self.assertEqual(all_DL_subjects, all_subjects) | ||
|
||
def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True): | ||
if use_subtest: | ||
with self.subTest(case_name): | ||
|
@@ -71,6 +136,17 @@ def build_dataset(self): | |
f"save_dir={self.paths['dataset']}", | ||
] | ||
self._test_command(command_parts, "Build Dataset", use_subtest=False) | ||
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset"]) | ||
|
||
command_parts = [ | ||
"./scripts/build_dataset.py", | ||
f"--config-path='{(root / 'sample_data').resolve()}'", | ||
"--config-name=dataset_parquet", | ||
'"hydra.searchpath=[./configs]"', | ||
f"save_dir={self.paths['dataset_from_parquet']}", | ||
] | ||
self._test_command(command_parts, "Build Dataset from Parquet", use_subtest=False) | ||
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset_from_parquet"]) | ||
|
||
def build_ESDS_dataset(self): | ||
command_parts = [ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
json.load()
instead ofjson.loads()
withread_text()
.The current implementation reads the entire file content into memory and then parses it. This can be optimized by directly loading the JSON content using
json.load()
, which streams the file content.Committable suggestion