diff --git a/tests/utils/gdl_current_test.pth.tar b/tests/utils/gdl_current_test.pth.tar new file mode 100644 index 00000000..298be2ea Binary files /dev/null and b/tests/utils/gdl_current_test.pth.tar differ diff --git a/tests/utils/gdl_pre20_test.pth.tar b/tests/utils/gdl_pre20_test.pth.tar new file mode 100644 index 00000000..46e97dd2 Binary files /dev/null and b/tests/utils/gdl_pre20_test.pth.tar differ diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 2307da60..610718ab 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,7 +1,9 @@ import pytest from torchgeo.datasets.utils import extract_archive -from utils.utils import read_csv +from models.model_choice import read_checkpoint +from utils.utils import read_csv, is_inference_compatible, update_gdl_checkpoint + class TestUtils(object): def test_wrong_seperation(self) -> None: @@ -17,4 +19,30 @@ def test_with_header_in_csv(self) -> None: with pytest.raises(TypeError): data = read_csv("tests/tiling/header.csv") ##for row in data: - ##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) \ No newline at end of file + ##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) + + def test_is_current_config(self) -> None: + ckpt = "tests/utils/gdl_current_test.pth.tar" + ckpt_dict = read_checkpoint(ckpt, update=False) + assert is_inference_compatible(ckpt_dict) + + def test_update_gdl_checkpoint(self) -> None: + ckpt = "tests/utils/gdl_pre20_test.pth.tar" + ckpt_dict = read_checkpoint(ckpt, update=False) + assert not is_inference_compatible(ckpt_dict) + ckpt_updated = update_gdl_checkpoint(ckpt_dict) + assert is_inference_compatible(ckpt_updated) + + # grouped to put emphasis on before/after result of updating + assert ckpt_dict['params']['global']['number_of_bands'] == 4 + assert ckpt_updated['params']['dataset']['bands'] == ['red', 'green', 'blue', 'nir'] + + assert ckpt_dict['params']['global']['num_classes'] == 1 + assert ckpt_updated['params']['dataset']['classes_dict'] == {'class1': 1} + + means = [0.0950882, 0.13039997, 0.12815733, 0.25175254] + assert ckpt_dict['params']['training']['normalization']['mean'] == means + assert ckpt_updated['params']['augmentation']['normalization']['mean'] == means + + assert ckpt_dict['params']['training']['augmentation']['clahe_enhance'] is True + assert ckpt_updated['params']['augmentation']['clahe_enhance_clip_limit'] == 0.1 diff --git a/utils/utils.py b/utils/utils.py index 9411b701..e2f66171 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -522,23 +522,35 @@ def print_config( rich.print(tree, file=fp) -def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict: - """ - Utility to update model checkpoints from older versions of GDL to current version - @param checkpoint_params: +def is_inference_compatible(cfg: Union[dict, DictConfig]): + """Checks whether a configuration dictionary contains a config structure compatible with current inference script""" + try: + # don't update if already a recent checkpoint + # checks if major keys for current config exist, especially those that have changed over time + cfg['params']['augmentation'] + cfg['params']['dataset']['classes_dict'] + cfg['params']['dataset']['bands'] + cfg['params']['model']['_target_'] + + # model state dicts + cfg['model_state_dict'] + return True + except KeyError as e: + logging.debug(e) + return False + + +def update_gdl_checkpoint(checkpoint: Union[dict, DictConfig]) -> Dict: + """ + Utility to update model checkpoints from older versions of GDL to current version. + NB: The purpose of this utility is ONLY to allow the use of "old" model in current inference script. + Mostly inference-relevant parameters are update. + @param checkpoint: Dictionary containing weights, optimizer state and saved configuration params from training @return: """ - # covers gdl checkpoints from version <= 2.0.1 - if 'model' in checkpoint_params.keys(): - checkpoint_params['model_state_dict'] = checkpoint_params['model'] - del checkpoint_params['model'] - if 'optimizer' in checkpoint_params.keys(): - checkpoint_params['optimizer_state_dict'] = checkpoint_params['optimizer'] - del checkpoint_params['optimizer'] - # covers gdl checkpoints pre-hydra (<=2.0.0) - bands = ['R', 'G', 'B', 'N'] + bands = {'red': 'R', 'green': 'G', 'blue': 'B', 'nir': 'N'} old2new = { 'manet_pretrained': { '_target_': 'segmentation_models_pytorch.MAnet', 'encoder_name': 'resnext50_32x4d', @@ -567,16 +579,19 @@ def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict: 'encoder_weights': 'imagenet' }, } - try: - # don't update if already a recent checkpoint - get_key_def('classes_dict', checkpoint_params['params']['dataset'], expected_type=(dict, DictConfig)) - get_key_def('modalities', checkpoint_params['params']['dataset'], expected_type=Sequence) - get_key_def('model', checkpoint_params['params'], expected_type=(dict, DictConfig)) - return checkpoint_params - except KeyError: - num_classes_ckpt = get_key_def('num_classes', checkpoint_params['params']['global'], expected_type=int) - num_bands_ckpt = get_key_def('number_of_bands', checkpoint_params['params']['global'], expected_type=int) - model_name = get_key_def('model_name', checkpoint_params['params']['global'], expected_type=str) + if not is_inference_compatible(checkpoint): + # covers gdl checkpoints from version <= 2.0.1 + if 'model' in checkpoint.keys(): + checkpoint['model_state_dict'] = checkpoint['model'] + del checkpoint['model'] + try: + num_classes_ckpt = get_key_def('num_classes', checkpoint['params']['global'], expected_type=int) + num_bands_ckpt = get_key_def('number_of_bands', checkpoint['params']['global'], expected_type=int) + model_name = get_key_def('model_name', checkpoint['params']['global'], expected_type=str) + except KeyError as e: + logging.critical(f"\nCouldn't update checkpoint parameters" + f"\nError {type(e)}: {e}") + raise e try: model_ckpt = old2new[model_name] except KeyError as e: @@ -585,17 +600,39 @@ def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict: f"\nError {type(e)}: {e}") raise e # For GDL pre-v2.0.2 - #bands_ckpt = '' - #bands_ckpt = bands_ckpt.join([bands[i] for i in range(num_bands_ckpt)]) - checkpoint_params['params'].update({ + # Move transformation/augmentations hyperparameters + if not "augmentation" in checkpoint["params"].keys(): + checkpoint["params"]["augmentation"] = { + 'normalization': {'mean': [], 'std': []}, + 'clahe_enhance_clip_limit': None + } + try: + means_ckpt = checkpoint['params']['training']['normalization']['mean'] + stds_ckpt = checkpoint['params']['training']['normalization']['std'] + scale_ckpt = checkpoint['params']['global']['scale_data'] + # clahe_enhance was never officially added to GDL, so will default to None if not present + clahe_enhance = get_key_def('clahe_enhance', checkpoint['params']['training']['augmentation'], default=None) + except KeyError as e: # if KeyError on old keys, then we'll assume we have an up-to-date checkpoint + logging.debug(e) + return checkpoint + + checkpoint["params"]["augmentation"]["normalization"]["mean"] = means_ckpt + checkpoint["params"]["augmentation"]["normalization"]["std"] = stds_ckpt + checkpoint["params"]["augmentation"]["scale_data"] = scale_ckpt + checkpoint["params"]["augmentation"]["clahe_enhance_clip_limit"] = 0.1 if clahe_enhance is True else None + + checkpoint['params'].update({'model': model_ckpt}) + + checkpoint['params'].update({ 'dataset': { - 'modalities': [bands[i] for i in range(num_bands_ckpt)], #bands_ckpt, - #"classes_dict": {f"BUIL": 1} + 'bands': [list(bands.keys())[i] for i in range(num_bands_ckpt)], "classes_dict": {f"class{i + 1}": i + 1 for i in range(num_classes_ckpt)} + # Some manually update may be necessary when using old models + # 'bands': ['nir', 'red', 'green'], + # "classes_dict": {f"FORE": 1}, } }) - checkpoint_params['params'].update({'model': model_ckpt}) - return checkpoint_params + return checkpoint def map_wrapper(x):