Skip to content

Commit

Permalink
Update code to latest BiaPy version and prepare for releasing 1.1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
danifranco committed Nov 28, 2024
1 parent a052a24 commit 105258b
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 57 deletions.
58 changes: 41 additions & 17 deletions biapy/biapy_check_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,8 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"vit",
"mae",
"unext_v1",
], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1']"
"unext_v2",
], "MODEL.ARCHITECTURE not in ['unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'simple_cnn', 'efficientnet_b[0-7]', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2']"
if (
model_arch
not in [
Expand All @@ -1253,6 +1254,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"vit",
"mae",
"unext_v1",
"unext_v2",
]
and cfg.PROBLEM.NDIM == "3D"
and cfg.PROBLEM.TYPE != "CLASSIFICATION"
Expand All @@ -1271,6 +1273,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"vit",
"mae",
"unext_v1",
"unext_v2",
]
)
)
Expand All @@ -1288,10 +1291,11 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"multiresunet",
"unetr",
"unext_v1",
"unext_v2",
]
):
raise ValueError(
"'MODEL.N_CLASSES' > 2 can only be used with 'MODEL.ARCHITECTURE' in ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unetr', 'unext_v1']"
"'MODEL.N_CLASSES' > 2 can only be used with 'MODEL.ARCHITECTURE' in ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unetr', 'unext_v1', 'unext_v2']"
)

assert len(cfg.MODEL.FEATURE_MAPS) > 2, "'MODEL.FEATURE_MAPS' needs to have at least 3 values"
Expand Down Expand Up @@ -1372,10 +1376,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"resunet_se",
"unetr",
"multiresunet",
"unext_v1",
"unext_v1","unext_v2",
]:
raise ValueError(
"Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1']".format(
"Architectures available for {} are: ['unet', 'resunet', 'resunet++', 'seunet', 'attention_unet', 'resunet_se', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2']".format(
cfg.PROBLEM.TYPE
)
)
Expand All @@ -1393,9 +1397,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"attention_unet",
"multiresunet",
"unext_v1",
"unext_v2",
]:
raise ValueError(
"Architectures available for 2D 'SUPER_RESOLUTION' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1']"
"Architectures available for 2D 'SUPER_RESOLUTION' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2']"
)
elif cfg.PROBLEM.NDIM == "3D":
if model_arch not in [
Expand All @@ -1406,9 +1411,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"attention_unet",
"multiresunet",
"unext_v1",
"unext_v2",
]:
raise ValueError(
"Architectures available for 3D 'SUPER_RESOLUTION' are: ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1']"
"Architectures available for 3D 'SUPER_RESOLUTION' are: ['unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2']"
)
assert cfg.MODEL.UNET_SR_UPSAMPLE_POSITION in [
"pre",
Expand All @@ -1429,9 +1435,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"unetr",
"multiresunet",
"unext_v1",
"unext_v2",
]:
raise ValueError(
"Architectures available for 'IMAGE_TO_IMAGE' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'resunet_se', 'seunet', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1']"
"Architectures available for 'IMAGE_TO_IMAGE' are: ['edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'resunet_se', 'seunet', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2']"
)
elif cfg.PROBLEM.TYPE == "SELF_SUPERVISED":
if model_arch not in [
Expand All @@ -1444,6 +1451,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"resunet_se",
"unetr",
"unext_v1",
"unext_v2",
"edsr",
"rcan",
"dfcan",
Expand Down Expand Up @@ -1482,6 +1490,7 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"attention_unet",
"multiresunet",
"unext_v1",
"unext_v2",
]:
z_size = cfg.DATA.PATCH_SIZE[0]
sizes = cfg.DATA.PATCH_SIZE[1:-1]
Expand Down Expand Up @@ -1684,10 +1693,10 @@ def check_configuration(cfg, jobname, check_data_paths=True):
"'TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS' needs to be set when 'TEST.POST_PROCESSING.REMOVE_CLOSE_POINTS' is True"
)

