Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cases and (eventually) fixes for #114 #117

Open
wants to merge 24 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
edcd2e9
starting off
mmcdermott Jun 22, 2024
785ccc9
Starting to add test contents; not yet working
mmcdermott Jun 22, 2024
fa6844e
Merge branch 'dev' into fix_114_typing_with_subject_IDs
mmcdermott Jun 22, 2024
cf67bcd
Added a more involved test for dataset construction.
mmcdermott Jun 22, 2024
0434ed4
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
98cafd5
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
48ba63d
Merge branch 'fix_114_typing_with_subject_IDs' of github.com:mmcdermo…
mmcdermott Jun 22, 2024
b134133
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
9acff54
Removed the unnecessary pandas conversion in the getting number of ev…
mmcdermott Jun 22, 2024
1b4f0d8
Things are partially improved, but other tests are still failing. Inv…
mmcdermott Jun 22, 2024
39ba674
This may have fixed it.
mmcdermott Jun 22, 2024
ace03dd
Added minor changes from another branch and more debug logs
pargaw Jun 22, 2024
b41b014
Added assertion to check _filter_col_inclusion
pargaw Jun 22, 2024
aa0fa66
Some more corrections to subject ID typing.
mmcdermott Jun 22, 2024
413dda5
Fixed pytorch dataset issue (maybe)
mmcdermott Jun 22, 2024
9d703b7
Tests still having challenges with shards not overlapping.
mmcdermott Jun 22, 2024
7616598
fixed broken pytorch dataset test given string conventions.
mmcdermott Jun 23, 2024
674b50f
Other changes to try to get things working
mmcdermott Jun 23, 2024
9f3ce52
Added some more logging and dropping of nulls for safety
mmcdermott Jun 23, 2024
71ac9a8
Added flat rep changes from another branch
pargaw Jun 23, 2024
a813468
Added additional changes to fix flat_rep run from another branch
pargaw Jun 23, 2024
87a3874
Set assertion error to warning log
pargaw Jun 23, 2024
5150e05
Added functionality when do_update is False
pargaw Jun 30, 2024
532c3dd
Added debug statements for caching flat reps
pargaw Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def load_flat_rep(
do_update_if_missing: bool = True,
task_df_name: str | None = None,
do_cache_filtered_task: bool = True,
overwrite_cache_filtered_task: bool = False,
subjects_included: dict[str, set[int]] | None = None,
) -> dict[str, pl.LazyFrame]:
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints.
Expand All @@ -68,14 +69,16 @@ def load_flat_rep(
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will
try to update the stored flat representations to reflect these. If `False`, if information is
missing, it will raise a `FileNotFoundError` instead.
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task
task_df_name: If specified, the flat representations loaded will be joined against the task
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a
specified path for this task, then the data will be loaded from there, rather than from the base
dataset.
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the
relevant rows for the task, be cached to disk for faster re-use.
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
the cached file will be loaded if exists.
subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are
used wholesale.

Expand Down Expand Up @@ -148,8 +151,9 @@ def load_flat_rep(

by_split = {}
for sp, all_sp_subjects in ESD.split_subjects.items():
all_sp_subjects = pl.Series(list(all_sp_subjects)).cast(ESD.subject_id_dtype)
if task_df_name is not None:
sp_join_df = join_df.filter(pl.col("subject_id").is_in(list(all_sp_subjects)))
sp_join_df = join_df.filter(pl.col("subject_id").is_in(all_sp_subjects))

static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet")
if task_df_name is not None:
Expand All @@ -171,19 +175,28 @@ def load_flat_rep(
if task_df_name is not None:
fn = fp.parts[-1]
cached_fp = task_window_dir / fn
if cached_fp.is_file():
if cached_fp.is_file() and not overwrite_cache_filtered_task:
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)
window_dfs.append(df)
continue

df = pl.scan_parquet(fp)
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))
filter_join_df = sp_join_df.select(join_keys).filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)

df = df.join(filter_join_df, on=join_keys, how="inner")
df = filter_join_df.join_asof(
df,
by="subject_id",
on="timestamp",
strategy="forward" if "-" in window_size else "backward",
)

if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
Expand All @@ -193,7 +206,7 @@ def load_flat_rep(
df = df.select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)))

window_dfs.append(df)

