From 7647a0a590e1d403479cad4c91e952e87851dfbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Asbj=C3=B8rn=20Munk?= <9844416+asbjrnmunk@users.noreply.github.com> Date: Wed, 23 Oct 2024 17:11:12 +0200 Subject: [PATCH] Extend YuccaTestPreprocessedDataset to handle .npy files --- pyproject.toml | 2 +- yucca/modules/data/datasets/YuccaDataset.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7e5dea97..d4422163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "yucca" -version = "2.2.1" +version = "2.2.2" authors = [ { name="Sebastian Llambias", email="llambias@live.com" }, { name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" }, diff --git a/yucca/modules/data/datasets/YuccaDataset.py b/yucca/modules/data/datasets/YuccaDataset.py index 54ffcf49..6d917f29 100644 --- a/yucca/modules/data/datasets/YuccaDataset.py +++ b/yucca/modules/data/datasets/YuccaDataset.py @@ -197,13 +197,13 @@ def __init__( preprocessed_data_dir: str, pred_save_dir: str, overwrite_predictions: bool = False, - suffix: str = None, # noqa U100 + suffix: str = ".pt", # noqa U100 pred_include_cases: list = None, ): self.data_path = preprocessed_data_dir self.pred_save_dir = pred_save_dir self.overwrite = overwrite_predictions - self.data_suffix = ".pt" + self.suffix = suffix self.prediction_suffix = ".nii.gz" self.pred_include_cases = pred_include_cases self.unique_cases = np.unique( @@ -234,7 +234,13 @@ def __getitem__(self, idx): # will convert them to a list of tuples of strings and a tuple of a string. # i.e. ['path1', 'path2'] -> [('path1',), ('path2',)] case_id = self.unique_cases[idx] - data = torch.load(os.path.join(self.data_path, case_id + self.data_suffix), weights_only=False) + path = os.path.join(self.data_path, case_id + self.data_suffix) + + if self.data_suffix == ".pt": + data = torch.load(path, weights_only=False) + elif self.data_suffix == ".npy": + data = torch.tensor(np.load(path, allow_pickle=True)) + data_properties = load_pickle(os.path.join(self.data_path, case_id + ".pkl")) return {"data": data, "data_properties": data_properties, "case_id": case_id}