def compare_configurations_without_model(actual_cfg, old_cfg, header_message=""):
def compare_configurations_without_model(actual_cfg, old_cfg, header_message="", old_cfg_version=None):
"""
Compares two configurations and throws an error if they differ in some critical variables that change workflow behaviour. This
comparisdon does not take into account model specs.
comparisdon does not take into account model specs.
"""
print("Comparing configurations . . .")

Expand All @@ -1699,21 +1708,33 @@ def compare_configurations_without_model(actual_cfg, old_cfg, header_message="")
"PROBLEM.SELF_SUPERVISED.PRETEXT_TASK",
"PROBLEM.SUPER_RESOLUTION.UPSCALING",
"MODEL.N_CLASSES",
]
]

def get_attribute_recursive(var, attr):
att = attr.split(".")
if len(att) == 1:
return getattr(var, att[0])
else:
return get_attribute_recursive(getattr(var, att[0]), ".".join(att[1:]))


# Old configuration translation
dim_count = 2 if old_cfg.PROBLEM.NDIM == "2D" else 3
# BiaPy version less than 3.5.5
if old_cfg_version is None:
if isinstance(old_cfg["PROBLEM"]["SUPER_RESOLUTION"]["UPSCALING"], int):
old_cfg["PROBLEM"]["SUPER_RESOLUTION"]["UPSCALING"] = (old_cfg["PROBLEM"]["SUPER_RESOLUTION"]["UPSCALING"],) * dim_count

for var_to_compare in vars_to_compare:
if get_attribute_recursive(actual_cfg, var_to_compare) != get_attribute_recursive(old_cfg, var_to_compare):
raise ValueError(header_message+f"The '{var_to_compare}' value of the compared configurations does not match")

raise ValueError(
header_message + f"The '{var_to_compare}' value of the compared configurations does not match: " +\
f"{get_attribute_recursive(actual_cfg, var_to_compare)} (current configuration) vs {get_attribute_recursive(old_cfg, var_to_compare)} (from loaded configuration)"
)

print("Configurations seem to be compatible. Continuing . . .")



def get_checkpoint_path(cfg, jobname):
"""Get the checkpoint file path"""
checkpoint_dir = Path(cfg.PATHS.CHECKPOINT)
Expand Down Expand Up @@ -1942,7 +1963,10 @@ def check_torchvision_available_models(workflow, ndim):
return models, model_restrictions_description, model_restrictions

def convert_old_model_cfg_to_current_version(old_cfg):
# https://github.com/BiaPyX/BiaPy/compare/6aa291baa9bc5d7fb410454bfcea3a3da0c23604...v3.5.5
"""
Backward compatibility until commit 6aa291baa9bc5d7fb410454bfcea3a3da0c23604 (version 3.2.0)
Commit url: https://github.com/BiaPyX/BiaPy/commit/6aa291baa9bc5d7fb410454bfcea3a3da0c23604
"""
if "TEST" in old_cfg:
if "STATS" in old_cfg["TEST"]:
full_image = old_cfg["TEST"]["STATS"]["FULL_IMG"]
Expand Down Expand Up @@ -2015,7 +2039,7 @@ def convert_old_model_cfg_to_current_version(old_cfg):
old_cfg["DATA"]["TRAIN"]["FILTER_SAMPLES"] = {}
old_cfg["DATA"]["TRAIN"]["FILTER_SAMPLES"]["PROPS"] = [['foreground']]
old_cfg["DATA"]["TRAIN"]["FILTER_SAMPLES"]["VALUES"] = [[min_fore]]
old_cfg["DATA"]["TRAIN"]["FILTER_SAMPLES"]["SIGN"] = [['lt']]
old_cfg["DATA"]["TRAIN"]["FILTER_SAMPLES"]["SIGNS"] = [['lt']]
if "VAL" in old_cfg["DATA"]:
if "BINARY_MASKS" in old_cfg["DATA"]["VAL"]:
del old_cfg["DATA"]["VAL"]["BINARY_MASKS"]
Expand Down Expand Up @@ -2098,4 +2122,4 @@ def convert_old_model_cfg_to_current_version(old_cfg):
except:
pass

return old_cfg
return old_cfg
Loading

0 comments on commit 105258b

Please sign in to comment.