-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test cases and (eventually) fixes for #114 #117
base: dev
Are you sure you want to change the base?
Changes from 22 commits
edcd2e9
785ccc9
fa6844e
cf67bcd
0434ed4
98cafd5
48ba63d
b134133
9acff54
1b4f0d8
39ba674
ace03dd
b41b014
aa0fa66
413dda5
9d703b7
7616598
674b50f
9f3ce52
71ac9a8
a813468
87a3874
5150e05
532c3dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -690,6 +690,7 @@ def _agg_by_time(self): | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def _update_subject_event_properties(self): | ||||||||||||||||||||||||||||||||||||||
self.subject_id_dtype = self.events_df.schema["subject_id"] | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure The assignment of - self.subject_id_dtype = self.events_df.schema["subject_id"]
- if self.events_df is not None:
+ if self.events_df is not None:
+ self.subject_id_dtype = self.events_df.schema["subject_id"] Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
if self.events_df is not None: | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting event types") | ||||||||||||||||||||||||||||||||||||||
self.event_types = ( | ||||||||||||||||||||||||||||||||||||||
|
@@ -699,15 +700,28 @@ def _update_subject_event_properties(self): | |||||||||||||||||||||||||||||||||||||
.to_list() | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting subject event counts") | ||||||||||||||||||||||||||||||||||||||
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count")) | ||||||||||||||||||||||||||||||||||||||
n_events = n_events.drop_nulls("subject_id") | ||||||||||||||||||||||||||||||||||||||
# here we cast to str to avoid issues with the subject_id column being various other types as we | ||||||||||||||||||||||||||||||||||||||
# will eventually JSON serialize it. | ||||||||||||||||||||||||||||||||||||||
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject = { | ||||||||||||||||||||||||||||||||||||||
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"]) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
self.subject_ids = set(self.n_events_per_subject.keys()) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.subjects_df is not None: | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting subject event counts") | ||||||||||||||||||||||||||||||||||||||
subjects_with_no_events = ( | ||||||||||||||||||||||||||||||||||||||
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids | ||||||||||||||||||||||||||||||||||||||
subjects_df_subjects = ( | ||||||||||||||||||||||||||||||||||||||
self.subjects_df | ||||||||||||||||||||||||||||||||||||||
.drop_nulls("subject_id") | ||||||||||||||||||||||||||||||||||||||
.select(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
subjects_df_subj_ids = set(subjects_df_subjects["subject_id"].to_list()) | ||||||||||||||||||||||||||||||||||||||
subj_no_in_df = self.subject_ids - subjects_df_subj_ids | ||||||||||||||||||||||||||||||||||||||
if len(subj_no_in_df) > 0: | ||||||||||||||||||||||||||||||||||||||
logger.warning(f"Found {len(subj_no_in_df)} subjects not in subject df!") | ||||||||||||||||||||||||||||||||||||||
subjects_with_no_events = subjects_df_subj_ids - self.subject_ids | ||||||||||||||||||||||||||||||||||||||
for sid in subjects_with_no_events: | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject[sid] = 0 | ||||||||||||||||||||||||||||||||||||||
self.subject_ids.update(subjects_with_no_events) | ||||||||||||||||||||||||||||||||||||||
|
@@ -723,7 +737,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | | |||||||||||||||||||||||||||||||||||||
filter_exprs.append(pl.col(col).is_null()) | ||||||||||||||||||||||||||||||||||||||
case _: | ||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for " | ||||||||||||||||||||||||||||||||||||||
f"{col} to {df.schema[col]}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
if isinstance(list(incl_targets)[0], str): | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
incl_list = incl_list.cast(df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+740
to
+753
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve type conversion and error handling in The conversion of inclusion targets to the appropriate data type can be simplified by using - logger.debug(
- f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
- f"{col} to {df.schema[col]}"
- )
- if isinstance(list(incl_targets)[0], str):
- incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
- else:
- incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
- incl_list = incl_list.cast(df.schema[col])
- logger.debug(
- f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
- )
+ try:
+ incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
+ except TypeError as e:
+ raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
except TypeError as e: | ||||||||||||||||||||||||||||||||||||||
incl_targets_by_type = defaultdict(list) | ||||||||||||||||||||||||||||||||||||||
for t in incl_targets: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1358,6 +1385,8 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
# 1. Process subject data into the right format. | ||||||||||||||||||||||||||||||||||||||
if subject_ids: | ||||||||||||||||||||||||||||||||||||||
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids}) | ||||||||||||||||||||||||||||||||||||||
logger.warning( f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion " | ||||||||||||||||||||||||||||||||||||||
f"the size of subjects_df are {len(subjects_df)}") | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
subjects_df = self.subjects_df | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -1369,6 +1398,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
pl.col("index").alias("static_indices"), | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of static_data: {static_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 2. Process event data into the right format. | ||||||||||||||||||||||||||||||||||||||
if subject_ids: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1378,6 +1408,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
events_df = self.events_df | ||||||||||||||||||||||||||||||||||||||
event_ids = None | ||||||||||||||||||||||||||||||||||||||
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures) | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of event_data: {event_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 3. Process measurement data into the right base format: | ||||||||||||||||||||||||||||||||||||||
if event_ids: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1392,6 +1423,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if do_sort_outputs: | ||||||||||||||||||||||||||||||||||||||
dynamic_data = dynamic_data.sort("event_id", "measurement_id") | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 4. Join dynamic and event data. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -1444,7 +1476,8 @@ def _summarize_static_measurements( | |||||||||||||||||||||||||||||||||||||
if include_only_subjects is None: | ||||||||||||||||||||||||||||||||||||||
df = self.subjects_df | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) | ||||||||||||||||||||||||||||||||||||||
self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
valid_measures = {} | ||||||||||||||||||||||||||||||||||||||
for feat_col in feature_columns: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1490,9 +1523,9 @@ def _summarize_static_measurements( | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
remap_cols = [c for c in pivoted_df.columns if c not in ID_cols] | ||||||||||||||||||||||||||||||||||||||
out_dfs[m] = pivoted_df.lazy().select( | ||||||||||||||||||||||||||||||||||||||
out_dfs[m] = pivoted_df.select( | ||||||||||||||||||||||||||||||||||||||
*ID_cols, *[pl.col(c).alias(f"static/{m}/{c}/present").cast(pl.Boolean) for c in remap_cols] | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
).lazy() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return pl.concat(list(out_dfs.values()), how="align") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -1504,7 +1537,8 @@ def _summarize_time_dependent_measurements( | |||||||||||||||||||||||||||||||||||||
if include_only_subjects is None: | ||||||||||||||||||||||||||||||||||||||
df = self.events_df | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) | ||||||||||||||||||||||||||||||||||||||
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
valid_measures = {} | ||||||||||||||||||||||||||||||||||||||
for feat_col in feature_columns: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1549,13 +1583,13 @@ def _summarize_time_dependent_measurements( | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
remap_cols = [c for c in pivoted_df.columns if c not in ID_cols] | ||||||||||||||||||||||||||||||||||||||
out_dfs[m] = pivoted_df.lazy().select( | ||||||||||||||||||||||||||||||||||||||
out_dfs[m] = pivoted_df.select( | ||||||||||||||||||||||||||||||||||||||
*ID_cols, | ||||||||||||||||||||||||||||||||||||||
*[ | ||||||||||||||||||||||||||||||||||||||
pl.col(c).cast(pl.Boolean).alias(f"functional_time_dependent/{m}/{c}/present") | ||||||||||||||||||||||||||||||||||||||
for c in remap_cols | ||||||||||||||||||||||||||||||||||||||
], | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
).lazy() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return pl.concat(list(out_dfs.values()), how="align") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -1567,8 +1601,9 @@ def _summarize_dynamic_measurements( | |||||||||||||||||||||||||||||||||||||
if include_only_subjects is None: | ||||||||||||||||||||||||||||||||||||||
df = self.dynamic_measurements_df | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
df = self.dynamic_measurements_df.join( | ||||||||||||||||||||||||||||||||||||||
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select( | ||||||||||||||||||||||||||||||||||||||
self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])).select( | ||||||||||||||||||||||||||||||||||||||
"event_id" | ||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||
on="event_id", | ||||||||||||||||||||||||||||||||||||||
|
@@ -1676,10 +1711,10 @@ def _summarize_dynamic_measurements( | |||||||||||||||||||||||||||||||||||||
values=values_cols, | ||||||||||||||||||||||||||||||||||||||
aggregate_function=None, | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
.lazy() | ||||||||||||||||||||||||||||||||||||||
.drop("measurement_id") | ||||||||||||||||||||||||||||||||||||||
.group_by("event_id") | ||||||||||||||||||||||||||||||||||||||
.agg(*aggs) | ||||||||||||||||||||||||||||||||||||||
.lazy() | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
return pl.concat(list(out_dfs.values()), how="align") | ||||||||||||||||||||||||||||||||||||||
|
@@ -1764,11 +1799,9 @@ def _get_flat_ts_rep( | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
.drop("event_id") | ||||||||||||||||||||||||||||||||||||||
.sort(by=["subject_id", "timestamp"]) | ||||||||||||||||||||||||||||||||||||||
.collect() | ||||||||||||||||||||||||||||||||||||||
.lazy(), | ||||||||||||||||||||||||||||||||||||||
[c for c in feature_columns if not c.startswith("static/")], | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
# The above .collect().lazy() shouldn't be necessary but it appears to be for some reason... | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def _normalize_flat_rep_df_cols( | ||||||||||||||||||||||||||||||||||||||
self, flat_df: DF_T, feature_columns: list[str] | None = None, set_count_0_to_null: bool = False | ||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: | |||||
|
||||||
subject_id, st, end = self.index[idx] | ||||||
|
||||||
shard = self.subj_map[subject_id] | ||||||
if str(subject_id) not in self.subj_map: | ||||||
err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"] | ||||||
|
||||||
if len(self.subj_map) < 10: | ||||||
err_str.append("Subject IDs in map:") | ||||||
err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items()) | ||||||
|
||||||
raise ValueError("\n".join(err_str)) | ||||||
|
||||||
shard = self.subj_map[str(subject_id)] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the handling of subject ID. The - shard = self.subj_map[str(subject_id)]
+ shard = self.subj_map[subject_id] # Assuming subject_id is consistently used in its native type across the codebase. Committable suggestion
Suggested change
|
||||||
subject_idx = self.subj_indices[subject_id] | ||||||
static_row = self.static_dfs[shard][subject_idx].to_dict() | ||||||
|
||||||
|
@@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: | |||||
} | ||||||
|
||||||
if self.config.do_include_subject_id: | ||||||
out["subject_id"] = subject_id | ||||||
out["subject_id"] = static_row["subject_id"].item() | ||||||
|
||||||
seq_len = end - st | ||||||
if seq_len > self.max_seq_len: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
json.load()
instead ofjson.loads()
withread_text()
.The current implementation reads the entire file content into memory and then parses it. This can be optimized by directly loading the JSON content using
json.load()
, which streams the file content.Committable suggestion