diff --git a/openfl-workspace/tf_2dunet/src/nii_reader.py b/openfl-workspace/tf_2dunet/src/nii_reader.py index ba90a644b1..a9c91b5909 100644 --- a/openfl-workspace/tf_2dunet/src/nii_reader.py +++ b/openfl-workspace/tf_2dunet/src/nii_reader.py @@ -212,24 +212,29 @@ def nii_reader(brain_path, task, channels_last=True, # check that all appropriate files are present file_root = brain_path.split('/')[-1] + '_' - extension = '.nii.gz' + # check for all possible extensions + extensions = ['.nii.gz', '.nii'] # record files needed # needed mask files are currntly independent of task - need_files_oneof = list_files(file_root, extension, msk_names) - if normalization != 'modes_together': - need_files_all = list_files(file_root, extension, task_to_img_modes[task]) - else: - need_files_all = list_files(file_root, extension, img_modes) - - correct_files = np.all([ - (reqd in files) - for reqd in need_files_all - ]) and np.sum([ - (reqd in files) - for reqd in need_files_oneof - ]) == 1 - + need_files_oneof = None + need_files_all = None + for extension in extensions: + need_files_oneof = list_files(file_root, extension, msk_names) + if normalization != 'modes_together': + need_files_all = list_files(file_root, extension, task_to_img_modes[task]) + else: + need_files_all = list_files(file_root, extension, img_modes) + + correct_files = np.all([ + (reqd in files) + for reqd in need_files_all + ]) and np.sum([ + (reqd in files) + for reqd in need_files_oneof + ]) == 1 + if correct_files: + break if not correct_files: return None, None