Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test cases and (eventually) fixes for #114 #117

Open
wants to merge 24 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
edcd2e9
starting off
mmcdermott Jun 22, 2024
785ccc9
Starting to add test contents; not yet working
mmcdermott Jun 22, 2024
fa6844e
Merge branch 'dev' into fix_114_typing_with_subject_IDs
mmcdermott Jun 22, 2024
cf67bcd
Added a more involved test for dataset construction.
mmcdermott Jun 22, 2024
0434ed4
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
98cafd5
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
48ba63d
Merge branch 'fix_114_typing_with_subject_IDs' of github.com:mmcdermo…
mmcdermott Jun 22, 2024
b134133
Added test case that (correctly) fails
mmcdermott Jun 22, 2024
9acff54
Removed the unnecessary pandas conversion in the getting number of ev…
mmcdermott Jun 22, 2024
1b4f0d8
Things are partially improved, but other tests are still failing. Inv…
mmcdermott Jun 22, 2024
39ba674
This may have fixed it.
mmcdermott Jun 22, 2024
ace03dd
Added minor changes from another branch and more debug logs
pargaw Jun 22, 2024
b41b014
Added assertion to check _filter_col_inclusion
pargaw Jun 22, 2024
aa0fa66
Some more corrections to subject ID typing.
mmcdermott Jun 22, 2024
413dda5
Fixed pytorch dataset issue (maybe)
mmcdermott Jun 22, 2024
9d703b7
Tests still having challenges with shards not overlapping.
mmcdermott Jun 22, 2024
7616598
fixed broken pytorch dataset test given string conventions.
mmcdermott Jun 23, 2024
674b50f
Other changes to try to get things working
mmcdermott Jun 23, 2024
9f3ce52
Added some more logging and dropping of nulls for safety
mmcdermott Jun 23, 2024
71ac9a8
Added flat rep changes from another branch
pargaw Jun 23, 2024
a813468
Added additional changes to fix flat_rep run from another branch
pargaw Jun 23, 2024
87a3874
Set assertion error to warning log
pargaw Jun 23, 2024
5150e05
Added functionality when do_update is False
pargaw Jun 30, 2024
532c3dd
Added debug statements for caching flat reps
pargaw Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def load_flat_rep(

by_split = {}
for sp, all_sp_subjects in ESD.split_subjects.items():
all_sp_subjects = pl.Series(list(all_sp_subjects)).cast(ESD.subject_id_dtype)
if task_df_name is not None:
sp_join_df = join_df.filter(pl.col("subject_id").is_in(list(all_sp_subjects)))
sp_join_df = join_df.filter(pl.col("subject_id").is_in(all_sp_subjects))

static_df = pl.scan_parquet(flat_dir / "static" / sp / "*.parquet")
if task_df_name is not None:
Expand All @@ -175,13 +176,17 @@ def load_flat_rep(
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)
window_dfs.append(df)
continue

df = pl.scan_parquet(fp)
if task_df_name is not None:
filter_join_df = sp_join_df.select(join_keys).filter(pl.col("subject_id").is_in(subjects))
filter_join_df = sp_join_df.select(join_keys).filter(
pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype))
)

df = df.join(filter_join_df, on=join_keys, how="inner")

