-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjusted join in flat reps to account for different timestamps with t… #107
base: main
Are you sure you want to change the base?
Changes from 7 commits
720e6cb
b006195
9e0acf7
22dca2d
435d968
5c6cb4b
69b99ce
bf453e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ def load_flat_rep( | |
do_update_if_missing: bool = True, | ||
task_df_name: str | None = None, | ||
do_cache_filtered_task: bool = True, | ||
overwrite_cache_filtered_task: bool = False, | ||
subjects_included: dict[str, set[int]] | None = None, | ||
) -> dict[str, pl.LazyFrame]: | ||
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints. | ||
|
@@ -67,14 +68,16 @@ def load_flat_rep( | |
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will | ||
try to update the stored flat representations to reflect these. If `False`, if information is | ||
missing, it will raise a `FileNotFoundError` instead. | ||
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task | ||
task_df_name: If specified, the flat representations loaded will be joined against the task | ||
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed | ||
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into | ||
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a | ||
specified path for this task, then the data will be loaded from there, rather than from the base | ||
dataset. | ||
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the | ||
relevant rows for the task, be cached to disk for faster re-use. | ||
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, | ||
the cached file will be loaded if exists. | ||
Comment on lines
+79
to
+80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tip Codebase Verification Ensure all calls to The following calls to
Please update these calls to include the Analysis chainLGTM! But verify the new parameter usage in the codebase. The code changes are approved. However, ensure that all function calls to Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Verify all function calls to `load_flat_rep` match the new signature.
# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg --type python $'load_flat_rep'
Length of output: 63 Script: #!/bin/bash
# Description: Verify all function calls to `load_flat_rep` match the new signature.
# Test: Search for the function usage. Expect: Only occurrences of the new signature.
rg 'load_flat_rep'
Length of output: 1215 Script: #!/bin/bash
# Description: Extract lines around each call to `load_flat_rep` to verify the usage of the new parameter.
# Extract lines around the calls in the Jupyter notebook
rg -A 3 -B 3 'load_flat_rep' sample_data/examine_synthetic_data.ipynb
# Extract lines around the calls in the Python file
rg -A 3 -B 3 'load_flat_rep' EventStream/baseline/FT_task_baseline.py
Length of output: 3273 |
||
subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are | ||
used wholesale. | ||
|
||
|
@@ -152,6 +155,7 @@ def load_flat_rep( | |
|
||
static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet") | ||
if task_df_name is not None: | ||
static_df = static_df.cast({"subject_id": sp_join_df.select('subject_id').dtypes[0]}) | ||
static_df = static_df.join(sp_join_df.select("subject_id").unique(), on="subject_id", how="inner") | ||
|
||
dfs = [] | ||
|
@@ -170,7 +174,7 @@ def load_flat_rep( | |
if task_df_name is not None: | ||
fn = fp.parts[-1] | ||
cached_fp = task_window_dir / fn | ||
if cached_fp.is_file(): | ||
if cached_fp.is_file() and not overwrite_cache_filtered_task: | ||
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features) | ||
if subjects_included.get(sp, None) is not None: | ||
subjects = list(set(subjects).intersection(subjects_included[sp])) | ||
|
@@ -181,8 +185,13 @@ def load_flat_rep( | |
df = pl.scan_parquet(fp) | ||
if task_df_name is not None: | ||
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects)) | ||
|
||
df = df.join(filter_join_df, on=join_keys, how="inner") | ||
df = df.cast({"subject_id": filter_join_df.select('subject_id').dtypes[0]}) | ||
df = filter_join_df.join_asof( | ||
df, | ||
by="subject_id", | ||
on="timestamp", | ||
strategy="forward" if "-" in window_size else "backward", | ||
) | ||
|
||
if do_cache_filtered_task: | ||
cached_fp.parent.mkdir(exist_ok=True, parents=True) | ||
|
@@ -195,7 +204,7 @@ def load_flat_rep( | |
|
||
window_dfs.append(df) | ||
|
||
dfs.append(pl.concat(window_dfs, how="vertical")) | ||
dfs.append(pl.concat(window_dfs, how="vertical_relaxed")) | ||
|
||
joined_df = dfs[0] | ||
for jdf in dfs[1:]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -705,7 +705,7 @@ def _update_subject_event_properties(self): | |
) | ||
|
||
n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() | ||
self.n_events_per_subject = n_events_pd.set_index("subject_id")["counts"].to_dict() | ||
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() | ||
self.subject_ids = set(self.n_events_per_subject.keys()) | ||
|
||
if self.subjects_df is not None: | ||
|
@@ -1105,7 +1105,7 @@ def _fit_vocabulary(self, measure: str, config: MeasurementConfig, source_df: DF | |
try: | ||
value_counts = observations.value_counts() | ||
vocab_elements = value_counts.get_column(measure).to_list() | ||
el_counts = value_counts.get_column("counts") | ||
el_counts = value_counts.get_column("count") | ||
return Vocabulary(vocabulary=vocab_elements, obs_frequencies=el_counts) | ||
except AssertionError as e: | ||
raise AssertionError(f"Failed to build vocabulary for {measure}") from e | ||
|
@@ -1417,7 +1417,10 @@ def _summarize_static_measurements( | |
if include_only_subjects is None: | ||
df = self.subjects_df | ||
else: | ||
df = self.subjects_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) | ||
self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||
df = self.subjects_df.filter( | ||
pl.col("subject_id").is_in([str(id) for id in include_only_subjects]) | ||
) | ||
Comment on lines
+1420
to
+1423
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure consistent data type handling for It's a good practice to maintain consistency in how data types are handled across different methods. Here, the - self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
- df = self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects]))
+ df = self._filter_subjects_by_id(include_only_subjects) And add a new method in the class: def _filter_subjects_by_id(self, include_only_subjects):
self.subjects_df = self.subjects_df.with_columns(pl.col("subject_id").cast(pl.Utf8))
return self.subjects_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) Also applies to: 1483-1484, 1547-1551 |
||
|
||
valid_measures = {} | ||
for feat_col in feature_columns: | ||
|
@@ -1477,7 +1480,8 @@ def _summarize_time_dependent_measurements( | |
if include_only_subjects is None: | ||
df = self.events_df | ||
else: | ||
df = self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))) | ||
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||
df = self.events_df.filter(pl.col("subject_id").is_in([str(id) for id in include_only_subjects])) | ||
|
||
valid_measures = {} | ||
for feat_col in feature_columns: | ||
|
@@ -1540,10 +1544,11 @@ def _summarize_dynamic_measurements( | |
if include_only_subjects is None: | ||
df = self.dynamic_measurements_df | ||
else: | ||
self.events_df = self.events_df.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||
df = self.dynamic_measurements_df.join( | ||
self.events_df.filter(pl.col("subject_id").is_in(list(include_only_subjects))).select( | ||
"event_id" | ||
), | ||
self.events_df.filter( | ||
pl.col("subject_id").is_in([str(id) for id in include_only_subjects]) | ||
).select("event_id"), | ||
on="event_id", | ||
how="inner", | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/usr/bin/env python | ||
"""Builds a flat representation dataset given a hydra config file.""" | ||
|
||
try: | ||
import stackprinter | ||
|
||
stackprinter.set_excepthook(style="darkbg2") | ||
except ImportError: | ||
pass # no need to fail because of missing dev dependency | ||
|
||
from pathlib import Path | ||
|
||
import hydra | ||
from omegaconf import DictConfig | ||
|
||
from EventStream.data.dataset_polars import Dataset | ||
|
||
|
||
@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base") | ||
def main(cfg: DictConfig): | ||
cfg = hydra.utils.instantiate(cfg, _convert_="all") | ||
save_dir = Path(cfg.pop("save_dir")) | ||
window_sizes = cfg.pop("window_sizes") | ||
subjects_per_output_file = ( | ||
cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None | ||
) | ||
|
||
# Build flat reps for specified task and window sizes | ||
ESD = Dataset.load(save_dir) | ||
feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params( | ||
feature_inclusion_frequency=None, include_only_measurements=None | ||
) | ||
cache_kwargs = dict( | ||
subjects_per_output_file=subjects_per_output_file, | ||
feature_inclusion_frequency=feature_inclusion_frequency, # 0.1 | ||
window_sizes=window_sizes, | ||
include_only_measurements=include_only_measurements, | ||
do_overwrite=False, | ||
do_update=True, | ||
) | ||
ESD.cache_flat_representation(**cache_kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure proper documentation for the new parameter.
The new parameter
overwrite_cache_filtered_task
should be included in the function's docstring to maintain comprehensive documentation.+ overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`, the cached file will be loaded if exists.
Committable suggestion