Expand Down
9 changes: 3 additions & 6 deletions EventStream/data/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,27 +1367,24 @@ def cache_deep_learning_representation(
NRT_dir = self.config.save_dir / "NRT_reps"

shards_fp = self.config.save_dir / "DL_shards.json"
if shards_fp.exists():
if shards_fp.exists() and not do_overwrite:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using json.load() instead of json.loads() with read_text().

The current implementation reads the entire file content into memory and then parses it. This can be optimized by directly loading the JSON content using json.load(), which streams the file content.

- shards = json.loads(shards_fp.read_text())
+ with open(shards_fp, 'r') as file:
+     shards = json.load(file)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if shards_fp.exists() and not do_overwrite:
if shards_fp.exists() and not do_overwrite:
with open(shards_fp, 'r') as file:
shards = json.load(file)

shards = json.loads(shards_fp.read_text())
else:
shards = {}

if subjects_per_output_file is None:
subject_chunks = [self.subject_ids]
else:
subjects = np.random.permutation(list(self.subject_ids))
subjects = np.random.permutation(list(set(self.subject_ids)))
subject_chunks = np.array_split(
subjects,
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file),
)

subject_chunks = [[int(x) for x in c] for c in subject_chunks]

for chunk_idx, subjects_list in enumerate(subject_chunks):
for split, subjects in self.split_subjects.items():
shard_key = f"{split}/{chunk_idx}"
included_subjects = set(subjects_list).intersection({int(x) for x in subjects})
shards[shard_key] = list(included_subjects)
shards[shard_key] = list(set(subjects_list).intersection(subjects))

shards_fp.write_text(json.dumps(shards))

Expand Down
65 changes: 49 additions & 16 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def _agg_by_time(self):
)

def _update_subject_event_properties(self):
self.subject_id_dtype = self.events_df.schema["subject_id"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure events_df is not None before accessing its schema.

The assignment of self.subject_id_dtype should be inside the conditional check to ensure events_df is not None before accessing its schema.

-        self.subject_id_dtype = self.events_df.schema["subject_id"]
-        if self.events_df is not None:
+        if self.events_df is not None:
+            self.subject_id_dtype = self.events_df.schema["subject_id"]
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.subject_id_dtype = self.events_df.schema["subject_id"]
if self.events_df is not None:
self.subject_id_dtype = self.events_df.schema["subject_id"]

if self.events_df is not None:
logger.debug("Collecting event types")
self.event_types = (
Expand All @@ -699,15 +700,28 @@ def _update_subject_event_properties(self):
.to_list()
)

n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas()
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict()
logger.debug("Collecting subject event counts")
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
n_events = n_events.drop_nulls("subject_id")
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
self.subject_ids = set(self.n_events_per_subject.keys())

if self.subjects_df is not None:
logger.debug("Collecting subject event counts")
subjects_with_no_events = (
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids
subjects_df_subjects = (
self.subjects_df
.drop_nulls("subject_id")
.select(pl.col("subject_id").cast(pl.Utf8))
)
subjects_df_subj_ids = set(subjects_df_subjects["subject_id"].to_list())
subj_no_in_df = self.subject_ids - subjects_df_subj_ids
if len(subj_no_in_df) > 0:
logger.warning(f"Found {len(subj_no_in_df)} subjects not in subject df!")
subjects_with_no_events = subjects_df_subj_ids - self.subject_ids
for sid in subjects_with_no_events:
self.n_events_per_subject[sid] = 0
self.subject_ids.update(subjects_with_no_events)
Expand All @@ -723,7 +737,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |
filter_exprs.append(pl.col(col).is_null())
case _:
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

incl_list = incl_list.cast(df.schema[col])

logger.debug(
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
)
Comment on lines +740 to +753
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve type conversion and error handling in _filter_col_inclusion.

The conversion of inclusion targets to the appropriate data type can be simplified by using polars built-in functions, and the error message can be enhanced for better clarity.

-                        logger.debug(
-                            f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
-                            f"{col} to {df.schema[col]}"
-                        )
-                        if isinstance(list(incl_targets)[0], str):
-                            incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
-                        else:
-                            incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
-                        incl_list = incl_list.cast(df.schema[col])
-                        logger.debug(
-                            f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
-                        )
+                        try:
+                            incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
+                        except TypeError as e:
+                            raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
incl_list = incl_list.cast(df.schema[col])
logger.debug(
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
)
try:
incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
except TypeError as e:
raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e

except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
Expand Down Expand Up @@ -1358,6 +1385,8 @@ def build_DL_cached_representation(
# 1. Process subject data into the right format.
if subject_ids:
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
logger.warning( f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion "
f"the size of subjects_df are {len(subjects_df)}")
else:
subjects_df = self.subjects_df

Expand All @@ -1369,6 +1398,7 @@ def build_DL_cached_representation(
pl.col("index").alias("static_indices"),
)
)
logger.debug(f"Size of static_data: {static_data.shape[0]}")

# 2. Process event data into the right format.
if subject_ids:
Expand All @@ -1378,6 +1408,7 @@ def build_DL_cached_representation(
events_df = self.events_df
event_ids = None
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures)
logger.debug(f"Size of event_data: {event_data.shape[0]}")

# 3. Process measurement data into the right base format:
if event_ids:
Expand All @@ -1392,6 +1423,7 @@ def build_DL_cached_representation(

if do_sort_outputs:
dynamic_data = dynamic_data.sort("event_id", "measurement_id")
logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}")

# 4. Join dynamic and event data.

Expand Down Expand Up @@ -1444,7 +1476,8 @@ def _summarize_static_measurements(
if include_only_subjects is None:
df = self.subjects_df
else:
df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects)))
self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))