Expand All @@ -193,7 +198,7 @@ def load_flat_rep(
df = df.select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
subjects = list(set(subjects).intersection(subjects_included[sp]))
df = df.filter(pl.col("subject_id").is_in(subjects))
df = df.filter(pl.col("subject_id").is_in(pl.Series(subjects).cast(ESD.subject_id_dtype)))

window_dfs.append(df)

Expand Down
9 changes: 3 additions & 6 deletions EventStream/data/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,27 +1367,24 @@ def cache_deep_learning_representation(
NRT_dir = self.config.save_dir / "NRT_reps"

shards_fp = self.config.save_dir / "DL_shards.json"
if shards_fp.exists():
if shards_fp.exists() and not do_overwrite:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using json.load() instead of json.loads() with read_text().

The current implementation reads the entire file content into memory and then parses it. This can be optimized by directly loading the JSON content using json.load(), which streams the file content.

- shards = json.loads(shards_fp.read_text())
+ with open(shards_fp, 'r') as file:
+     shards = json.load(file)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if shards_fp.exists() and not do_overwrite:
if shards_fp.exists() and not do_overwrite:
with open(shards_fp, 'r') as file:
shards = json.load(file)

shards = json.loads(shards_fp.read_text())
else:
shards = {}

if subjects_per_output_file is None:
subject_chunks = [self.subject_ids]
else:
subjects = np.random.permutation(list(self.subject_ids))
subjects = np.random.permutation(list(set(self.subject_ids)))
subject_chunks = np.array_split(
subjects,
np.arange(subjects_per_output_file, len(subjects), subjects_per_output_file),
)

subject_chunks = [[int(x) for x in c] for c in subject_chunks]

for chunk_idx, subjects_list in enumerate(subject_chunks):
for split, subjects in self.split_subjects.items():
shard_key = f"{split}/{chunk_idx}"
included_subjects = set(subjects_list).intersection({int(x) for x in subjects})
shards[shard_key] = list(included_subjects)
shards[shard_key] = list(set(subjects_list).intersection(subjects))

shards_fp.write_text(json.dumps(shards))

Expand Down
38 changes: 32 additions & 6 deletions EventStream/data/dataset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,7 @@ def _agg_by_time(self):
)

def _update_subject_event_properties(self):
self.subject_id_dtype = self.events_df.schema["subject_id"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensure events_df is not None before accessing its schema.

The assignment of self.subject_id_dtype should be inside the conditional check to ensure events_df is not None before accessing its schema.

-        self.subject_id_dtype = self.events_df.schema["subject_id"]
-        if self.events_df is not None:
+        if self.events_df is not None:
+            self.subject_id_dtype = self.events_df.schema["subject_id"]
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.subject_id_dtype = self.events_df.schema["subject_id"]
if self.events_df is not None:
self.subject_id_dtype = self.events_df.schema["subject_id"]

if self.events_df is not None:
logger.debug("Collecting event types")
self.event_types = (
Expand All @@ -699,15 +700,19 @@ def _update_subject_event_properties(self):
.to_list()
)

n_events_pd = self.events_df.get_column("subject_id").value_counts(sort=False).to_pandas()
self.n_events_per_subject = n_events_pd.set_index("subject_id")["count"].to_dict()
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor subject event properties calculation for clarity and efficiency.

The calculation of n_events_per_subject can be streamlined by using direct dictionary construction from n_events dataframe, which avoids the need for zipping and iterating over the series.

-            n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
-            self.n_events_per_subject = {
-                subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
-            }
+            self.n_events_per_subject = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)).to_dict(as_series=False)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
n_events = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8))
self.n_events_per_subject = {
subject_id: count for subject_id, count in zip(n_events["subject_id"], n_events["count"])
}
n_events = self.events_df.group_by("subject_id").agg(pl.len().alias("count"))
# here we cast to str to avoid issues with the subject_id column being various other types as we
# will eventually JSON serialize it.
self.n_events_per_subject = n_events.with_columns(pl.col("subject_id").cast(pl.Utf8)).to_dict(as_series=False)

self.subject_ids = set(self.n_events_per_subject.keys())

if self.subjects_df is not None:
logger.debug("Collecting subject event counts")
subjects_with_no_events = (
set(self.subjects_df.get_column("subject_id").to_list()) - self.subject_ids
)
subjects_df_subjects = self.subjects_df.select(pl.col("subject_id").cast(pl.Utf8))
subjects_with_no_events = set(subjects_df_subjects["subject_id"].to_list()) - self.subject_ids
for sid in subjects_with_no_events:
self.n_events_per_subject[sid] = 0
self.subject_ids.update(subjects_with_no_events)
Expand All @@ -723,7 +728,20 @@ def _filter_col_inclusion(cls, df: DF_T, col_inclusion_targets: dict[str, bool |
filter_exprs.append(pl.col(col).is_null())
case _:
try:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])

incl_list = incl_list.cast(df.schema[col])

logger.debug(
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
)
Comment on lines +740 to +753
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve type conversion and error handling in _filter_col_inclusion.

