Skip to content

Commit

Permalink
Extend YuccaTestPreprocessedDataset to handle .npy files
Browse files Browse the repository at this point in the history
  • Loading branch information
asbjrnmunk committed Oct 23, 2024
1 parent 5959e84 commit 7647a0a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "yucca"
version = "2.2.1"
version = "2.2.2"
authors = [
{ name="Sebastian Llambias", email="[email protected]" },
{ name="Asbjørn Munk", email="[email protected]" },
Expand Down
12 changes: 9 additions & 3 deletions yucca/modules/data/datasets/YuccaDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,13 @@ def __init__(
preprocessed_data_dir: str,
pred_save_dir: str,
overwrite_predictions: bool = False,
suffix: str = None, # noqa U100
suffix: str = ".pt", # noqa U100
pred_include_cases: list = None,
):
self.data_path = preprocessed_data_dir
self.pred_save_dir = pred_save_dir
self.overwrite = overwrite_predictions
self.data_suffix = ".pt"
self.suffix = suffix
self.prediction_suffix = ".nii.gz"
self.pred_include_cases = pred_include_cases
self.unique_cases = np.unique(
Expand Down Expand Up @@ -234,7 +234,13 @@ def __getitem__(self, idx):
# will convert them to a list of tuples of strings and a tuple of a string.
# i.e. ['path1', 'path2'] -> [('path1',), ('path2',)]
case_id = self.unique_cases[idx]
data = torch.load(os.path.join(self.data_path, case_id + self.data_suffix), weights_only=False)
path = os.path.join(self.data_path, case_id + self.data_suffix)

if self.data_suffix == ".pt":
data = torch.load(path, weights_only=False)
elif self.data_suffix == ".npy":
data = torch.tensor(np.load(path, allow_pickle=True))

data_properties = load_pickle(os.path.join(self.data_path, case_id + ".pkl"))
return {"data": data, "data_properties": data_properties, "case_id": case_id}

Expand Down

0 comments on commit 7647a0a

Please sign in to comment.