Skip to content

Commit

Permalink
test_utils.py: add tests for updating utility (#357)
Browse files Browse the repository at this point in the history
utils.py: apply update to more params
  • Loading branch information
remtav authored Oct 5, 2022
1 parent d17d840 commit cdea293
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 32 deletions.
Binary file added tests/utils/gdl_current_test.pth.tar
Binary file not shown.
Binary file added tests/utils/gdl_pre20_test.pth.tar
Binary file not shown.
32 changes: 30 additions & 2 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
from torchgeo.datasets.utils import extract_archive

from utils.utils import read_csv
from models.model_choice import read_checkpoint
from utils.utils import read_csv, is_inference_compatible, update_gdl_checkpoint


class TestUtils(object):
def test_wrong_seperation(self) -> None:
Expand All @@ -17,4 +19,30 @@ def test_with_header_in_csv(self) -> None:
with pytest.raises(TypeError):
data = read_csv("tests/tiling/header.csv")
##for row in data:
##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'])
##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'])

def test_is_current_config(self) -> None:
ckpt = "tests/utils/gdl_current_test.pth.tar"
ckpt_dict = read_checkpoint(ckpt, update=False)
assert is_inference_compatible(ckpt_dict)

def test_update_gdl_checkpoint(self) -> None:
ckpt = "tests/utils/gdl_pre20_test.pth.tar"
ckpt_dict = read_checkpoint(ckpt, update=False)
assert not is_inference_compatible(ckpt_dict)
ckpt_updated = update_gdl_checkpoint(ckpt_dict)
assert is_inference_compatible(ckpt_updated)

# grouped to put emphasis on before/after result of updating
assert ckpt_dict['params']['global']['number_of_bands'] == 4
assert ckpt_updated['params']['dataset']['bands'] == ['red', 'green', 'blue', 'nir']

assert ckpt_dict['params']['global']['num_classes'] == 1
assert ckpt_updated['params']['dataset']['classes_dict'] == {'class1': 1}

means = [0.0950882, 0.13039997, 0.12815733, 0.25175254]
assert ckpt_dict['params']['training']['normalization']['mean'] == means
assert ckpt_updated['params']['augmentation']['normalization']['mean'] == means

assert ckpt_dict['params']['training']['augmentation']['clahe_enhance'] is True
assert ckpt_updated['params']['augmentation']['clahe_enhance_clip_limit'] == 0.1
97 changes: 67 additions & 30 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,23 +522,35 @@ def print_config(
rich.print(tree, file=fp)


def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict:
"""
Utility to update model checkpoints from older versions of GDL to current version
@param checkpoint_params:
def is_inference_compatible(cfg: Union[dict, DictConfig]):
"""Checks whether a configuration dictionary contains a config structure compatible with current inference script"""
try:
# don't update if already a recent checkpoint
# checks if major keys for current config exist, especially those that have changed over time
cfg['params']['augmentation']
cfg['params']['dataset']['classes_dict']
cfg['params']['dataset']['bands']
cfg['params']['model']['_target_']

# model state dicts
cfg['model_state_dict']
return True
except KeyError as e:
logging.debug(e)
return False


def update_gdl_checkpoint(checkpoint: Union[dict, DictConfig]) -> Dict:
"""
Utility to update model checkpoints from older versions of GDL to current version.
NB: The purpose of this utility is ONLY to allow the use of "old" model in current inference script.
Mostly inference-relevant parameters are update.
@param checkpoint:
Dictionary containing weights, optimizer state and saved configuration params from training
@return:
"""
# covers gdl checkpoints from version <= 2.0.1
if 'model' in checkpoint_params.keys():
checkpoint_params['model_state_dict'] = checkpoint_params['model']
del checkpoint_params['model']
if 'optimizer' in checkpoint_params.keys():
checkpoint_params['optimizer_state_dict'] = checkpoint_params['optimizer']
del checkpoint_params['optimizer']

# covers gdl checkpoints pre-hydra (<=2.0.0)
bands = ['R', 'G', 'B', 'N']
bands = {'red': 'R', 'green': 'G', 'blue': 'B', 'nir': 'N'}
old2new = {
'manet_pretrained': {
'_target_': 'segmentation_models_pytorch.MAnet', 'encoder_name': 'resnext50_32x4d',
Expand Down Expand Up @@ -567,16 +579,19 @@ def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict:
'encoder_weights': 'imagenet'
},
}
try:
# don't update if already a recent checkpoint
get_key_def('classes_dict', checkpoint_params['params']['dataset'], expected_type=(dict, DictConfig))
get_key_def('modalities', checkpoint_params['params']['dataset'], expected_type=Sequence)
get_key_def('model', checkpoint_params['params'], expected_type=(dict, DictConfig))
return checkpoint_params
except KeyError:
num_classes_ckpt = get_key_def('num_classes', checkpoint_params['params']['global'], expected_type=int)
num_bands_ckpt = get_key_def('number_of_bands', checkpoint_params['params']['global'], expected_type=int)
model_name = get_key_def('model_name', checkpoint_params['params']['global'], expected_type=str)
if not is_inference_compatible(checkpoint):
# covers gdl checkpoints from version <= 2.0.1
if 'model' in checkpoint.keys():
checkpoint['model_state_dict'] = checkpoint['model']
del checkpoint['model']
try:
num_classes_ckpt = get_key_def('num_classes', checkpoint['params']['global'], expected_type=int)
num_bands_ckpt = get_key_def('number_of_bands', checkpoint['params']['global'], expected_type=int)
model_name = get_key_def('model_name', checkpoint['params']['global'], expected_type=str)
except KeyError as e:
logging.critical(f"\nCouldn't update checkpoint parameters"
f"\nError {type(e)}: {e}")
raise e
try:
model_ckpt = old2new[model_name]
except KeyError as e:
Expand All @@ -585,17 +600,39 @@ def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict:
f"\nError {type(e)}: {e}")
raise e
# For GDL pre-v2.0.2
#bands_ckpt = ''
#bands_ckpt = bands_ckpt.join([bands[i] for i in range(num_bands_ckpt)])
checkpoint_params['params'].update({
# Move transformation/augmentations hyperparameters
if not "augmentation" in checkpoint["params"].keys():
checkpoint["params"]["augmentation"] = {
'normalization': {'mean': [], 'std': []},
'clahe_enhance_clip_limit': None
}
try:
means_ckpt = checkpoint['params']['training']['normalization']['mean']
stds_ckpt = checkpoint['params']['training']['normalization']['std']
scale_ckpt = checkpoint['params']['global']['scale_data']
# clahe_enhance was never officially added to GDL, so will default to None if not present
clahe_enhance = get_key_def('clahe_enhance', checkpoint['params']['training']['augmentation'], default=None)
except KeyError as e: # if KeyError on old keys, then we'll assume we have an up-to-date checkpoint
logging.debug(e)
return checkpoint

checkpoint["params"]["augmentation"]["normalization"]["mean"] = means_ckpt
checkpoint["params"]["augmentation"]["normalization"]["std"] = stds_ckpt
checkpoint["params"]["augmentation"]["scale_data"] = scale_ckpt
checkpoint["params"]["augmentation"]["clahe_enhance_clip_limit"] = 0.1 if clahe_enhance is True else None

checkpoint['params'].update({'model': model_ckpt})

checkpoint['params'].update({
'dataset': {
'modalities': [bands[i] for i in range(num_bands_ckpt)], #bands_ckpt,
#"classes_dict": {f"BUIL": 1}
'bands': [list(bands.keys())[i] for i in range(num_bands_ckpt)],
"classes_dict": {f"class{i + 1}": i + 1 for i in range(num_classes_ckpt)}
# Some manually update may be necessary when using old models
# 'bands': ['nir', 'red', 'green'],
# "classes_dict": {f"FORE": 1},
}
})
checkpoint_params['params'].update({'model': model_ckpt})
return checkpoint_params
return checkpoint


def map_wrapper(x):
Expand Down

0 comments on commit cdea293

Please sign in to comment.