Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cases and (eventually) fixes for #114 #117

Open
wants to merge 24 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
edcd2e9
starting off
mmcdermott Jun 22, 2024
785ccc9
Starting to add test contents; not yet working
mmcdermott Jun 22, 2024
fa6844e
Merge branch 'dev' into fix_114_typing_with_subject_IDs
mmcdermott Jun 22, 2024
cf67bcd
Added a more involved test for dataset construction.
mmcdermott Jun 22, 2024
0434ed4
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
98cafd5
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
48ba63d
Merge branch 'fix_114_typing_with_subject_IDs' of github.com:mmcdermo…
mmcdermott Jun 22, 2024
b134133
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
9acff54
Removed the unnecessary pandas conversion in the getting number of ev…
mmcdermott Jun 22, 2024
1b4f0d8
Things are partially improved, but other tests are still failing. Inv…
mmcdermott Jun 22, 2024
39ba674
This may have fixed it.
mmcdermott Jun 22, 2024
ace03dd
Added minor changes from another branch and more debug logs
pargaw Jun 22, 2024
b41b014
Added assertion to check _filter_col_inclusion
pargaw Jun 22, 2024
aa0fa66
Some more corrections to subject ID typing.
mmcdermott Jun 22, 2024
413dda5
Fixed pytorch dataset issue (maybe)
mmcdermott Jun 22, 2024
9d703b7
Tests still having challenges with shards not overlapping.
mmcdermott Jun 22, 2024
7616598
fixed broken pytorch dataset test given string conventions.
mmcdermott Jun 23, 2024
674b50f
Other changes to try to get things working
mmcdermott Jun 23, 2024
9f3ce52
Added some more logging and dropping of nulls for safety
mmcdermott Jun 23, 2024
71ac9a8
Added flat rep changes from another branch
pargaw Jun 23, 2024
a813468
Added additional changes to fix flat_rep run from another branch
pargaw Jun 23, 2024
87a3874
Set assertion error to warning log
pargaw Jun 23, 2024
5150e05
Added functionality when do_update is False
pargaw Jun 30, 2024
532c3dd
Added debug statements for caching flat reps
pargaw Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def load_flat_rep(
do_update_if_missing: bool = True,
task_df_name: str | None = None,
do_cache_filtered_task: bool = True,
overwrite_cache_filtered_task: bool = False,
subjects_included: dict[str, set[int]] | None = None,
) -> dict[str, pl.LazyFrame]:
"""Loads a set of flat representations from a passed dataset that satisfy the given constraints.
Expand All @@ -68,14 +69,16 @@ def load_flat_rep(
do_update_if_missing: If `True`, then if any window sizes or features are missing, the function will
try to update the stored flat representations to reflect these. If `False`, if information is
missing, it will raise a `FileNotFoundError` instead.
task_df_name: If specified, the flat representations loaded will be (inner) joined against the task
task_df_name: If specified, the flat representations loaded will be joined against the task
dataframe of this name on the columns ``"subject_id"`` and ``"end_time"`` (which will be renamed
to ``"timestamp"``). This is to avoid needing to load the full dataset in flattened form into
memory. This is also used as a cache key; if a pre-filtered dataset is written to disk at a
specified path for this task, then the data will be loaded from there, rather than from the base
dataset.
do_cache_filtered_task: If `True`, the flat representations will, after being filtered to just the
relevant rows for the task, be cached to disk for faster re-use.
overwrite_cache_filtered_task: If `True`, the flat representations will be regenerated. If `False`,
the cached file will be loaded if exists.
subjects_included: A dictionary by split of the subjects to include in the task. Omitted splits are
used wholesale.

Expand Down Expand Up @@ -148,8 +151,9 @@ def load_flat_rep(

by_split = {}
for sp, all_sp_subjects in ESD.split_subjects.items():
all_sp_subjects = pl.Series(list(all_sp_subjects)).cast(ESD.subject_id_dtype)
if task_df_name is not None:
sp_join_df = join_df.filter(pl.col("subject_id").is_in(list(all_sp_subjects)))
sp_join_df = join_df.filter(pl.col("subject_id").is_in(all_sp_subjects))

static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet")
if task_df_name is not None:
Expand All @@ -171,19 +175,28 @@ def load_flat_rep(
if task_df_name is not None:
fn = fp.parts[-1]
cached_fp = task_window_dir / fn
if cached_fp.is_file():
if cached_fp.is_file() and not overwrite_cache_filtered_task:
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)
window_dfs.append(df)
continue

df = pl.scan_parquet(fp)
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))
filter_join_df = sp_join_df.select(join_keys).filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)

df = df.join(filter_join_df, on=join_keys, how="inner")
df = filter_join_df.join_asof(
df,
by="subject_id",
on="timestamp",
strategy="forward" if "-" in window_size else "backward",
)

if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
Expand All @@ -193,7 +206,7 @@ def load_flat_rep(
df = df.select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)))

window_dfs.append(df)

Expand Down
9 changes: 3 additions & 6 deletions EventStream/data/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,27 +1367,24 @@ def cache_deep_learning_representation(
NRT_dir = self.config.save_dir / "NRT_reps"

shards_fp = self.config.save_dir / "DL_shards.json"
if shards_fp.exists():
if shards_fp.exists() and not do_overwrite:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using json.load() instead of json.loads() with read_text().

The current implementation reads the entire file content into memory and then parses it. This can be optimized by directly loading the JSON content using json.load(), which streams the file content.

- shards = json.loads(shards_fp.read_text())
+ with open(shards_fp, 'r') as file:
+     shards = json.load(file)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if shards_fp.exists() and not do_overwrite:
if shards_fp.exists() and not do_overwrite:
with open(shards_fp, 'r') as file:
shards = json.load(file)

shards = json.loads(shards_fp.read_text())
else:
shards = {}

if subjects_per_output_file is None:
subject_chunks = [self.subject_ids]
else:
subjects = np.random.permutation(list(self.subject_ids))
subjects = np.random.permutation(list(set(self.subject_ids)))
subject_chunks = np.array_split(
subjects,
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file),
)

subject_chunks = [[int(x) for x in c] for c in subject_chunks]

for chunk_idx, subjects_list in enumerate(subject_chunks):
for split, subjects in self.split_subjects.items():
shard_key = f"{split}/{chunk_idx}"
included_subjects = set(subjects_list).intersection({int(x) for x in subjects})
shards[shard_key] = list(included_subjects)
shards[shard_key] = list(set(subjects_list).intersection(subjects))

shards_fp.write_text(json.dumps(shards))

Expand Down
47 changes: 41 additions & 6 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def _agg_by_time(self):
)

def _update_subject_event_properties(self):
self.subject_id_dtype = self.events_df.schema["subject_id"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure events_df is not None before accessing its schema.

The assignment of self.subject_id_dtype should be inside the conditional check to ensure events_df is not None before accessing its schema.

-        self.subject_id_dtype = self.events_df.schema["subject_id"]
-        if self.events_df is not None:
+        if self.events_df is not None:
+            self.subject_id_dtype = self.events_df.schema["subject_id"]
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.subject_id_dtype = self.events_df.schema["subject_id"]
if self.events_df is not None:
self.subject_id_dtype = self.events_df.schema["subject_id"]

if self.events_df is not None:
logger.debug("Collecting event types")
self.event_types = (
Expand All @@ -699,15 +700,28 @@ def _update_subject_event_properties(self):
.to_list()
)

n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas()
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict()
logger.debug("Collecting subject event counts")
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
n_events = n_events.drop_nulls("subject_id")
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
self.subject_ids = set(self.n_events_per_subject.keys())

if self.subjects_df is not None:
logger.debug("Collecting subject event counts")
subjects_with_no_events = (
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids
subjects_df_subjects = (
self.subjects_df
.drop_nulls("subject_id")
.select(pl.col("subject_id").cast(pl.Utf8))
)
subjects_df_subj_ids = set(subjects_df_subjects["subject_id"].to_list())
subj_no_in_df = self.subject_ids - subjects_df_subj_ids
if len(subj_no_in_df) > 0:
logger.warning(f"Found {len(subj_no_in_df)} subjects not in subject df!")
subjects_with_no_events = subjects_df_subj_ids - self.subject_ids
for sid in subjects_with_no_events:
self.n_events_per_subject[sid] = 0
self.subject_ids.update(subjects_with_no_events)
Expand All @@ -723,7 +737,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |
filter_exprs.append(pl.col(col).is_null())
case _:
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

incl_list = incl_list.cast(df.schema[col])

logger.debug(
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
)
Comment on lines +740 to +753
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve type conversion and error handling in _filter_col_inclusion.

The conversion of inclusion targets to the appropriate data type can be simplified by using polars built-in functions, and the error message can be enhanced for better clarity.

-                        logger.debug(
-                            f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
-                            f"{col} to {df.schema[col]}"
-                        )
-                        if isinstance(list(incl_targets)[0], str):
-                            incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
-                        else:
-                            incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
-                        incl_list = incl_list.cast(df.schema[col])
-                        logger.debug(
-                            f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
-                        )
+                        try:
+                            incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
+                        except TypeError as e:
+                            raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
incl_list = incl_list.cast(df.schema[col])
logger.debug(
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
)
try:
incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
except TypeError as e:
raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e

except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
Expand Down Expand Up @@ -1358,6 +1385,11 @@ def build_DL_cached_representation(
# 1. Process subject data into the right format.
if subject_ids:
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
if len(subject_ids) != len(subjects_df):
raise ValueError(
f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion "
f"the size of subjects_df are {len(subjects_df)}"
)
else:
subjects_df = self.subjects_df

Expand All @@ -1369,6 +1401,7 @@ def build_DL_cached_representation(
pl.col("index").alias("static_indices"),
)
)
logger.debug(f"Size of static_data: {static_data.shape[0]}")

# 2. Process event data into the right format.
if subject_ids:
Expand All @@ -1378,6 +1411,7 @@ def build_DL_cached_representation(
events_df = self.events_df
event_ids = None
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures)
logger.debug(f"Size of event_data: {event_data.shape[0]}")

# 3. Process measurement data into the right base format:
if event_ids:
Expand All @@ -1392,6 +1426,7 @@ def build_DL_cached_representation(

if do_sort_outputs:
dynamic_data = dynamic_data.sort("event_id", "measurement_id")
logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}")

# 4. Join dynamic and event data.

Expand Down
13 changes: 11 additions & 2 deletions EventStream/data/pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]:

subject_id, st, end = self.index[idx]

shard = self.subj_map[subject_id]
if str(subject_id) not in self.subj_map:
err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"]

if len(self.subj_map) < 10:
err_str.append("Subject IDs in map:")
err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items())

raise ValueError("\n".join(err_str))

shard = self.subj_map[str(subject_id)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the handling of subject ID.

The subject_id is being accessed as a string, which might not be consistent with other parts of the code where it could be treated as an integer or another type. This inconsistency can lead to errors or unexpected behavior.

-        shard = self.subj_map[str(subject_id)]
+        shard = self.subj_map[subject_id]  # Assuming subject_id is consistently used in its native type across the codebase.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
shard = self.subj_map[str(subject_id)]
shard = self.subj_map[subject_id] # Assuming subject_id is consistently used in its native type across the codebase.

subject_idx = self.subj_indices[subject_id]
static_row = self.static_dfs[shard][subject_idx].to_dict()

Expand All @@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]:
}

if self.config.do_include_subject_id:
out["subject_id"] = subject_id
out["subject_id"] = static_row["subject_id"].item()

seq_len = end - st
if seq_len > self.max_seq_len:
Expand Down
2 changes: 1 addition & 1 deletion configs/dataset_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ center_and_scale: True

hydra:
job:
name: build_${cohort_name}
name: build_dataset
run:
dir: ${save_dir}/.logs
sweep:
Expand Down
2 changes: 1 addition & 1 deletion sample_data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw/"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: null
DL_chunk_size: 25

inputs:
subjects:
Expand Down
69 changes: 69 additions & 0 deletions sample_data/dataset_parquet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
defaults:
- dataset_base
- _self_

# So that it can be run multiple times without issue.
do_overwrite: True

cohort_name: "sample"
subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw_parquet"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: 25

inputs:
subjects:
input_df: "${raw_data_dir}/subjects.parquet"
admissions:
input_df: "${raw_data_dir}/admit_vitals.parquet"
start_ts_col: "admit_date"
end_ts_col: "disch_date"
ts_format: "%m/%d/%Y, %H:%M:%S"
event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"]
vitals:
input_df: "${raw_data_dir}/admit_vitals.parquet"
ts_col: "vitals_date"
ts_format: "%m/%d/%Y, %H:%M:%S"
labs:
input_df: "${raw_data_dir}/labs.parquet"
ts_col: "timestamp"
ts_format: "%H:%M:%S-%Y-%m-%d"
medications:
input_df: "${raw_data_dir}/medications.parquet"
ts_col: "timestamp"
ts_format: "%H:%M:%S-%Y-%m-%d"
columns: {"name": "medication"}

measurements:
static:
single_label_classification:
subjects: ["eye_color"]
functional_time_dependent:
age:
functor: AgeFunctor
necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] }
kwargs: { dob_col: "dob" }
dynamic:
multi_label_classification:
admissions: ["department"]
medications:
- name: medication
modifiers:
- [dose, "float"]
- [frequency, "categorical"]
- [duration, "categorical"]
- [generic_name, "categorical"]
univariate_regression:
vitals: ["HR", "temp"]
multivariate_regression:
labs: [["lab_name", "lab_value"]]

outlier_detector_config:
stddev_cutoff: 1.5
min_valid_vocab_element_observations: 5
min_valid_column_observations: 5
min_true_float_frequency: 0.1
min_unique_numerical_observations: 20
min_events_per_subject: 3
agg_by_time_scale: "1h"
Binary file added sample_data/raw_parquet/admit_vitals.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/labs.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/medications.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/subjects.parquet
Binary file not shown.
45 changes: 45 additions & 0 deletions scripts/build_flat_reps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#!/usr/bin/env python
"""Builds a flat representation dataset given a hydra config file."""

try:
import stackprinter

stackprinter.set_excepthook(style="darkbg2")
except ImportError:
pass # no need to fail because of missing dev dependency

from pathlib import Path

import hydra
from omegaconf import DictConfig

from EventStream.data.dataset_polars import Dataset


@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base")
def main(cfg: DictConfig):
Comment on lines +19 to +20
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add documentation for the main function.

The new main function lacks a docstring. Adding a docstring would help other developers understand the purpose and usage of this function.

+    """
+    Main function to build flat representation datasets.
+    Args:
+        cfg (DictConfig): Configuration object containing parameters for dataset processing.
+    """
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base")
def main(cfg: DictConfig):
@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base")
def main(cfg: DictConfig):
"""
Main function to build flat representation datasets.
Args:
cfg (DictConfig): Configuration object containing parameters for dataset processing.
"""

cfg = hydra.utils.instantiate(cfg, _convert_="all")
save_dir = Path(cfg.pop("save_dir"))
window_sizes = cfg.pop("window_sizes")
subjects_per_output_file = (
cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None
)

# Build flat reps for specified task and window sizes
ESD = Dataset.load(save_dir)
feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params(
feature_inclusion_frequency=None, include_only_measurements=None
)
cache_kwargs = dict(
subjects_per_output_file=subjects_per_output_file,
feature_inclusion_frequency=feature_inclusion_frequency, # 0.1
window_sizes=window_sizes,
include_only_measurements=include_only_measurements,
do_overwrite=False,
do_update=True,
)
ESD.cache_flat_representation(**cache_kwargs)


if __name__ == "__main__":
main()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure proper error handling in the main function.

The function uses several pop operations on the configuration dictionary without checking if the keys exist, which could lead to KeyError if they are missing. Consider adding default values or error handling to prevent runtime errors.

-    save_dir = Path(cfg.pop("save_dir"))
-    window_sizes = cfg.pop("window_sizes")
+    save_dir = Path(cfg.get("save_dir", "default_directory"))
+    window_sizes = cfg.get("window_sizes", [])
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cfg = hydra.utils.instantiate(cfg, _convert_="all")
save_dir = Path(cfg.pop("save_dir"))
window_sizes = cfg.pop("window_sizes")
subjects_per_output_file = (
cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None
)
# Build flat reps for specified task and window sizes
ESD = Dataset.load(save_dir)
feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params(
feature_inclusion_frequency=None, include_only_measurements=None
)
cache_kwargs = dict(
subjects_per_output_file=subjects_per_output_file,
feature_inclusion_frequency=feature_inclusion_frequency, # 0.1
window_sizes=window_sizes,
include_only_measurements=include_only_measurements,
do_overwrite=False,
do_update=True,
)
ESD.cache_flat_representation(**cache_kwargs)
if __name__ == "__main__":
main()
cfg = hydra.utils.instantiate(cfg, _convert_="all")
save_dir = Path(cfg.get("save_dir", "default_directory"))
window_sizes = cfg.get("window_sizes", [])
subjects_per_output_file = (
cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None
)
# Build flat reps for specified task and window sizes
ESD = Dataset.load(save_dir)
feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params(
feature_inclusion_frequency=None, include_only_measurements=None
)
cache_kwargs = dict(
subjects_per_output_file=subjects_per_output_file,
feature_inclusion_frequency=feature_inclusion_frequency, # 0.1
window_sizes=window_sizes,
include_only_measurements=include_only_measurements,
do_overwrite=False,
do_update=True,
)
ESD.cache_flat_representation(**cache_kwargs)
if __name__ == "__main__":
main()

2 changes: 1 addition & 1 deletion tests/data/test_pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def setUp(self):

shards_fp = self.path / "DL_shards.json"
shards = {
f"{self.split}/0": list(set(DL_REP_DF["subject_id"].to_list())),
f"{self.split}/0": [str(x) for x in set(DL_REP_DF["subject_id"].to_list())],
}
shards_fp.write_text(json.dumps(shards))

Expand Down
Loading
Loading