Skip to content

Commit

Permalink
Added additional changes to fix flat_rep run from another branch
Browse files Browse the repository at this point in the history
  • Loading branch information
pargaw committed Jun 23, 2024
1 parent 71ac9a8 commit a813468
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,8 @@ def _summarize_static_measurements(
if include_only_subjects is None:
df = self.subjects_df
else:
df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects)))
self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))

valid_measures = {}
for feat_col in feature_columns:
Expand Down Expand Up @@ -1525,9 +1526,9 @@ def _summarize_static_measurements(
)

remap_cols = [c for c in pivoted_df.columns if c not in ID_cols]
out_dfs[m] = pivoted_df.lazy().select(
out_dfs[m] = pivoted_df.select(
*ID_cols, *[pl.col(c).alias(f"static/{m}/{c}/present").cast(pl.Boolean) for c in remap_cols]
)
).lazy()

return pl.concat(list(out_dfs.values()), how="align")

Expand All @@ -1539,7 +1540,8 @@ def _summarize_time_dependent_measurements(
if include_only_subjects is None:
df = self.events_df
else:
df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects)))
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))

valid_measures = {}
for feat_col in feature_columns:
Expand Down Expand Up @@ -1584,13 +1586,13 @@ def _summarize_time_dependent_measurements(
)

remap_cols = [c for c in pivoted_df.columns if c not in ID_cols]
out_dfs[m] = pivoted_df.lazy().select(
out_dfs[m] = pivoted_df.select(
*ID_cols,
*[
pl.col(c).cast(pl.Boolean).alias(f"functional_time_dependent/{m}/{c}/present")
for c in remap_cols
],
)
).lazy()

return pl.concat(list(out_dfs.values()), how="align")

Expand All @@ -1602,8 +1604,9 @@ def _summarize_dynamic_measurements(
if include_only_subjects is None:
df = self.dynamic_measurements_df
else:
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.dynamic_measurements_df.join(
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select(
self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])).select(
"event_id"
),
on="event_id",
Expand Down Expand Up @@ -1711,10 +1714,10 @@ def _summarize_dynamic_measurements(
values=values_cols,
aggregate_function=None,
)
.lazy()
.drop("measurement_id")
.group_by("event_id")
.agg(*aggs)
.lazy()
)

return pl.concat(list(out_dfs.values()), how="align")
Expand Down Expand Up @@ -1799,11 +1802,9 @@ def _get_flat_ts_rep(
)
.drop("event_id")
.sort(by=["subject_id", "timestamp"])
.collect()
.lazy(),
[c for c in feature_columns if not c.startswith("static/")],
)
# The above .collect().lazy() shouldn't be necessary but it appears to be for some reason...

def _normalize_flat_rep_df_cols(
self, flat_df: DF_T, feature_columns: list[str] | None = None, set_count_0_to_null: bool = False
Expand Down

0 comments on commit a813468

Please sign in to comment.