The conversion of inclusion targets to the appropriate data type can be simplified by using polars built-in functions, and the error message can be enhanced for better clarity.

-                        logger.debug(
-                            f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
-                            f"{col} to {df.schema[col]}"
-                        )
-                        if isinstance(list(incl_targets)[0], str):
-                            incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
-                        else:
-                            incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
-                        incl_list = incl_list.cast(df.schema[col])
-                        logger.debug(
-                            f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
-                        )
+                        try:
+                            incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
+                        except TypeError as e:
+                            raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.debug(
f"Converting inclusion targets of type {type(list(incl_targets)[0])} for "
f"{col} to {df.schema[col]}"
)
if isinstance(list(incl_targets)[0], str):
incl_list = pl.Series(list(incl_targets), dtype=pl.Utf8)
else:
incl_list = pl.Series(list(incl_targets), dtype=df.schema[col])
incl_list = incl_list.cast(df.schema[col])
logger.debug(
f"Converted to Series of type {incl_list.dtype} with size {len(incl_list)}"
)
try:
incl_list = pl.Series(list(incl_targets)).cast(df.schema[col])
except TypeError as e:
raise TypeError(f"Failed to cast inclusion targets to column '{col}' schema type.") from e

except TypeError as e:
incl_targets_by_type = defaultdict(list)
for t in incl_targets:
Expand Down Expand Up @@ -1358,6 +1376,11 @@ def build_DL_cached_representation(
# 1. Process subject data into the right format.
if subject_ids:
subjects_df = self._filter_col_inclusion(self.subjects_df, {"subject_id": subject_ids})
if len(subject_ids) != len(subjects_df):
raise ValueError(
f"Size of given subject_ids are {len(subject_ids)}, but after _filter_col_inclusion "
f"the size of subjects_df are {len(subjects_df)}"
)
else:
subjects_df = self.subjects_df

Expand All @@ -1369,6 +1392,7 @@ def build_DL_cached_representation(
pl.col("index").alias("static_indices"),
)
)
logger.debug(f"Size of static_data: {static_data.shape[0]}")

# 2. Process event data into the right format.
if subject_ids:
Expand All @@ -1378,6 +1402,7 @@ def build_DL_cached_representation(
events_df = self.events_df
event_ids = None
event_data = self._melt_df(events_df, ["subject_id", "timestamp", "event_id"], event_measures)
logger.debug(f"Size of event_data: {event_data.shape[0]}")

# 3. Process measurement data into the right base format:
if event_ids:
Expand All @@ -1392,6 +1417,7 @@ def build_DL_cached_representation(

if do_sort_outputs:
dynamic_data = dynamic_data.sort("event_id", "measurement_id")
logger.debug(f"Size of dynamic_data: {dynamic_data.shape[0]}")

# 4. Join dynamic and event data.

Expand Down
13 changes: 11 additions & 2 deletions EventStream/data/pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,16 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]:

subject_id, st, end = self.index[idx]

shard = self.subj_map[subject_id]
if str(subject_id) not in self.subj_map:
err_str = [f"Subject {subject_id} ({type(subject_id)} -- as str) not found in the shard map!"]

if len(self.subj_map) < 10:
err_str.append("Subject IDs in map:")
err_str.extend(f" * {k} ({type(k)}): {v}" for k, v in self.subj_map.items())

raise ValueError("\n".join(err_str))

shard = self.subj_map[str(subject_id)]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct the handling of subject ID.

The subject_id is being accessed as a string, which might not be consistent with other parts of the code where it could be treated as an integer or another type. This inconsistency can lead to errors or unexpected behavior.

-        shard = self.subj_map[str(subject_id)]
+        shard = self.subj_map[subject_id]  # Assuming subject_id is consistently used in its native type across the codebase.
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
shard = self.subj_map[str(subject_id)]
shard = self.subj_map[subject_id] # Assuming subject_id is consistently used in its native type across the codebase.

subject_idx = self.subj_indices[subject_id]
static_row = self.static_dfs[shard][subject_idx].to_dict()

Expand All @@ -471,7 +480,7 @@ def _seeded_getitem(self, idx: int) -> dict[str, list[float]]:
}

