diff --git a/pyproject.toml b/pyproject.toml index 7e5dea97..d4422163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "yucca" -version = "2.2.1" +version = "2.2.2" authors = [ { name="Sebastian Llambias", email="llambias@live.com" }, { name="Asbjørn Munk", email="9844416+asbjrnmunk@users.noreply.github.com" }, diff --git a/yucca/documentation/templates/functional_inference.py b/yucca/documentation/templates/functional_inference.py index 424d1822..d70cf3d4 100644 --- a/yucca/documentation/templates/functional_inference.py +++ b/yucca/documentation/templates/functional_inference.py @@ -62,7 +62,7 @@ pred_save_dir=save_path, pred_data_dir=target_data_path, overwrite_predictions=True, - image_extension=".nii.gz", + image_extension="pt", test_dataset_class=YuccaTestPreprocessedDataset, ) diff --git a/yucca/modules/data/datasets/YuccaDataset.py b/yucca/modules/data/datasets/YuccaDataset.py index 54ffcf49..746364cc 100644 --- a/yucca/modules/data/datasets/YuccaDataset.py +++ b/yucca/modules/data/datasets/YuccaDataset.py @@ -197,13 +197,13 @@ def __init__( preprocessed_data_dir: str, pred_save_dir: str, overwrite_predictions: bool = False, - suffix: str = None, # noqa U100 + suffix: str = "pt", # noqa U100 pred_include_cases: list = None, ): self.data_path = preprocessed_data_dir self.pred_save_dir = pred_save_dir self.overwrite = overwrite_predictions - self.data_suffix = ".pt" + self.data_suffix = "." + suffix self.prediction_suffix = ".nii.gz" self.pred_include_cases = pred_include_cases self.unique_cases = np.unique( @@ -234,7 +234,13 @@ def __getitem__(self, idx): # will convert them to a list of tuples of strings and a tuple of a string. # i.e. ['path1', 'path2'] -> [('path1',), ('path2',)] case_id = self.unique_cases[idx] - data = torch.load(os.path.join(self.data_path, case_id + self.data_suffix), weights_only=False) + path = os.path.join(self.data_path, case_id + self.data_suffix) + + if self.data_suffix == ".pt": + data = torch.load(path, weights_only=False) + elif self.data_suffix == ".npy": + data = torch.tensor(np.load(path, allow_pickle=True)) + data_properties = load_pickle(os.path.join(self.data_path, case_id + ".pkl")) return {"data": data, "data_properties": data_properties, "case_id": case_id}