Skip to content

Commit

Permalink
ENH: Refactor the dataset parameters tuples
Browse files Browse the repository at this point in the history
Refactor the dataset parameters tuples:
- The dataset tuples do not provide an effective fetcher, but only the
parameters needed to make the fetcher. Thus, the `fetcher_` prefix is
removed.
- Rename the method that provides the fetcher parameters accordingly.
- The fetcher names are all built in the same way, and thus this is put
into a method for the sake of best coding practices. Also, prefer naming
the fetchers directly using the Dataset enum values.
  • Loading branch information
jhlegarreta committed Feb 14, 2023
1 parent 28e166f commit cbef0bb
Showing 1 changed file with 48 additions and 43 deletions.
91 changes: 48 additions & 43 deletions tractolearn/tractoio/dataset_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,7 @@ def fetcher():
return fetcher


fetch_bundle_label_config = (
"fetch_bundle_label_config",
bundle_label_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["rbx_atlas_v10.json"],
["rbx_atlas_v10.json"],
Expand All @@ -334,8 +333,7 @@ def fetcher():
False,
)

fetch_contrastive_ae_weights = (
"fetch_contrastive_ae_weights",
contrastive_ae_weights = (
TRACTOLEARN_DATASETS_URL + "7562790/files/",
["best_model_contrastive_tractoinferno_hcp.pt"],
["best_model_contrastive_tractoinferno_hcp.pt"],
Expand All @@ -346,8 +344,7 @@ def fetcher():
False,
)

fetch_mni2009cnonlinsymm_anat = (
"fetch_mni2009cnonlinsymm_anat",
mni2009cnonlinsymm_anat = (
TRACTOLEARN_DATASETS_URL + "7562790/files/",
["mni_masked.nii.gz"],
["mni_masked.nii.gz"],
Expand All @@ -358,8 +355,7 @@ def fetcher():
False,
)

fetch_generative_loa_cone_config = (
"fetch_generative_loa_cone_config",
generative_loa_cone_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["degree.json"],
["degree.json"],
Expand All @@ -370,8 +366,7 @@ def fetcher():
False,
)

fetch_generative_seed_streamline_ratio_config = (
"fetch_generative_seed_streamline_ratio_config",
generative_seed_streamline_ratio_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["ratio.json"],
["ratio.json"],
Expand All @@ -382,8 +377,7 @@ def fetcher():
False,
)

fetch_generative_streamline_max_count_config = (
"fetch_generative_streamline_max_count_config",
generative_streamline_max_count_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["max_total_sampling.json"],
["max_total_sampling.json"],
Expand All @@ -394,8 +388,7 @@ def fetcher():
False,
)

fetch_generative_streamline_req_count_config = (
"fetch_generative_streamline_req_count_config",
generative_streamline_req_count_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["number_rejection_sampling.json"],
["number_rejection_sampling.json"],
Expand All @@ -406,8 +399,7 @@ def fetcher():
False,
)

fetch_generative_wm_tisue_criterion_config = (
"fetch_generative_wm_tisue_criterion_config",
generative_wm_tisue_criterion_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["white_matter_mask.json"],
["white_matter_mask.json"],
Expand All @@ -418,8 +410,7 @@ def fetcher():
False,
)

fetch_recobundlesx_atlas = (
"fetch_recobundlesx_atlas",
recobundlesx_atlas = (
TRACTOLEARN_DATASETS_URL + "7562635/files/",
["atlas.zip"],
["atlas.zip"],
Expand All @@ -430,8 +421,7 @@ def fetcher():
True,
)

fetch_recobundlesx_config = (
"fetch_recobundlesx_config",
recobundlesx_config = (
TRACTOLEARN_DATASETS_URL + "7562635/files/",
["config.zip"],
["config.zip"],
Expand All @@ -442,8 +432,7 @@ def fetcher():
True,
)

fetch_tractoinferno_hcp_contrastive_threshold_config = (
"fetch_tractoinferno_hcp_contrastive_threshold_config",
tractoinferno_hcp_contrastive_threshold_config = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["thresholds_contrastive_tractoinferno_hcp.json"],
["thresholds_contrastive_tractoinferno_hcp.json"],
Expand All @@ -454,8 +443,7 @@ def fetcher():
False,
)

fetch_tractoinferno_hcp_ref_tractography = (
"fetch_tractoinferno_hcp_ref_tractography",
tractoinferno_hcp_ref_tractography = (
TRACTOLEARN_DATASETS_URL + "/7562790/files/",
["data_tractoinferno_hcp_qbx.hdf5"],
["data_tractoinferno_hcp_qbx.hdf5"],
Expand All @@ -467,43 +455,59 @@ def fetcher():
)


def get_fetcher_method(name):
"""Provide the fetcher method corresponding to the method name.
def _get_fetcher_data(name):
"""Provide the fetcher method parameters corresponding to the method name.
Returns
-------
callable
Fetcher method.
Tuple
Fetcher method parameters.
"""

if name == Dataset.BUNDLE_LABEL_CONFIG.name:
return fetch_bundle_label_config
return bundle_label_config
elif name == Dataset.CONTRASTIVE_AUTOENCODER_WEIGHTS.name:
return fetch_contrastive_ae_weights
return contrastive_ae_weights
elif name == Dataset.MNI2009CNONLINSYMM_ANAT.name:
return fetch_mni2009cnonlinsymm_anat
return mni2009cnonlinsymm_anat
elif name == Dataset.GENERATIVE_LOA_CONE_CONFIG.name:
return fetch_generative_loa_cone_config
return generative_loa_cone_config
elif name == Dataset.GENERATIVE_SEED_STRML_RATIO_CONFIG.name:
return fetch_generative_seed_streamline_ratio_config
return generative_seed_streamline_ratio_config
elif name == Dataset.GENERATIVE_STRML_MAX_COUNT_CONFIG.name:
return fetch_generative_streamline_max_count_config
return generative_streamline_max_count_config
elif name == Dataset.GENERATIVE_STRML_RQ_COUNT_CONFIG.name:
return fetch_generative_streamline_req_count_config
return generative_streamline_req_count_config
elif name == Dataset.GENERATIVE_WM_TISSUE_CRITERION_CONFIG.name:
return fetch_generative_wm_tisue_criterion_config
return generative_wm_tisue_criterion_config
elif name == Dataset.RECOBUNDLESX_ATLAS.name:
return fetch_recobundlesx_atlas
return recobundlesx_atlas
elif name == Dataset.RECOBUNDLESX_CONFIG.name:
return fetch_recobundlesx_config
return recobundlesx_config
elif name == Dataset.TRACTOINFERNO_HCP_CONTRASTIVE_THR_CONFIG.name:
return fetch_tractoinferno_hcp_contrastive_threshold_config
return tractoinferno_hcp_contrastive_threshold_config
elif name == Dataset.TRACTOINFERNO_HCP_REF_TRACTOGRAPHY.name:
return fetch_tractoinferno_hcp_ref_tractography
return tractoinferno_hcp_ref_tractography
else:
raise DatasetError(_unknown_dataset_msg(name))


def _compose_fetcher_name(name):
"""Compose a name for the fetcher given the dataset name.
Parameters
----------
name : string
Dataset name.
Returns
-------
string
Fetcher name for dataset.
"""

return "fetcher_" + Dataset[name].value


def provide_dataset_description():
"""Provide the description of the available datasets.
Expand All @@ -518,7 +522,7 @@ def provide_dataset_description():
descr = list()

for elem in list(Dataset):
params = get_fetcher_method(elem.name)
params = _get_fetcher_data(elem.name)
descr.append(elem.value + ": " + params[descr_idx] + ": " + params[url_idx] + "\n")

return descr
Expand All @@ -542,8 +546,9 @@ def retrieve_dataset(name, path):

logger.info(f"\nDataset: {name}")

params = get_fetcher_method(name)
files, folder = _make_fetcher(path, *params)()
params = _get_fetcher_data(name)
fetcher_name = _compose_fetcher_name(name)
files, folder = _make_fetcher(path, fetcher_name, *params)()

file_basename = list(files.keys())[0]

Expand Down

0 comments on commit cbef0bb

Please sign in to comment.