-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Test cases and (eventually) fixes for #114 #117
base: dev
Are you sure you want to change the base?
Changes from 20 commits
edcd2e9
785ccc9
fa6844e
cf67bcd
0434ed4
98cafd5
48ba63d
b134133
9acff54
1b4f0d8
39ba674
ace03dd
b41b014
aa0fa66
413dda5
9d703b7
7616598
674b50f
9f3ce52
71ac9a8
a813468
87a3874
5150e05
532c3dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -690,6 +690,7 @@ def _agg_by_time(self): | |||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
def _update_subject_event_properties(self): | ||||||||||||||||||||||||||||||||||||||
self.subject_id_dtype = self.events_df.schema["subject_id"] | ||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure The assignment of - self.subject_id_dtype = self.events_df.schema["subject_id"]
- if self.events_df is not None:
+ if self.events_df is not None:
+ self.subject_id_dtype = self.events_df.schema["subject_id"] Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
if self.events_df is not None: | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting event types") | ||||||||||||||||||||||||||||||||||||||
self.event_types = ( | ||||||||||||||||||||||||||||||||||||||
|
@@ -699,15 +700,28 @@ def _update_subject_event_properties(self): | |||||||||||||||||||||||||||||||||||||
.to_list() | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas() | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict() | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting subject event counts") | ||||||||||||||||||||||||||||||||||||||
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count")) | ||||||||||||||||||||||||||||||||||||||
n_events = n_events.drop_nulls("subject_id") | ||||||||||||||||||||||||||||||||||||||
# here we cast to str to avoid issues with the subject_id column being various other types as we | ||||||||||||||||||||||||||||||||||||||
# will eventually JSON serialize it. | ||||||||||||||||||||||||||||||||||||||
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject = { | ||||||||||||||||||||||||||||||||||||||
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"]) | ||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||
self.subject_ids = set(self.n_events_per_subject.keys()) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if self.subjects_df is not None: | ||||||||||||||||||||||||||||||||||||||
logger.debug("Collecting subject event counts") | ||||||||||||||||||||||||||||||||||||||
subjects_with_no_events = ( | ||||||||||||||||||||||||||||||||||||||
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids | ||||||||||||||||||||||||||||||||||||||
subjects_df_subjects = ( | ||||||||||||||||||||||||||||||||||||||
self.subjects_df | ||||||||||||||||||||||||||||||||||||||
.drop_nulls("subject_id") | ||||||||||||||||||||||||||||||||||||||
.select(pl.col("subject_id").cast(pl.Utf8)) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
subjects_df_subj_ids = set(subjects_df_subjects["subject_id"].to_list()) | ||||||||||||||||||||||||||||||||||||||
subj_no_in_df = self.subject_ids - subjects_df_subj_ids | ||||||||||||||||||||||||||||||||||||||
if len(subj_no_in_df) > 0: | ||||||||||||||||||||||||||||||||||||||
logger.warning(f"Found {len(subj_no_in_df)} subjects not in subject df!") | ||||||||||||||||||||||||||||||||||||||
subjects_with_no_events = subjects_df_subj_ids - self.subject_ids | ||||||||||||||||||||||||||||||||||||||
for sid in subjects_with_no_events: | ||||||||||||||||||||||||||||||||||||||
self.n_events_per_subject[sid] = 0 | ||||||||||||||||||||||||||||||||||||||
self.subject_ids.update(subjects_with_no_events) | ||||||||||||||||||||||||||||||||||||||
|
@@ -723,7 +737,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool | | |||||||||||||||||||||||||||||||||||||
filter_exprs.append(pl.col(col).is_null()) | ||||||||||||||||||||||||||||||||||||||
case _: | ||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for " | ||||||||||||||||||||||||||||||||||||||
f"{col} to {df.schema[col]}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
if isinstance(list(incl_targets)[0], str): | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
incl_list = incl_list.cast(df.schema[col]) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||||
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+740
to
+753
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve type conversion and error handling in The conversion of inclusion targets to the appropriate data type can be simplified by using - logger.debug(
- f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
- f"{col} to {df.schema[col]}"
- )
- if isinstance(list(incl_targets)[0], str):
- incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
- else:
- incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
- incl_list = incl_list.cast(df.schema[col])
- logger.debug(
- f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
- )
+ try:
+ incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
+ except TypeError as e:
+ raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
except TypeError as e: | ||||||||||||||||||||||||||||||||||||||
incl_targets_by_type = defaultdict(list) | ||||||||||||||||||||||||||||||||||||||
for t in incl_targets: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1358,6 +1385,11 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
# 1. Process subject data into the right format. | ||||||||||||||||||||||||||||||||||||||
if subject_ids: | ||||||||||||||||||||||||||||||||||||||
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids}) | ||||||||||||||||||||||||||||||||||||||
if len(subject_ids) != len(subjects_df): | ||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||
f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion " | ||||||||||||||||||||||||||||||||||||||
f"the size of subjects_df are {len(subjects_df)}" | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
subjects_df = self.subjects_df | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -1369,6 +1401,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
pl.col("index").alias("static_indices"), | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of static_data: {static_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 2. Process event data into the right format. | ||||||||||||||||||||||||||||||||||||||
if subject_ids: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1378,6 +1411,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
events_df = self.events_df | ||||||||||||||||||||||||||||||||||||||
event_ids = None | ||||||||||||||||||||||||||||||||||||||
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures) | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of event_data: {event_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 3. Process measurement data into the right base format: | ||||||||||||||||||||||||||||||||||||||
if event_ids: | ||||||||||||||||||||||||||||||||||||||
|
@@ -1392,6 +1426,7 @@ def build_DL_cached_representation( | |||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if do_sort_outputs: | ||||||||||||||||||||||||||||||||||||||
dynamic_data = dynamic_data.sort("event_id", "measurement_id") | ||||||||||||||||||||||||||||||||||||||
logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}") | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# 4. Join dynamic and event data. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: | |||||
|
||||||
subject_id, st, end = self.index[idx] | ||||||
|
||||||
shard = self.subj_map[subject_id] | ||||||
if str(subject_id) not in self.subj_map: | ||||||
err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"] | ||||||
|
||||||
if len(self.subj_map) < 10: | ||||||
err_str.append("Subject IDs in map:") | ||||||
err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items()) | ||||||
|
||||||
raise ValueError("\n".join(err_str)) | ||||||
|
||||||
shard = self.subj_map[str(subject_id)] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct the handling of subject ID. The - shard = self.subj_map[str(subject_id)]
+ shard = self.subj_map[subject_id] # Assuming subject_id is consistently used in its native type across the codebase. Committable suggestion
Suggested change
|
||||||
subject_idx = self.subj_indices[subject_id] | ||||||
static_row = self.static_dfs[shard][subject_idx].to_dict() | ||||||
|
||||||
|
@@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]: | |||||
} | ||||||
|
||||||
if self.config.do_include_subject_id: | ||||||
out["subject_id"] = subject_id | ||||||
out["subject_id"] = static_row["subject_id"].item() | ||||||
|
||||||
seq_len = end - st | ||||||
if seq_len > self.max_seq_len: | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
defaults: | ||
- dataset_base | ||
- _self_ | ||
|
||
# So that it can be run multiple times without issue. | ||
do_overwrite: True | ||
|
||
cohort_name: "sample" | ||
subject_id_col: "MRN" | ||
raw_data_dir: "./sample_data/raw_parquet" | ||
save_dir: "./sample_data/processed/${cohort_name}" | ||
|
||
DL_chunk_size: 25 | ||
|
||
inputs: | ||
subjects: | ||
input_df: "${raw_data_dir}/subjects.parquet" | ||
admissions: | ||
input_df: "${raw_data_dir}/admit_vitals.parquet" | ||
start_ts_col: "admit_date" | ||
end_ts_col: "disch_date" | ||
ts_format: "%m/%d/%Y, %H:%M:%S" | ||
event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"] | ||
vitals: | ||
input_df: "${raw_data_dir}/admit_vitals.parquet" | ||
ts_col: "vitals_date" | ||
ts_format: "%m/%d/%Y, %H:%M:%S" | ||
labs: | ||
input_df: "${raw_data_dir}/labs.parquet" | ||
ts_col: "timestamp" | ||
ts_format: "%H:%M:%S-%Y-%m-%d" | ||
medications: | ||
input_df: "${raw_data_dir}/medications.parquet" | ||
ts_col: "timestamp" | ||
ts_format: "%H:%M:%S-%Y-%m-%d" | ||
columns: {"name": "medication"} | ||
|
||
measurements: | ||
static: | ||
single_label_classification: | ||
subjects: ["eye_color"] | ||
functional_time_dependent: | ||
age: | ||
functor: AgeFunctor | ||
necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] } | ||
kwargs: { dob_col: "dob" } | ||
dynamic: | ||
multi_label_classification: | ||
admissions: ["department"] | ||
medications: | ||
- name: medication | ||
modifiers: | ||
- [dose, "float"] | ||
- [frequency, "categorical"] | ||
- [duration, "categorical"] | ||
- [generic_name, "categorical"] | ||
univariate_regression: | ||
vitals: ["HR", "temp"] | ||
multivariate_regression: | ||
labs: [["lab_name", "lab_value"]] | ||
|
||
outlier_detector_config: | ||
stddev_cutoff: 1.5 | ||
min_valid_vocab_element_observations: 5 | ||
min_valid_column_observations: 5 | ||
min_true_float_frequency: 0.1 | ||
min_unique_numerical_observations: 20 | ||
min_events_per_subject: 3 | ||
agg_by_time_scale: "1h" |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,45 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#!/usr/bin/env python | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"""Builds a flat representation dataset given a hydra config file.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import stackprinter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
stackprinter.set_excepthook(style="darkbg2") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass # no need to fail because of missing dev dependency | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import hydra | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from omegaconf import DictConfig | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from EventStream.data.dataset_polars import Dataset | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@hydra.main(version_base=None, config_path="../configs", config_name="dataset_base") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def main(cfg: DictConfig): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+19
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add documentation for the main function. The new + """
+ Main function to build flat representation datasets.
+ Args:
+ cfg (DictConfig): Configuration object containing parameters for dataset processing.
+ """ Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cfg = hydra.utils.instantiate(cfg, _convert_="all") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
save_dir = Path(cfg.pop("save_dir")) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
window_sizes = cfg.pop("window_sizes") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
subjects_per_output_file = ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cfg.pop("subjects_per_output_file") if "subjects_per_output_file" in cfg else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Build flat reps for specified task and window sizes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ESD = Dataset.load(save_dir) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_inclusion_frequency, include_only_measurements = ESD._resolve_flat_rep_cache_params( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_inclusion_frequency=None, include_only_measurements=None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
cache_kwargs = dict( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
subjects_per_output_file=subjects_per_output_file, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
feature_inclusion_frequency=feature_inclusion_frequency, # 0.1 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
window_sizes=window_sizes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
include_only_measurements=include_only_measurements, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
do_overwrite=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
do_update=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ESD.cache_flat_representation(**cache_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
main() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure proper error handling in the main function. The function uses several - save_dir = Path(cfg.pop("save_dir"))
- window_sizes = cfg.pop("window_sizes")
+ save_dir = Path(cfg.get("save_dir", "default_directory"))
+ window_sizes = cfg.get("window_sizes", []) Committable suggestion
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
json.load()
instead ofjson.loads()
withread_text()
.The current implementation reads the entire file content into memory and then parses it. This can be optimized by directly loading the JSON content using
json.load()
, which streams the file content.Committable suggestion