if self.config.do_include_subject_id:
out["subject_id"] = subject_id
out["subject_id"] = static_row["subject_id"].item()

seq_len = end - st
if seq_len > self.max_seq_len:
Expand Down
2 changes: 1 addition & 1 deletion configs/dataset_base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ center_and_scale: True

hydra:
job:
name: build_${cohort_name}
name: build_dataset
run:
dir: ${save_dir}/.logs
sweep:
Expand Down
2 changes: 1 addition & 1 deletion sample_data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw/"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: null
DL_chunk_size: 25

inputs:
subjects:
Expand Down
69 changes: 69 additions & 0 deletions sample_data/dataset_parquet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
defaults:
- dataset_base
- _self_

# So that it can be run multiple times without issue.
do_overwrite: True

cohort_name: "sample"
subject_id_col: "MRN"
raw_data_dir: "./sample_data/raw_parquet"
save_dir: "./sample_data/processed/${cohort_name}"

DL_chunk_size: 25

inputs:
subjects:
input_df: "${raw_data_dir}/subjects.parquet"
admissions:
input_df: "${raw_data_dir}/admit_vitals.parquet"
start_ts_col: "admit_date"
end_ts_col: "disch_date"
ts_format: "%m/%d/%Y, %H:%M:%S"
event_type: ["OUTPATIENT_VISIT", "ADMISSION", "DISCHARGE"]
vitals:
input_df: "${raw_data_dir}/admit_vitals.parquet"
ts_col: "vitals_date"
ts_format: "%m/%d/%Y, %H:%M:%S"
labs:
input_df: "${raw_data_dir}/labs.parquet"
ts_col: "timestamp"
ts_format: "%H:%M:%S-%Y-%m-%d"
medications:
input_df: "${raw_data_dir}/medications.parquet"
ts_col: "timestamp"
ts_format: "%H:%M:%S-%Y-%m-%d"
columns: {"name": "medication"}

measurements:
static:
single_label_classification:
subjects: ["eye_color"]
functional_time_dependent:
age:
functor: AgeFunctor
necessary_static_measurements: { "dob": ["timestamp", "%m/%d/%Y"] }
kwargs: { dob_col: "dob" }
dynamic:
multi_label_classification:
admissions: ["department"]
medications:
- name: medication
modifiers:
- [dose, "float"]
- [frequency, "categorical"]
- [duration, "categorical"]
- [generic_name, "categorical"]
univariate_regression:
vitals: ["HR", "temp"]
multivariate_regression:
labs: [["lab_name", "lab_value"]]

outlier_detector_config:
stddev_cutoff: 1.5
min_valid_vocab_element_observations: 5
min_valid_column_observations: 5
min_true_float_frequency: 0.1
min_unique_numerical_observations: 20
min_events_per_subject: 3
agg_by_time_scale: "1h"
Binary file added sample_data/raw_parquet/admit_vitals.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/labs.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/medications.parquet
Binary file not shown.
Binary file added sample_data/raw_parquet/subjects.parquet
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/data/test_pytorch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def setUp(self):

shards_fp = self.path / "DL_shards.json"
shards = {
f"{self.split}/0": list(set(DL_REP_DF["subject_id"].to_list())),
f"{self.split}/0": [str(x) for x in set(DL_REP_DF["subject_id"].to_list())],
}
shards_fp.write_text(json.dumps(shards))

Expand Down
76 changes: 76 additions & 0 deletions tests/test_e2e_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

import json
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorder module-level imports to the top of the file.

To comply with PEP 8, module-level imports should be at the top of the file.

+import json

Committable suggestion was skipped due to low confidence.

Tools
Ruff

5-5: Module level import not at top of file (E402)

import os
import subprocess
import unittest
Expand All @@ -10,6 +11,8 @@
from tempfile import TemporaryDirectory
from typing import Any

import polars as pl
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved

from tests.utils import MLTypeEqualityCheckableMixin


