diff --git a/tractolearn/tractoio/dataset_fetch.py b/tractolearn/tractoio/dataset_fetch.py index 605e9b4..608b609 100644 --- a/tractolearn/tractoio/dataset_fetch.py +++ b/tractolearn/tractoio/dataset_fetch.py @@ -329,8 +329,7 @@ def fetcher(): return fetcher -fetch_bundle_label_config = ( - "fetch_bundle_label_config", +bundle_label_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["rbx_atlas_v10.json"], ["rbx_atlas_v10.json"], @@ -341,8 +340,7 @@ def fetcher(): False, ) -fetch_contrastive_ae_weights = ( - "fetch_contrastive_ae_weights", +contrastive_ae_weights = ( TRACTOLEARN_DATASETS_URL + "7562790/files/", ["best_model_contrastive_tractoinferno_hcp.pt"], ["best_model_contrastive_tractoinferno_hcp.pt"], @@ -353,8 +351,7 @@ def fetcher(): False, ) -fetch_mni2009cnonlinsymm_anat = ( - "fetch_mni2009cnonlinsymm_anat", +mni2009cnonlinsymm_anat = ( TRACTOLEARN_DATASETS_URL + "7562790/files/", ["mni_masked.nii.gz"], ["mni_masked.nii.gz"], @@ -365,8 +362,7 @@ def fetcher(): False, ) -fetch_generative_loa_cone_config = ( - "fetch_generative_loa_cone_config", +generative_loa_cone_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["degree.json"], ["degree.json"], @@ -377,8 +373,7 @@ def fetcher(): False, ) -fetch_generative_seed_streamline_ratio_config = ( - "fetch_generative_seed_streamline_ratio_config", +generative_seed_streamline_ratio_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["ratio.json"], ["ratio.json"], @@ -389,8 +384,7 @@ def fetcher(): False, ) -fetch_generative_streamline_max_count_config = ( - "fetch_generative_streamline_max_count_config", +generative_streamline_max_count_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["max_total_sampling.json"], ["max_total_sampling.json"], @@ -401,8 +395,7 @@ def fetcher(): False, ) -fetch_generative_streamline_req_count_config = ( - "fetch_generative_streamline_req_count_config", +generative_streamline_req_count_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["number_rejection_sampling.json"], ["number_rejection_sampling.json"], @@ -413,8 +406,7 @@ def fetcher(): False, ) -fetch_generative_wm_tisue_criterion_config = ( - "fetch_generative_wm_tisue_criterion_config", +generative_wm_tisue_criterion_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["white_matter_mask.json"], ["white_matter_mask.json"], @@ -425,8 +417,7 @@ def fetcher(): False, ) -fetch_recobundlesx_atlas = ( - "fetch_recobundlesx_atlas", +recobundlesx_atlas = ( TRACTOLEARN_DATASETS_URL + "7562635/files/", ["atlas.zip"], ["atlas.zip"], @@ -437,8 +428,7 @@ def fetcher(): True, ) -fetch_recobundlesx_config = ( - "fetch_recobundlesx_config", +recobundlesx_config = ( TRACTOLEARN_DATASETS_URL + "7562635/files/", ["config.zip"], ["config.zip"], @@ -449,8 +439,7 @@ def fetcher(): True, ) -fetch_tractoinferno_hcp_contrastive_threshold_config = ( - "fetch_tractoinferno_hcp_contrastive_threshold_config", +tractoinferno_hcp_contrastive_threshold_config = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["thresholds_contrastive_tractoinferno_hcp.json"], ["thresholds_contrastive_tractoinferno_hcp.json"], @@ -461,8 +450,7 @@ def fetcher(): False, ) -fetch_tractoinferno_hcp_ref_tractography = ( - "fetch_tractoinferno_hcp_ref_tractography", +tractoinferno_hcp_ref_tractography = ( TRACTOLEARN_DATASETS_URL + "/7562790/files/", ["data_tractoinferno_hcp_qbx.hdf5"], ["data_tractoinferno_hcp_qbx.hdf5"], @@ -474,43 +462,59 @@ def fetcher(): ) -def get_fetcher_method(name): - """Provide the fetcher method corresponding to the method name. +def _get_fetcher_data(name): + """Provide the fetcher method parameters corresponding to the method name. Returns ------- - callable - Fetcher method. + Tuple + Fetcher method parameters. """ if name == Dataset.BUNDLE_LABEL_CONFIG.name: - return fetch_bundle_label_config + return bundle_label_config elif name == Dataset.CONTRASTIVE_AUTOENCODER_WEIGHTS.name: - return fetch_contrastive_ae_weights + return contrastive_ae_weights elif name == Dataset.MNI2009CNONLINSYMM_ANAT.name: - return fetch_mni2009cnonlinsymm_anat + return mni2009cnonlinsymm_anat elif name == Dataset.GENERATIVE_LOA_CONE_CONFIG.name: - return fetch_generative_loa_cone_config + return generative_loa_cone_config elif name == Dataset.GENERATIVE_SEED_STRML_RATIO_CONFIG.name: - return fetch_generative_seed_streamline_ratio_config + return generative_seed_streamline_ratio_config elif name == Dataset.GENERATIVE_STRML_MAX_COUNT_CONFIG.name: - return fetch_generative_streamline_max_count_config + return generative_streamline_max_count_config elif name == Dataset.GENERATIVE_STRML_RQ_COUNT_CONFIG.name: - return fetch_generative_streamline_req_count_config + return generative_streamline_req_count_config elif name == Dataset.GENERATIVE_WM_TISSUE_CRITERION_CONFIG.name: - return fetch_generative_wm_tisue_criterion_config + return generative_wm_tisue_criterion_config elif name == Dataset.RECOBUNDLESX_ATLAS.name: - return fetch_recobundlesx_atlas + return recobundlesx_atlas elif name == Dataset.RECOBUNDLESX_CONFIG.name: - return fetch_recobundlesx_config + return recobundlesx_config elif name == Dataset.TRACTOINFERNO_HCP_CONTRASTIVE_THR_CONFIG.name: - return fetch_tractoinferno_hcp_contrastive_threshold_config + return tractoinferno_hcp_contrastive_threshold_config elif name == Dataset.TRACTOINFERNO_HCP_REF_TRACTOGRAPHY.name: - return fetch_tractoinferno_hcp_ref_tractography + return tractoinferno_hcp_ref_tractography else: raise DatasetError(_unknown_dataset_msg(name)) +def _compose_fetcher_name(name): + """Compose a name for the fetcher given the dataset name. + + Parameters + ---------- + name : string + Dataset name. + Returns + ------- + string + Fetcher name for dataset. + """ + + return "fetcher_" + Dataset[name].value + + def provide_dataset_description(): """Provide the description of the available datasets. @@ -525,7 +529,7 @@ def provide_dataset_description(): descr = list() for elem in list(Dataset): - params = get_fetcher_method(elem.name) + params = _get_fetcher_data(elem.name) descr.append( elem.value + ": " @@ -556,8 +560,9 @@ def retrieve_dataset(name, path): logger.info(f"\nDataset: {name}") - params = get_fetcher_method(name) - files, folder = _make_fetcher(path, *params)() + params = _get_fetcher_data(name) + fetcher_name = _compose_fetcher_name(name) + files, folder = _make_fetcher(path, fetcher_name, *params)() file_basename = list(files.keys())[0]