Skip to content

Commit

Permalink
Added minor changes from another branch and more debug logs
Browse files Browse the repository at this point in the history
  • Loading branch information
pargaw committed Jun 22, 2024
1 parent 39ba674 commit ace03dd
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion EventStream/data/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,7 @@ def cache_deep_learning_representation(
NRT_dir = self.config.save_dir / "NRT_reps"

shards_fp = self.config.save_dir / "DL_shards.json"
if shards_fp.exists():
if shards_fp.exists() and not do_overwrite:
shards = json.loads(shards_fp.read_text())
else:
shards = {}
Expand Down
7 changes: 6 additions & 1 deletion EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |

incl_list = incl_list.cast(df.schema[col])

logger.debug(f"Converted to Series of type {incl_list.dtype}")
logger.debug(f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}")
except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
Expand Down Expand Up @@ -1372,10 +1372,12 @@ def build_DL_cached_representation(
raise ValueError(f"Unknown temporality type {temporality} for {m}")

# 1. Process subject data into the right format.
logger.debug(f'Size of subject_ids: {len(subject_ids)} and self.subjects_df: {self.subjects_df.shape[0]}')
if subject_ids:
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
else:
subjects_df = self.subjects_df
logger.debug(f'Size of subjects_df after _filter_col_inclusion: {len(subjects_df)}')

static_data = (
self._melt_df(subjects_df, ["subject_id"], subject_measures)
Expand All @@ -1385,6 +1387,7 @@ def build_DL_cached_representation(
pl.col("index").alias("static_indices"),
)
)
logger.debug(f'Size of static_data: {static_data.shape[0]}')

# 2. Process event data into the right format.
if subject_ids:
Expand All @@ -1394,6 +1397,7 @@ def build_DL_cached_representation(
events_df = self.events_df
event_ids = None
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures)
logger.debug(f'Size of event_data: {event_data.shape[0]}')

# 3. Process measurement data into the right base format:
if event_ids:
Expand All @@ -1408,6 +1412,7 @@ def build_DL_cached_representation(

if do_sort_outputs:
dynamic_data = dynamic_data.sort("event_id", "measurement_id")
logger.debug(f'Size of dynamic_data: {dynamic_data.shape[0]}')

# 4. Join dynamic and event data.

Expand Down
2 changes: 1 addition & 1 deletion configs/dataset_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ center_and_scale: True

hydra:
job:
name: build_${cohort_name}
name: build_dataset
run:
dir: ${save_dir}/.logs
sweep:
Expand Down

0 comments on commit ace03dd

Please sign in to comment.