Expand All @@ -32,6 +35,7 @@ def setUp(self):
self.paths = {}
for n in (
"dataset",
"dataset_from_parquet",
"esds",
"pretraining/CI",
"pretraining/NA",
Expand All @@ -45,6 +49,67 @@ def tearDown(self):
for o in self.dir_objs.values():
o.cleanup()

def _test_dataset_output(self, raw_data_root: Path, dataset_save_dir: Path):
DL_save_dir = dataset_save_dir / "DL_reps"

train_files = list((DL_save_dir / "train").glob("*.parquet"))
tuning_files = list((DL_save_dir / "tuning").glob("*.parquet"))
held_out_files = list((DL_save_dir / "held_out").glob("*.parquet"))

assert len(set(train_files) & set(tuning_files)) == 0
assert len(set(train_files) & set(held_out_files)) == 0
assert len(set(tuning_files) & set(held_out_files)) == 0

self.assertTrue(len(train_files) > 0)
self.assertTrue(len(tuning_files) > 0)
self.assertTrue(len(held_out_files) > 0)

train_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in train_files])
tuning_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in tuning_files])
held_out_DL_reps = pl.concat([pl.read_parquet(f, use_pyarrow=False) for f in held_out_files])

DL_shards = json.loads((dataset_save_dir / "DL_shards.json").read_text())

ESD_subjects = pl.read_parquet(dataset_save_dir / "subjects_df.parquet", use_pyarrow=False)

# Check that the DL shards are correctly partitioned.
all_subjects = set(ESD_subjects["subject_id"].unique().to_list())

self.assertEqual(len(all_subjects), len(ESD_subjects))

all_subj_in_DL_shards = set().union(*DL_shards.values())

all_subj_in_DL_shards = set(
pl.Series(list(all_subj_in_DL_shards)).cast(ESD_subjects["subject_id"].dtype).to_list()
)

self.assertEqual(all_subjects, all_subj_in_DL_shards)

all_train_DL_shard_subj = set().union(*(v for k, v in DL_shards.items() if k.startswith("train")))
all_tuning_DL_shard_subj = set().union(*(v for k, v in DL_shards.items() if k.startswith("tuning")))
all_held_out_DL_shard_subj = set().union(
*(v for k, v in DL_shards.items() if k.startswith("held_out"))
)

self.assertEqual(len(all_train_DL_shard_subj & all_tuning_DL_shard_subj), 0)
self.assertEqual(len(all_train_DL_shard_subj & all_held_out_DL_shard_subj), 0)
self.assertEqual(len(all_tuning_DL_shard_subj & all_held_out_DL_shard_subj), 0)

train_DL_subjects = set(train_DL_reps["subject_id"].to_list())
tuning_DL_subjects = set(tuning_DL_reps["subject_id"].to_list())
held_out_DL_subjects = set(held_out_DL_reps["subject_id"].to_list())

self.assertEqual(all_train_DL_shard_subj, {str(x) for x in train_DL_subjects})
self.assertEqual(all_tuning_DL_shard_subj, {str(x) for x in tuning_DL_subjects})
self.assertEqual(all_held_out_DL_shard_subj, {str(x) for x in held_out_DL_subjects})

self.assertTrue(len(train_DL_subjects) > len(tuning_DL_subjects))
self.assertTrue(len(train_DL_subjects) > len(held_out_DL_subjects))

all_DL_subjects = train_DL_subjects | tuning_DL_subjects | held_out_DL_subjects

self.assertEqual(all_DL_subjects, all_subjects)

def _test_command(self, command_parts: list[str], case_name: str, use_subtest: bool = True):
if use_subtest:
with self.subTest(case_name):
Expand All @@ -71,6 +136,17 @@ def build_dataset(self):
f"save_dir={self.paths['dataset']}",
]
self._test_command(command_parts, "Build Dataset", use_subtest=False)
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset"])

command_parts = [
"./scripts/build_dataset.py",
f"--config-path='{(root / 'sample_data').resolve()}'",
"--config-name=dataset_parquet",
'"hydra.searchpath=[./configs]"',
f"save_dir={self.paths['dataset_from_parquet']}",
]
self._test_command(command_parts, "Build Dataset from Parquet", use_subtest=False)
self._test_dataset_output((root / "sample_data" / "raw"), self.paths["dataset_from_parquet"])

def build_ESDS_dataset(self):
command_parts = [
Expand Down
Loading