valid_measures = {}
for feat_col in feature_columns:
Expand Down Expand Up @@ -1490,9 +1523,9 @@ def _summarize_static_measurements(
)

remap_cols = [c for c in pivoted_df.columns if c not in ID_cols]
out_dfs[m] = pivoted_df.lazy().select(
out_dfs[m] = pivoted_df.select(
*ID_cols, *[pl.col(c).alias(f"static/{m}/{c}/present").cast(pl.Boolean) for c in remap_cols]
)
).lazy()

return pl.concat(list(out_dfs.values()), how="align")

Expand All @@ -1504,7 +1537,8 @@ def _summarize_time_dependent_measurements(
if include_only_subjects is None:
df = self.events_df
else:
df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects)))
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))

valid_measures = {}
for feat_col in feature_columns:
Expand Down Expand Up @@ -1549,13 +1583,13 @@ def _summarize_time_dependent_measurements(
)

remap_cols = [c for c in pivoted_df.columns if c not in ID_cols]
out_dfs[m] = pivoted_df.lazy().select(
out_dfs[m] = pivoted_df.select(
*ID_cols,
*[
pl.col(c).cast(pl.Boolean).alias(f"functional_time_dependent/{m}/{c}/present")
for c in remap_cols
],
)
).lazy()

return pl.concat(list(out_dfs.values()), how="align")

Expand All @@ -1567,8 +1601,9 @@ def _summarize_dynamic_measurements(
if include_only_subjects is None:
df = self.dynamic_measurements_df
else:
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.dynamic_measurements_df.join(
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select(
self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])).select(
"event_id"
),
on="event_id",
Expand Down Expand Up @@ -1676,10 +1711,10 @@ def _summarize_dynamic_measurements(
values=values_cols,
aggregate_function=None,
)
.lazy()
.drop("measurement_id")
.group_by("event_id")
.agg(*aggs)
.lazy()
)

return pl.concat(list(out_dfs.values()), how="align")
Expand Down Expand Up @@ -1764,11 +1799,9 @@ def _get_flat_ts_rep(
)
.drop("event_id")
.sort(by=["subject_id", "timestamp"])
.collect()
.lazy(),
[c for c in feature_columns if not c.startswith("static/")],
)
# The above .collect().lazy() shouldn't be necessary but it appears to be for some reason...

def _normalize_flat_rep_df_cols(
self, flat_df: DF_T, feature_columns: list[str] | None = None, set_count_0_to_null: bool = False
Expand Down
13 changes: 11 additions & 2 deletions EventStream/data/pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]:

subject_id, st, end = self.index[idx]

shard = self.subj_map[subject_id]
if str(subject_id) not in self.subj_map:
err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"]

if len(self.subj_map) < 10:
err_str.append("Subject IDs in map:")
err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items())

raise ValueError("\n".join(err_str))

shard = self.subj_map[str(subject_id)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the handling of subject ID.

The subject_id is being accessed as a string, which might not be consistent with other parts of the code where it could be treated as an integer or another type. This inconsistency can lead to errors or unexpected behavior.

-        shard = self.subj_map[str(subject_id)]
+        shard = self.subj_map[subject_id]  # Assuming subject_id is consistently used in its native type across the codebase.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
shard = self.subj_map[str(subject_id)]
shard = self.subj_map[subject_id] # Assuming subject_id is consistently used in its native type across the codebase.

subject_idx = self.subj_indices[subject_id]
static_row = self.static_dfs[shard][subject_idx].to_dict()

Expand All @@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]:
}

if self.config.do_include_subject_id:
out["subject_id"] = subject_id
out["subject_id"] = static_row["subject_id"].item()

seq_len = end - st
if seq_len > self.max_seq_len:
Expand Down
2 changes: 1 addition & 1 deletion configs/dataset_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ center_and_scale: True

hydra:
job:
name: build_${cohort_name}
name: build_dataset
run:
dir: ${save_dir}/.logs
sweep:
Expand Down
2 changes: 1 addition & 1 deletion sample_data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw/"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: null
DL_chunk_size: 25

inputs:
subjects:
Expand Down
Loading
Loading