Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjusted join in flat reps to account for different timestamps with t… #107

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
19 changes: 14 additions & 5 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def load_flat_rep(
do_update_if_missing: bool = True,
task_df_name: str | None = None,
do_cache_filtered_task: bool = True,
overwrite_cache_filtered_task: bool = False,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper documentation for the new parameter.

The new parameter overwrite_cache_filtered_task should be included in the function's docstring to maintain comprehensive documentation.

+        overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, the cached file will be loaded if exists.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
overwrite_cache_filtered_task: bool = False,
overwrite_cache_filtered_task: bool = False,

subjects_included: dict[str, set[int]] | None = None,
) -> dict[str, pl.LazyFrame]:
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints.
Expand All @@ -67,14 +68,16 @@ def load_flat_rep(
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will
try to update the stored flat representations to reflect these. If `False`, if information is
missing, it will raise a `FileNotFoundError` instead.
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task
task_df_name: If specified, the flat representations loaded will be joined against the task
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a
specified path for this task, then the data will be loaded from there, rather than from the base
dataset.
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the
relevant rows for the task, be cached to disk for faster re-use.
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
the cached file will be loaded if exists.
Comment on lines +79 to +80
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

Ensure all calls to load_flat_rep include the new parameter overwrite_cache_filtered_task.

The following calls to load_flat_rep in sample_data/examine_synthetic_data.ipynb do not include the new parameter and should be updated:

  • flat_reps = load_flat_rep(ESD, window_sizes=['7d'])
  • flat_reps = load_flat_rep(ESD, window_sizes=['FULL'], feature_inclusion_frequency=0.001)
  • flat_reps = load_flat_rep(ESD, window_sizes=['1d', '7d', 'FULL'])
  • flat_reps = load_flat_rep(ESD, window_sizes=['2d'], do_update_if_missing=False)
  • flat_reps = load_flat_rep(ESD, window_sizes=['1d', '7d', 'FULL'], subjects_included={'train': {0, 1, 2}})
  • flat_reps = load_flat_rep(ESD, window_sizes=['7d'], task_df_name='single_label_binary_classification')
  • flat_reps = load_flat_rep(ESD, window_sizes=['FULL', '1d'], task_df_name='multi_class_classification', do_cache_filtered_task=False)
  • flat_reps = load_flat_rep(ESD, window_sizes=['FULL', '1d'], task_df_name='single_label_binary_classification', subjects_included={'train': {0, 1, 2}})

Please update these calls to include the overwrite_cache_filtered_task parameter as needed.

Analysis chain

LGTM! But verify the new parameter usage in the codebase.

The code changes are approved.

However, ensure that all function calls to load_flat_rep are updated to use the new parameter if necessary.

Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify all function calls to `load_flat_rep` match the new signature.

# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg --type python $'load_flat_rep'

Length of output: 63


Script:

#!/bin/bash
# Description: Verify all function calls to `load_flat_rep` match the new signature.

# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg 'load_flat_rep'

Length of output: 1215


Script:

#!/bin/bash
# Description: Extract lines around each call to `load_flat_rep` to verify the usage of the new parameter.

# Extract lines around the calls in the Jupyter notebook
rg -A 3 -B 3 'load_flat_rep' sample_data/examine_synthetic_data.ipynb

# Extract lines around the calls in the Python file
rg -A 3 -B 3 'load_flat_rep' EventStream/baseline/FT_task_baseline.py

Length of output: 3273

subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are
used wholesale.

Expand Down Expand Up @@ -152,6 +155,7 @@ def load_flat_rep(

static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet")
if task_df_name is not None:
static_df = static_df.cast({"subject_id": sp_join_df.select('subject_id').dtypes[0]})
static_df = static_df.join(sp_join_df.select("subject_id").unique(), on="subject_id", how="inner")

dfs = []
Expand All @@ -170,7 +174,7 @@ def load_flat_rep(
if task_df_name is not None:
fn = fp.parts[-1]
cached_fp = task_window_dir / fn
if cached_fp.is_file():
if cached_fp.is_file() and not overwrite_cache_filtered_task:
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
Expand All @@ -181,8 +185,13 @@ def load_flat_rep(
df = pl.scan_parquet(fp)
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))

df = df.join(filter_join_df, on=join_keys, how="inner")
df = df.cast({"subject_id": filter_join_df.select('subject_id').dtypes[0]})
df = filter_join_df.join_asof(
df,
by="subject_id",
on="timestamp",
strategy="forward" if "-" in window_size else "backward",
)

if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
Expand All @@ -195,7 +204,7 @@ def load_flat_rep(

window_dfs.append(df)

dfs.append(pl.concat(window_dfs, how="vertical"))
dfs.append(pl.concat(window_dfs, how="vertical_relaxed"))

joined_df = dfs[0]
for jdf in dfs[1:]:
Expand Down
17 changes: 12 additions & 5 deletions EventStream/data/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,21 @@ def build_event_and_measurement_dfs(
all_events_and_measurements = []
event_types = []

for df, schemas in schemas_by_df.items():
for df_name, schemas in schemas_by_df.items():
all_columns = []

all_columns.extend(itertools.chain.from_iterable(s.columns_to_load for s in schemas))

try:
df = cls._load_input_df(df, all_columns, subject_id_col, subject_ids_map, subject_id_dtype)
df = cls._load_input_df(
df_name, all_columns, subject_id_col, subject_ids_map, subject_id_dtype
)
except Exception as e:
raise ValueError(f"Errored while loading {df}") from e
raise ValueError(f"Errored while loading {df_name}") from e

for schema in schemas:
for schema in tqdm(
schemas, desc=f"Processing events and measurements df for {df_name.split('/')[-1]}"
):
if schema.filter_on:
df = cls._filter_col_inclusion(schema.filter_on)
match schema.type:
Expand Down Expand Up @@ -266,7 +270,10 @@ def build_event_and_measurement_dfs(

all_events, all_measurements = [], []
running_event_id_max = 0
for event_type, (events, measurements) in zip(event_types, all_events_and_measurements):
for event_type, (events, measurements) in tqdm(
zip(event_types, all_events_and_measurements),
desc="Incrementing and combining events and measurements",
):
try:
new_events = cls._inc_df_col(events, "event_id", running_event_id_max)
except Exception as e:
Expand Down
19 changes: 12 additions & 7 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ def _update_subject_event_properties(self):
)

n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas()
self.n_events_per_subject = n_events_pd.set_index("subject_id")["counts"].to_dict()
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict()
self.subject_ids = set(self.n_events_per_subject.keys())

if self.subjects_df is not None:
Expand Down Expand Up @@ -1105,7 +1105,7 @@ def _fit_vocabulary(self, measure: str, config: MeasurementConfig, source_df: DF
try:
value_counts = observations.value_counts()
vocab_elements = value_counts.get_column(measure).to_list()
el_counts = value_counts.get_column("counts")
el_counts = value_counts.get_column("count")
return Vocabulary(vocabulary=vocab_elements, obs_frequencies=el_counts)
except AssertionError as e:
raise AssertionError(f"Failed to build vocabulary for {measure}") from e
Expand Down Expand Up @@ -1417,7 +1417,10 @@ def _summarize_static_measurements(
if include_only_subjects is None:
df = self.subjects_df
else:
df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects)))
self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.subjects_df.filter(
pl.col("subject_id").is_in([str(id) for id in include_only_subjects])
)
Comment on lines +1420 to +1423
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure consistent data type handling for subject_id across methods.

It's a good practice to maintain consistency in how data types are handled across different methods. Here, the subject_id is cast to string in several filtering operations. Consider creating a helper function to perform this casting and filtering operation, which can then be reused across these methods, enhancing code reusability and maintainability.

- self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
- df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))
+ df = self._filter_subjects_by_id(include_only_subjects)

And add a new method in the class:

def _filter_subjects_by_id(self, include_only_subjects):
    self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
    return self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))

Also applies to: 1483-1484, 1547-1551


valid_measures = {}
for feat_col in feature_columns:
Expand Down Expand Up @@ -1477,7 +1480,8 @@ def _summarize_time_dependent_measurements(
if include_only_subjects is None:
df = self.events_df
else:
df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects)))
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))

valid_measures = {}
for feat_col in feature_columns:
Expand Down Expand Up @@ -1540,10 +1544,11 @@ def _summarize_dynamic_measurements(
if include_only_subjects is None:
df = self.dynamic_measurements_df
else:
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
df = self.dynamic_measurements_df.join(
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select(
"event_id"
),
self.events_df.filter(
pl.col("subject_id").is_in([str(id) for id in include_only_subjects])
).select("event_id"),
on="event_id",
how="inner",
)
Expand Down
45 changes: 45 additions & 0 deletions scripts/build_flat_reps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python
"""Builds a flat representation dataset given a hydra config file."""

try:
import stackprinter

stackprinter.set_excepthook(style="darkbg2")
except ImportError:
pass # no need to fail because of missing dev dependency

from pathlib import Path

import hydra
from omegaconf import DictConfig

from EventStream.data.dataset_polars import Dataset


@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base")
def main(cfg: DictConfig):
cfg = hydra.utils.instantiate(cfg, _convert_="all")
save_dir = Path(cfg.pop("save_dir"))
window_sizes = cfg.pop("window_sizes")
subjects_per_output_file = (
cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None
)

# Build flat reps for specified task and window sizes
ESD = Dataset.load(save_dir)
feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params(
feature_inclusion_frequency=None, include_only_measurements=None
)
cache_kwargs = dict(
subjects_per_output_file=subjects_per_output_file,
feature_inclusion_frequency=feature_inclusion_frequency, # 0.1
window_sizes=window_sizes,
include_only_measurements=include_only_measurements,
do_overwrite=False,
do_update=True,
)
ESD.cache_flat_representation(**cache_kwargs)


if __name__ == "__main__":
main()
Loading