From 36c3d05300812134c9205f873d7a8c06cc90f721 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Tavon?= <34774759+remtav@users.noreply.github.com> Date: Wed, 21 Sep 2022 10:42:31 -0400 Subject: [PATCH 1/5] Add band selection and ordering from multiband source rasters (#341) * geoutils.py: add utility to create VRT from subset of bands of multiband raster aoi.py: add support for band selection from multiband raster test_aoi.py: - test band selection for multiband - test error if selection is done with letters, not integers - test error if too many bands requested test_geoutils.py: add tests for VRT creation utilities * remove legacy BGR_to_RGB parameter * small modifications --- .../dataset/test_ci_segmentation_binary.yaml | 2 +- .../test_ci_segmentation_multiclass.yaml | 2 +- dataset/README.md | 10 +++ dataset/aoi.py | 45 ++++++++---- inference_segmentation.py | 5 -- sampling_segmentation.py | 2 + tests/dataset/test_aoi.py | 70 ++++++++++++++++--- tests/test_verify_segmentation.py | 2 + tests/utils/test_geoutils.py | 24 +++++++ tests/utils/test_utils.py | 9 +-- train_segmentation.py | 7 -- utils/augmentation.py | 26 +------ utils/geoutils.py | 41 ++++++++++- utils/utils.py | 8 --- verify_segmentation.py | 4 +- 15 files changed, 173 insertions(+), 84 deletions(-) create mode 100644 tests/utils/test_geoutils.py diff --git a/config/dataset/test_ci_segmentation_binary.yaml b/config/dataset/test_ci_segmentation_binary.yaml index 33d3e9f2..ce6af4d9 100644 --- a/config/dataset/test_ci_segmentation_binary.yaml +++ b/config/dataset/test_ci_segmentation_binary.yaml @@ -11,7 +11,7 @@ dataset: download_data: False # imagery - bands: [R, G, B] + bands: [1,2,3] # ground truth attribute_field: properties/class diff --git a/config/dataset/test_ci_segmentation_multiclass.yaml b/config/dataset/test_ci_segmentation_multiclass.yaml index 324a079f..6105e59e 100644 --- a/config/dataset/test_ci_segmentation_multiclass.yaml +++ b/config/dataset/test_ci_segmentation_multiclass.yaml @@ -11,7 +11,7 @@ dataset: download_data: False # imagery - bands: [R, G, B] + bands: [1,2,3] # ground truth attribute_field: properties/Quatreclasses diff --git a/dataset/README.md b/dataset/README.md index a1b4d2f5..1ad0ad8f 100644 --- a/dataset/README.md +++ b/dataset/README.md @@ -27,6 +27,16 @@ To support both single-band and multi-band imagery, the path in the first column ### 1. Path to a multi-band image file: `my_dir/my_multiband_geofile.tif` +A particular order or subset of bands in multi-band file must be used by setting a list of specific indices: + +#### Example: + +`bands: [3, 2, 1]` + +Here, if the original multi-band raster had BGR bands, geo-deep-learning will reorder these bands to RGB order. + +The `bands` parameter is set in the [dataset config](../config/dataset/test_ci_segmentation_multiclass.yaml). + ### 2. Path to single-band image files, using only a common string A path to a list of single-band rasters can be inserted in the csv, but only a the string common to all single-band files should be considered. The "band specific" string in the file name must be in a [hydra-like interpolation format](https://hydra.cc/docs/1.0/advanced/override_grammar/basic/#primitives), with `${...}` notation. The interpolation string completed during execution by a dataset parameter with a list of desired band identifiers to help resolve the single-band filenames. diff --git a/dataset/aoi.py b/dataset/aoi.py index 315cfaa0..0b7946ba 100644 --- a/dataset/aoi.py +++ b/dataset/aoi.py @@ -18,7 +18,7 @@ from torchvision.datasets.utils import download_url from tqdm import tqdm -from utils.geoutils import stack_singlebands_vrt, is_stac_item, create_new_raster_from_base +from utils.geoutils import stack_singlebands_vrt, is_stac_item, create_new_raster_from_base, subset_multiband_vrt from utils.logger import get_logger from utils.utils import read_csv from utils.verifications import assert_crs_match, validate_raster, \ @@ -51,8 +51,8 @@ def __init__( self.item = item self._assets_by_common_name = None - if bands_requested is not None and len(bands_requested) == 0: - logging.warning(f"At least one band should be chosen if assets need to be reached") + if not bands_requested: + raise ValueError(f"At least one band should be chosen if assets need to be reached") # Create band inventory (all available bands) self.bands_all = [band for band in self.asset_by_common_name.keys()] @@ -183,7 +183,10 @@ def __init__(self, raster: Union[Path, str], self.raster_stac_item = None # If parsed result has more than a single file, then we're dealing with single-band files - self.raster_src_is_multiband = True if len(raster_parsed) == 1 else False + if len(raster_parsed) == 1 and rasterio.open(raster_parsed[0]).count > 1: + self.raster_src_is_multiband = True + else: + self.raster_src_is_multiband = False # Download assets if desired self.download_data = download_data @@ -203,8 +206,8 @@ def __init__(self, raster: Union[Path, str], self.raster_parsed = raster_parsed # if single band assets, build multiband VRT - self.raster_to_multiband(virtual=True) - self.raster_read() + self.src_raster_to_dest_multiband(virtual=True) + self.raster_open() self.raster_meta = self.raster.meta self.raster_meta['name'] = self.raster.name if self.raster_src_is_multiband: @@ -297,8 +300,8 @@ def __init__(self, raster: Union[Path, str], ) if len(self.label_gdf_filtered) == 0: logging.warning(f"\nNo features found for ground truth \"{self.label}\"," - f"\nfiltered by attribute field \"{self.attr_field_filter}\"" - f"\nwith values \"{self.attr_values_filter}\"") + f"\nfiltered by attribute field \"{self.attr_field_filter}\"" + f"\nwith values \"{self.attr_values_filter}\"") else: self.label_gdf_filtered = None @@ -347,7 +350,6 @@ def from_dict(cls, ) return new_aoi - # TODO: is this necessary if to_dict() is good enough? def __str__(self): return ( f"\nAOI ID: {self.aoi_id}" @@ -359,16 +361,29 @@ def __str__(self): f"\n\tAttribute values filter: {self.attr_values_filter}" ) - def raster_to_multiband(self, virtual=True): + def src_raster_to_dest_multiband(self, virtual=True): + """ + Outputs a multiband raster from multiple sources of input raster + E.g.: multiple singleband files, single multiband file with undesired bands, etc. + """ if not self.raster_src_is_multiband: if virtual: self.raster_multiband = stack_singlebands_vrt(self.raster_parsed) else: self.raster_multiband = self.write_multiband_from_singleband_rasters_as_vrt() + elif self.raster_src_is_multiband and self.raster_bands_request: + if not all([isinstance(band, int) for band in self.raster_bands_request]): + raise ValueError(f"Use only a list of integers to select bands from a multiband raster.\n" + f"Got {self.raster_bands_request}") + if len(self.raster_bands_request) > rasterio.open(self.raster_raw_input).count: + raise ValueError(f"Trying to subset more bands than actual number in source raster.\n" + f"Requested: {self.raster_bands_request}\n" + f"Available: {rasterio.open(self.raster_raw_input).count}") + self.raster_multiband = subset_multiband_vrt(self.raster_parsed[0], band_request=self.raster_bands_request) else: self.raster_multiband = self.raster_parsed[0] - def raster_read(self): + def raster_open(self): self.raster = _check_rasterio_im_load(self.raster_multiband) def to_dict(self, extended=True): @@ -509,8 +524,10 @@ def parse_input_raster( raster = [value['meta'].href for value in item.bands_requested.values()] return raster elif "${dataset.bands}" in csv_raster_str: - if not isinstance(raster_bands_requested, (List, ListConfig, tuple)) or len(raster_bands_requested) == 0: - raise TypeError(f"\nRequested bands should a list of bands. " + if not raster_bands_requested \ + or not isinstance(raster_bands_requested, (List, ListConfig, tuple)) \ + or len(raster_bands_requested) == 0: + raise TypeError(f"\nRequested bands should be a list of bands. " f"\nGot {raster_bands_requested} of type {type(raster_bands_requested)}") raster = [csv_raster_str.replace("${dataset.bands}", band) for band in raster_bands_requested] return raster @@ -593,7 +610,7 @@ def aois_from_csv( @param csv_path: path to csv file containing list of input data. See README for details on expected structure of csv. @param bands_requested: - List of bands to select from inputted imagery. Applies only to single-band input imagery. + List of bands to select from inputted imagery @param attr_values_filter: Attribute filed to filter features from @param attr_field_filter: diff --git a/inference_segmentation.py b/inference_segmentation.py index 52eeb33a..9dbfcec4 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -138,7 +138,6 @@ def segmentation(param, chunk_size: int, device, scale: List, - BGR_to_RGB: bool, tp_mem, debug=False, ): @@ -152,7 +151,6 @@ def segmentation(param, chunk_size: image tile size device: cuda/cpu device scale: scale range - BGR_to_RGB: True/False tp_mem: memory temp file for saving numpy array to disk debug: True/False @@ -192,7 +190,6 @@ def segmentation(param, sample['metadata'] = image_metadata totensor_transform = augmentation.compose_transforms(param, dataset="tst", - input_space=BGR_to_RGB, scale=scale, aug_type='totensor', print_log=print_log) @@ -341,7 +338,6 @@ def main(params: Union[DictConfig, dict]) -> None: # Default input directory based on default output directory raw_data_csv = get_key_def('raw_data_csv', params['inference'], default=working_folder, expected_type=str, to_path=True, validate_path_exists=True) - BGR_to_RGB = get_key_def('BGR_to_RGB', params['dataset'], expected_type=bool) # LOGGING PARAMETERS exper_name = get_key_def('project_name', params['general'], default='gdl-training') @@ -403,7 +399,6 @@ def main(params: Union[DictConfig, dict]) -> None: chunk_size=chunk_size, device=device, scale=scale, - BGR_to_RGB=BGR_to_RGB, tp_mem=temp_file, debug=debug) diff --git a/sampling_segmentation.py b/sampling_segmentation.py index 0ae546e1..5d2be78d 100644 --- a/sampling_segmentation.py +++ b/sampling_segmentation.py @@ -331,6 +331,8 @@ def main(cfg: DictConfig) -> None: # PARAMETERS num_classes = len(cfg.dataset.classes_dict.keys()) bands_requested = get_key_def('bands', cfg['dataset'], default=None, expected_type=Sequence) + if not bands_requested: + raise ValueError(f"") num_bands = len(bands_requested) debug = cfg.debug diff --git a/tests/dataset/test_aoi.py b/tests/dataset/test_aoi.py index 0ca9bac2..214617dd 100644 --- a/tests/dataset/test_aoi.py +++ b/tests/dataset/test_aoi.py @@ -3,6 +3,7 @@ from pathlib import Path import geopandas as gpd +import numpy as np import pytest import rasterio from rasterio import RasterioIOError @@ -18,10 +19,46 @@ def test_multiband_input(self): """Tests reading a multiband raster as input""" extract_archive(src="tests/data/spacenet.zip") data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") - for row in data: - aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) + row = data[0] + aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) + src_count = rasterio.open(aoi.raster_raw_input).count + assert src_count == aoi.raster.count + aoi.close_raster() + + def test_multiband_input_band_selection(self): + """Tests reading a multiband raster as input with band selection""" + extract_archive(src="tests/data/spacenet.zip") + data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + row = data[0] + bands_request = [2, 1] + aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=bands_request) + src_raster_subset = rasterio.open(aoi.raster_raw_input) + src_np_subset = src_raster_subset.read(bands_request) + dest_raster_subset = rasterio.open(aoi.raster_multiband) + assert src_np_subset.shape[0] == dest_raster_subset.count + dest_np_subset = dest_raster_subset.read() + assert np.all(src_np_subset == dest_np_subset) + aoi.close_raster() + + def test_multiband_input_band_selection_from_letters(self): + """Tests error when selecting bands from a multiband raster using letters, not integers""" + extract_archive(src="tests/data/spacenet.zip") + data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + row = data[0] + bands_request = ["R", "G"] + with pytest.raises(ValueError): + aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=bands_request) aoi.close_raster() + def test_multiband_input_band_selection_too_many(self): + """Tests error when selecting too many bands from a multiband raster""" + extract_archive(src="tests/data/spacenet.zip") + data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + row = data[0] + bands_request = [1, 2, 3, 4, 5] + with pytest.raises(ValueError): + aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=bands_request) + def test_singleband_input(self): """Tests reading a singleband raster as input with ${dataset.bands} pattern""" extract_archive(src="tests/data/spacenet.zip") @@ -49,7 +86,7 @@ def test_stac_url_input(self): raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R'], download_data=True, root_dir="data" ) - assert aoi.download_data == True + assert aoi.download_data is True assert Path("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-R.tif").is_file() aoi.close_raster() os.remove("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-R.tif") @@ -154,6 +191,18 @@ def test_stac_input_missing_band(self): raster_bands_request=['ru', 'gris', 'but']) aoi.close_raster() + def test_stac_input_empty_band_request(self): + """Tests error when band selection is required (stac item) but missing""" + extract_archive(src="tests/data/spacenet.zip") + extract_archive(src="tests/data/massachusetts_buildings_kaggle.zip") + raster_raw = ( + ("tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json", ""), + ("tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped_${dataset.bands}.tif", ""), + ) + for raster_raw, bands_requested in raster_raw: + with pytest.raises((ValueError, TypeError)): + AOI.parse_input_raster(csv_raster_str=raster_raw, raster_bands_requested=bands_requested) + def test_no_intersection(self) -> None: """Tests error testing no intersection between raster and label""" extract_archive(src="tests/data/spacenet.zip") @@ -164,16 +213,17 @@ def test_no_intersection(self) -> None: assert aoi.bounds_iou == 0 aoi.close_raster() + def test_write_multiband_from_single_band(self) -> None: """Tests the 'write_multiband' method""" extract_archive(src="tests/data/spacenet.zip") data = read_csv("tests/sampling/sampling_segmentation_binary-singleband_ci.csv") - for row in data: - aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B'], - write_multiband=True, root_dir="data") - assert Path("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-RGB.tif").is_file() - aoi.close_raster() - os.remove("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-RGB.tif") + row = data[0] + aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B'], + write_multiband=True, root_dir="data") + assert Path("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-RGB.tif").is_file() + aoi.close_raster() + os.remove("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-RGB.tif") def test_write_multiband_from_single_band_url(self) -> None: """Tests the 'write_multiband' method with singleband raster as URL""" @@ -274,6 +324,6 @@ def map_wrapper(x): def aoi_read_raster(aoi: AOI): """Function to package in multiprocessing""" - aoi.raster_read() + aoi.raster_open() return aoi.raster.meta diff --git a/tests/test_verify_segmentation.py b/tests/test_verify_segmentation.py index 16274c6a..21664c01 100644 --- a/tests/test_verify_segmentation.py +++ b/tests/test_verify_segmentation.py @@ -10,6 +10,7 @@ class TestVerify(object): def test_verify_per_aoi(self): + """Test stats outputs from an AOI""" extract_archive(src="tests/data/new_brunswick_aerial.zip") data = read_csv("tests/sampling/sampling_segmentation_multiclass_ci.csv") aoi = AOI(raster=data[0]['tif'], label=data[0]['gpkg'], split=data[0]['split']) @@ -25,6 +26,7 @@ def test_verify_per_aoi(self): assert aoi_dict['band_0_mean'] == 159.36075617930456 def test_verify_segmentation_parallel(self): + """Integration test to check verify mode without specific assert""" with initialize(config_path="../config", job_name="test_ci"): cfg = compose(config_name="gdl_config_template", overrides=[f"mode=verify", diff --git a/tests/utils/test_geoutils.py b/tests/utils/test_geoutils.py new file mode 100644 index 00000000..49f47f19 --- /dev/null +++ b/tests/utils/test_geoutils.py @@ -0,0 +1,24 @@ +import numpy as np +import rasterio +from torchgeo.datasets.utils import extract_archive + +from dataset.aoi import AOI +from utils.utils import read_csv + + +class TestGeoutils(object): + def test_multiband_vrt_from_single_band(self) -> None: + """Tests the 'stack_singlebands_vrt' utility""" + extract_archive(src="tests/data/spacenet.zip") + data = read_csv("tests/sampling/sampling_segmentation_binary-singleband_ci.csv") + row = data[0] + bands_request = ['R', 'G', 'B'] + aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], + raster_bands_request=bands_request, root_dir="data") + assert aoi.raster.count == len(bands_request) + src_red = rasterio.open(aoi.raster_raw_input.replace("${dataset.bands}", "R")) + src_red_np = src_red.read() + dest_red_np = aoi.raster.read(1) + # make sure first band in multiband VRT is identical to source R band + assert np.all(src_red_np == dest_red_np) + aoi.close_raster() diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 29dda9f6..87f80949 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -1,16 +1,9 @@ -import multiprocessing -from pathlib import Path - -import geopandas as gpd import pytest -import rasterio -from shapely.geometry import box from torchgeo.datasets.utils import extract_archive -from dataset.aoi import AOI from utils.utils import read_csv -class Test_utils(object): +class TestUtils(object): def test_wrong_seperation(self) -> None: extract_archive(src="tests/data/spacenet.zip") with pytest.raises(TypeError): diff --git a/train_segmentation.py b/train_segmentation.py index 101d2b72..60d75288 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -47,7 +47,6 @@ def create_dataloader(samples_folder: Path, dontcare_val: int, crop_size: int, num_bands: int, - BGR_to_RGB: bool, scale: Sequence, cfg: DictConfig, dontcare2backgr: bool = False, @@ -61,7 +60,6 @@ def create_dataloader(samples_folder: Path, :param sample_size: (int) size of hdf5 samples (used to evaluate eval batch-size) :param dontcare_val: (int) value in label to be ignored during loss calculation :param num_bands: (int) number of bands in imagery - :param BGR_to_RGB: (bool) if True, BGR channels will be flipped to RGB :param scale: (List) imagery data will be scaled to this min and max value (ex.: 0 to 1) :param cfg: (dict) Parameters found in the yaml config file. :param dontcare2backgr: (bool) if True, all dontcare values in label will be replaced with 0 (background value) @@ -92,7 +90,6 @@ def create_dataloader(samples_folder: Path, crop_size=crop_size), totensor_transform=aug.compose_transforms(params=cfg, dataset=subset, - input_space=BGR_to_RGB, scale=scale, dontcare2backgr=dontcare2backgr, dontcare=dontcare_val, @@ -472,9 +469,6 @@ def train(cfg: DictConfig) -> None: batch_size = get_key_def('batch_size', cfg['training'], expected_type=int) eval_batch_size = get_key_def('eval_batch_size', cfg['training'], expected_type=int, default=batch_size) num_epochs = get_key_def('max_epochs', cfg['training'], expected_type=int) - # TODO need to keep in parameters? see victor stuff - # BGR_to_RGB = get_key_def('BGR_to_RGB', params['global'], expected_type=bool) - BGR_to_RGB = False # OPTIONAL PARAMETERS debug = get_key_def('debug', cfg) @@ -609,7 +603,6 @@ def train(cfg: DictConfig) -> None: dontcare_val=dontcare_val, crop_size=crop_size, num_bands=num_bands, - BGR_to_RGB=BGR_to_RGB, scale=scale, cfg=cfg, dontcare2backgr=dontcare2backgr, diff --git a/utils/augmentation.py b/utils/augmentation.py index 88e97b0c..98563621 100644 --- a/utils/augmentation.py +++ b/utils/augmentation.py @@ -13,14 +13,13 @@ import numpy as np from skimage import transform, exposure from torchvision import transforms -from utils.utils import get_key_def, pad, minmax_scale, BGR_to_RGB +from utils.utils import get_key_def, pad, minmax_scale logging.getLogger(__name__) def compose_transforms(params, dataset, - input_space: bool = False, scale: Sequence = None, aug_type: str = '', dontcare=None, @@ -29,7 +28,6 @@ def compose_transforms(params, print_log=True): """ Function to compose the transformations to be applied on every batches. - :param input_space: (bool) if True, flip BGR channels to RGB :param params: (dict) Parameters found in the yaml config file :param dataset: (str) One of 'trn', 'val', 'tst' :param aug_type: (str) One of 'geometric', 'radiometric' @@ -84,15 +82,6 @@ def compose_transforms(params, trim_at_eval = round((random_radiom_trim_range[-1] - random_radiom_trim_range[0]) / 2, 1) lst_trans.append(RadiometricTrim(random_range=[trim_at_eval, trim_at_eval])) - if input_space: - lst_trans.append(BgrToRgb(input_space)) - else: - if print_log: - logging.info( - f"\nThe '{dataset}' images will be fed to model as is. " - f'First 3 bands of imagery should be RGB, not BGR.' - ) - if scale: lst_trans.append(Scale(scale)) # TODO: assert coherence with below normalization else: @@ -360,19 +349,6 @@ def __call__(self, sample): return sample -class BgrToRgb(object): - """Normalize Image with Mean and STD and similar to Pytorch(transform.Normalize) function """ - - def __init__(self, bgr_to_rgb): - self.bgr_to_rgb = bgr_to_rgb - - def __call__(self, sample): - sat_img = BGR_to_RGB(sample['sat_img']) if self.bgr_to_rgb else sample['sat_img'] - sample['sat_img'] = sat_img - - return sample - - class ToTensorTarget(object): """Convert ndarrays in sample to Tensors.""" def __init__(self, dontcare2backgr: bool = False, dontcare_val: int = None): diff --git a/utils/geoutils.py b/utils/geoutils.py index e4a03974..fd963500 100644 --- a/utils/geoutils.py +++ b/utils/geoutils.py @@ -1,5 +1,7 @@ import collections import logging +from pathlib import Path +from typing import List, Union, Sequence import numpy as np @@ -159,14 +161,14 @@ def is_stac_item(path: str) -> bool: return False -def stack_singlebands_vrt(srcs, band=1): +def stack_singlebands_vrt(srcs: List, band: int = 1): """ Stacks multiple single-band raster into a single multiband virtual raster Source: https://gis.stackexchange.com/questions/392695/is-it-possible-to-build-a-vrt-file-from-multiple-files-with-rasterio @param srcs: - List of paths/urls to single-band raster + List of paths/urls to single-band rasters @param band: - TODO + Index of band from source raster to stack into multiband VRT (index starts at 1 per GDAL convention) @return: RasterDataset object containing VRT """ @@ -185,3 +187,36 @@ def stack_singlebands_vrt(srcs, band=1): vrt_dataset.append(vrt_band) return ET.tostring(vrt_dataset).decode('UTF-8') + + +def subset_multiband_vrt(src: Union[str, Path], band_request: Sequence = []): + """ + Creates a multiband virtual raster containing a subset of all available bands in a source multiband raster + @param src: + Path/url to a multiband raster + @param band_request: + Indices of bands from source raster to subset from source multiband (index starts at 1 per GDAL convention). + Order matters, i.e. if source raster is BGR, "[3,2,1]" will create a VRT with bands as RGB + @return: + RasterDataset object containing VRT + """ + vrt_bands = [] + if not isinstance(src, (str, Path)) and not Path(src).is_file(): + raise ValueError(f"Invalid source multiband raster.\n" + f"Got {src}") + with rasterio.open(src) as ras, MemoryFile() as mem: + riocopy(ras, mem.name, driver='VRT') + vrt_xml = mem.read().decode('utf-8') + vrt_dataset = ET.fromstring(vrt_xml) + vrt_dataset_dict = {int(band.get('band')): band for band in vrt_dataset.iter("VRTRasterBand")} + for dest_band_idx, src_band_idx in enumerate(band_request, start=1): + vrt_band = vrt_dataset_dict[src_band_idx] + vrt_band.set('band', str(dest_band_idx)) + vrt_bands.append(vrt_band) + vrt_dataset.remove(vrt_band) + for leftover_band in vrt_dataset.iter("VRTRasterBand"): + vrt_dataset.remove(leftover_band) + for vrt_band in vrt_bands: + vrt_dataset.append(vrt_band) + + return ET.tostring(vrt_dataset).decode('UTF-8') diff --git a/utils/utils.py b/utils/utils.py index f7fd04e6..a3e3c7ad 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -281,14 +281,6 @@ def unnormalize(input_img, mean, std): return (input_img * std) + mean -def BGR_to_RGB(array): - assert array.shape[2] >= 3, f"Not enough channels in array of shape {array.shape}" - BGR_channels = array[..., :3] - RGB_channels = np.ascontiguousarray(BGR_channels[..., ::-1]) - array[:, :, :3] = RGB_channels - return array - - def checkpoint_url_download(url: str): mime_type = ('application/tar', 'application/x-tar', 'applicaton/x-gtar', 'multipart/x-tar', 'application/x-compress', 'application/x-compressed') diff --git a/verify_segmentation.py b/verify_segmentation.py index a7a95c04..b5a5b09f 100644 --- a/verify_segmentation.py +++ b/verify_segmentation.py @@ -38,7 +38,7 @@ def verify_per_aoi( Returns info on AOI or error raised, if any. """ try: - aoi.raster_read() # in case of multiprocessing + aoi.raster_open() # in case of multiprocessing # get aoi info logging.info(f"\nGetting data info for {aoi.aoi_id}...") @@ -77,7 +77,7 @@ def verify_per_aoi( plt.close() return aoi_dict, None except Exception as e: - logging.error(e) + raise e #logging.error(e) return None, e From 7ce428eb95405fef7f06d119fb1ba6fa0961a24a Mon Sep 17 00:00:00 2001 From: mpelchat04 <38693210+mpelchat04@users.noreply.github.com> Date: Wed, 21 Sep 2022 14:49:19 -0400 Subject: [PATCH 2/5] Tests for metrics calculation + Torchmetrics for IoU (#345) * add tests for metrics calculation * replace iou calculation with torchmetrics iou * fixes #190 --- tests/utils/test_metrics.py | 170 ++++++++++++++++++++++++++++++++++++ train_segmentation.py | 31 ++++--- utils/metrics.py | 102 ++++++++++++++-------- 3 files changed, 254 insertions(+), 49 deletions(-) create mode 100644 tests/utils/test_metrics.py diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py new file mode 100644 index 00000000..3d97f43a --- /dev/null +++ b/tests/utils/test_metrics.py @@ -0,0 +1,170 @@ +from utils.metrics import create_metrics_dict, report_classification, iou +import torch +import pytest + +# Test arrays: [bs=2, h=2, w=2] +def init_tensors(): + pred_multi = torch.tensor([0, 0, 2, 2, 0, 2, 1, 2, 1, 0, 2, 2, 1, 0, 2, 2]) + pred_binary = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]) + + lbl_multi = torch.tensor([1, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1]) + lbl_binary = torch.tensor([1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1]) + # array with dont care + lbl_multi_dc = torch.tensor([-1, -1, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1]) + lbl_binary_dc = torch.tensor([-1, -1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1]) + return {'pred_multi': pred_multi, + 'pred_binary': pred_binary, + 'lbl_multi': lbl_multi, + 'lbl_binary': lbl_binary, + 'lbl_multi_dc': lbl_multi_dc, + 'lbl_binary_dc': lbl_binary_dc} + + +class TestMetrics(object): + def test_create_metrics_dict(self): + """Evaluate the metrics dictionnary creation. + Binary and multiclass""" + # binary tasks have 1 class at class definition. + num_classes = 1 + metrics_dict = create_metrics_dict(num_classes) + assert 'iou_1' in metrics_dict.keys() + assert 'iou_2' not in metrics_dict.keys() + + num_classes = 3 + metrics_dict = create_metrics_dict(num_classes) + assert 'iou_1' in metrics_dict.keys() + assert 'iou_2' in metrics_dict.keys() + assert 'iou_3' not in metrics_dict.keys() + del metrics_dict + + def test_report_classification_multi(self): + """Evaluate report classification. + Multiclass, without ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(3) + metrics_dict = report_classification(t['pred_multi'], + t['lbl_multi'], + batch_size=2, + metrics_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['precision'].val) == "0.327083" + assert "{:.6f}".format(metrics_dict['recall'].val) == "0.312500" + assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.314935" + + def test_report_classification_multi_ignore_idx(self): + """Evaluate report classification. + Multiclass, with ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(3) + metrics_dict = report_classification(t['pred_multi'], + t['lbl_multi_dc'], + batch_size=2, + metrics_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['precision'].val) == "0.297619" + assert "{:.6f}".format(metrics_dict['recall'].val) == "0.285714" + assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.283163" + + def test_report_classification_binary(self): + """Evaluate report classification. + Binary, without ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(1) + metrics_dict = report_classification(t['pred_binary'], + t['lbl_binary'], + batch_size=2, + metrics_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['precision'].val) == "0.547727" + assert "{:.6f}".format(metrics_dict['recall'].val) == "0.562500" + assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.553030" + + def test_report_classification_binary_ignore_idx(self): + """Evaluate report classification. + Binary, without ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(1) + metrics_dict = report_classification(t['pred_binary'], + t['lbl_binary_dc'], + batch_size=2, + metrics_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['precision'].val) == "0.528139" + assert "{:.6f}".format(metrics_dict['recall'].val) == "0.571429" + assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.539286" + + def test_iou_multi(self): + """Evaluate iou calculation. + Multiclass, without ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(3) + metrics_dict = iou(t['pred_multi'], + t['lbl_multi'], + batch_size=2, + num_classes=3, + metric_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.185185" + + def test_iou_multi_ignore_idx(self): + """Evaluate iou calculation. + Multiclass, with ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(3) + # wih ignore_index == -1 + metrics_dict = iou(t['pred_multi'], + t['lbl_multi_dc'], + batch_size=2, + num_classes=3, + metric_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.233333" + + # with ignore_index == 0 + t = init_tensors() + metrics_dict = create_metrics_dict(3) + metrics_dict = iou(t['pred_multi'], + t['lbl_multi'], + batch_size=2, + num_classes=3, + metric_dict=metrics_dict, + ignore_index=0) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.208333" + + def test_iou_binary(self): + """Evaluate iou calculation. + Binary, without ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(1) + metrics_dict = iou(t['pred_binary'], + t['lbl_binary'], + batch_size=2, + num_classes=1, + metric_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.361111" + + def test_iou_binary_ignore_idx(self): + """Evaluate iou calculation. + Binary, with ignore_index in array.""" + t = init_tensors() + metrics_dict = create_metrics_dict(1) + # with ignore_index == -1 + metrics_dict = iou(t['pred_binary'], + t['lbl_binary_dc'], + batch_size=2, + num_classes=1, + metric_dict=metrics_dict, + ignore_index=-1) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.435897" + + # with ignore_index == 0 + t = init_tensors() + metrics_dict = create_metrics_dict(3) + with pytest.raises(ValueError): + metrics_dict = iou(t['pred_binary'], + t['lbl_binary_dc'], + batch_size=2, + num_classes=3, + metric_dict=metrics_dict, + ignore_index=0) \ No newline at end of file diff --git a/train_segmentation.py b/train_segmentation.py index 60d75288..ad59a571 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -336,7 +336,8 @@ def evaluation(eval_loader, batch_metrics=None, dataset='val', device=None, - debug=False): + debug=False, + dontcare=-1): """ Evaluate the model and return the updated metrics :param eval_loader: data loader @@ -401,14 +402,14 @@ def evaluation(eval_loader, f"{len(eval_loader)}. Metrics in validation loop won't be computed") if (batch_index + 1) % batch_metrics == 0: # +1 to skip val loop at very beginning a, segmentation = torch.max(outputs_flatten, dim=1) - eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics) + eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics, dontcare) eval_metrics = report_classification(segmentation, labels_flatten, batch_size, eval_metrics, - ignore_index=eval_loader.dataset.dontcare) - elif (dataset == 'tst') and (batch_metrics is not None): + ignore_index=dontcare) + elif (dataset == 'tst'): a, segmentation = torch.max(outputs_flatten, dim=1) - eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics) + eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics, dontcare) eval_metrics = report_classification(segmentation, labels_flatten, batch_size, eval_metrics, - ignore_index=eval_loader.dataset.dontcare) + ignore_index=dontcare) logging.debug(OrderedDict(dataset=dataset, loss=f'{eval_metrics["loss"].avg:.4f}')) @@ -421,11 +422,11 @@ def evaluation(eval_loader, if eval_metrics['loss'].avg: logging.info(f"\n{dataset} Loss: {eval_metrics['loss'].avg:.4f}") - if batch_metrics is not None: - logging.info(f"\n{dataset} precision: {eval_metrics['precision'].avg}") - logging.info(f"\n{dataset} recall: {eval_metrics['recall'].avg}") - logging.info(f"\n{dataset} fscore: {eval_metrics['fscore'].avg}") - logging.info(f"\n{dataset} iou: {eval_metrics['iou'].avg}") + if batch_metrics is not None or dataset == 'tst': + logging.info(f"\n{dataset} precision: {eval_metrics['precision'].avg:.4f}") + logging.info(f"\n{dataset} recall: {eval_metrics['recall'].avg:.4f}") + logging.info(f"\n{dataset} fscore: {eval_metrics['fscore'].avg:.4f}") + logging.info(f"\n{dataset} iou: {eval_metrics['iou'].avg:.4f}") return eval_metrics @@ -518,7 +519,7 @@ def train(cfg: DictConfig) -> None: # info on the hdf5 name samples_size = get_key_def("input_dim", cfg['dataset'], expected_type=int, default=256) overlap = get_key_def("overlap", cfg['dataset'], expected_type=int, default=0) - min_annot_perc = get_key_def('min_annotated_percent', cfg['dataset'], expected_type=int, default=0) + min_annot_perc = get_key_def('min_annotated_percent', cfg['dataset'], default=0) samples_folder_name = ( f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands_{experiment_name}' ) @@ -681,7 +682,8 @@ def train(cfg: DictConfig) -> None: device=device, scale=scale, vis_params=vis_params, - debug=debug) + debug=debug, + dontcare=dontcare_val) val_loss = val_report['loss'].avg if 'val_log' in locals(): # only save the value if a tracker is setup if batch_metrics is not None: @@ -738,7 +740,8 @@ def train(cfg: DictConfig) -> None: dataset='tst', scale=scale, vis_params=vis_params, - device=device) + device=device, + dontcare=dontcare_val) if 'tst_log' in locals(): # only save the value if a tracker is setup tst_log.add_values(tst_report, num_epochs) diff --git a/utils/metrics.py b/utils/metrics.py index b9728933..dc295a73 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,17 +1,25 @@ +from ast import Raise +from xml.dom import ValidationErr import numpy as np from sklearn.metrics import classification_report +from math import sqrt +from torch import IntTensor +from torchmetrics import JaccardIndex min_val = 1e-6 -def create_metrics_dict(num_classes): - num_classes = num_classes if num_classes == 1 else num_classes + 1 +def create_metrics_dict(num_classes, ignore_index=None): + + num_classes = num_classes + 1 if num_classes == 1 else num_classes + metrics_dict = {'precision': AverageMeter(), 'recall': AverageMeter(), 'fscore': AverageMeter(), 'loss': AverageMeter(), 'iou': AverageMeter()} for i in range(0, num_classes): - metrics_dict['precision_' + str(i)] = AverageMeter() - metrics_dict['recall_' + str(i)] = AverageMeter() - metrics_dict['fscore_' + str(i)] = AverageMeter() - metrics_dict['iou_' + str(i)] = AverageMeter() + if ignore_index != i: + metrics_dict['precision_' + str(i)] = AverageMeter() + metrics_dict['recall_' + str(i)] = AverageMeter() + metrics_dict['fscore_' + str(i)] = AverageMeter() + metrics_dict['iou_' + str(i)] = AverageMeter() # Add overall non-background iou metric metrics_dict['iou_nonbg'] = AverageMeter() @@ -59,7 +67,15 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1 """Computes precision, recall and f-score for each class and average of all classes. http://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html """ - class_report = classification_report(label.cpu(), pred.cpu(), output_dict=True, zero_division=1) + pred = pred.cpu() + label = label.cpu() + pred[label == ignore_index] = ignore_index + + # Required to remove ignore_index from scikit-learn's classification report + n = max(IntTensor.item(pred.amax()), IntTensor.item(label.amax())) + labels = np.arange(n+1) + + class_report = classification_report(label, pred, labels=labels, output_dict=True, zero_division=1) class_score = {} for key, value in class_report.items(): @@ -77,35 +93,51 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1 return metrics_dict -def iou(pred, label, batch_size, num_classes, metric_dict, only_present=True): +def iou(pred, label, batch_size, num_classes, metric_dict, ignore_index=None): """Calculate the intersection over union class-wise and mean-iou""" - ious = [] - num_classes = num_classes if num_classes == 1 else num_classes + 1 - pred = pred.cpu() - label = label.cpu() - for i in range(num_classes): - c_label = label == i - if only_present and c_label.sum() == 0: - ious.append(np.nan) - continue - c_pred = pred == i - intersection = (c_pred & c_label).float().sum() - union = (c_pred | c_label).float().sum() - iou = (intersection + min_val) / (union + min_val) # minimum value added to avoid Zero division - ious.append(iou) - metric_dict['iou_' + str(i)].update(iou.item(), batch_size) - # Add overall non-background iou metric - c_label = (1 <= label) & (label <= num_classes - 1) - c_pred = (1 <= pred) & (pred <= num_classes - 1) - intersection = (c_pred & c_label).float().sum() - union = (c_pred | c_label).float().sum() - iou = (intersection + min_val) / (union + min_val) # minimum value added to avoid Zero division - metric_dict['iou_nonbg'].update(iou.item(), batch_size) - - mean_IOU = np.nanmean(ious) - if (not only_present) or (not np.isnan(mean_IOU)): - metric_dict['iou'].update(mean_IOU, batch_size) + num_classes = num_classes + 1 if num_classes == 1 else num_classes + # Torchmetrics cannot handle ignore_index that are not in range 0 -> num_classes-1. + # if invalid ignore_index is provided, invalid values (e.g. -1) will be set to 0 + # and no ignore_index will be used. + if ignore_index and ignore_index not in range(0, num_classes-1): + pred[label == ignore_index] = 0 + label[label == ignore_index] = 0 + ignore_index = None + + cls_lst = [j for j in range(0, num_classes)] + if ignore_index is not None: + cls_lst.remove(ignore_index) + + jaccard = JaccardIndex(num_classes=num_classes, + average='none', + ignore_index=ignore_index, + absent_score=1) + cls_ious = jaccard(pred, label) + + + if len(cls_ious) > 1: + for i in range(len(cls_lst)): + metric_dict['iou_' + str(cls_lst[i])].update(cls_ious[i], batch_size) + + elif len(cls_ious) == 1: + if f"iou_{cls_lst[0]}" in metric_dict.keys(): + metric_dict['iou_' + str(cls_lst[0])].update(cls_ious, batch_size) + + jaccard_nobg = JaccardIndex(num_classes=num_classes, + average='macro', + ignore_index=0, + absent_score=1) + iou_nobg = jaccard_nobg(pred, label) + metric_dict['iou_nonbg'].update(iou_nobg.item(), batch_size) + + jaccard = JaccardIndex(num_classes=num_classes, + average='macro', + ignore_index=ignore_index, + absent_score=1) + mean_iou = jaccard(pred, label) + + metric_dict['iou'].update(mean_iou, batch_size) return metric_dict #### Benchmark Metrics #### @@ -122,7 +154,7 @@ class ComputePixelMetrics(): def __init__(self, label, pred, num_classes): self.label = label self.pred = pred - self.num_classes = num_classes if num_classes == 1 else num_classes + 1 + self.num_classes = num_classes + 1 if num_classes == 1 else num_classes def update(self, metric_func): metric = {} From e5ed1049f07af5e377ae230f8276109594622b3c Mon Sep 17 00:00:00 2001 From: CharlesAuthier Date: Tue, 4 Oct 2022 15:33:40 -0400 Subject: [PATCH 3/5] update hydra-core for 1.2 --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 51e771c0..97f23301 100644 --- a/environment.yml +++ b/environment.yml @@ -7,7 +7,7 @@ dependencies: - docker-py>=4.4.4 - geopandas>=0.10.2 - h5py>=3.7 - - hydra-core>=1.1.0 + - hydra-core>=1.2.0 - pip - pystac>=0.3.0 - pytest>=7.1 From 61341007e9ef6f5e9a5b815342b83d07422661c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Tavon?= <34774759+remtav@users.noreply.github.com> Date: Tue, 4 Oct 2022 16:08:49 -0400 Subject: [PATCH 4/5] tiling: suggested refactoring for "sampling" (#348) --- README.md | 2 +- config/README.md | 11 +++-- config/dataset/README.md | 4 +- .../dataset/test_ci_segmentation_binary.yaml | 7 +-- .../test_ci_segmentation_binary_stac.yaml | 9 +--- .../test_ci_segmentation_multiclass.yaml | 9 +--- config/gdl_config_template.yaml | 7 +-- config/tiling/default_tiling.yaml | 8 ++++ tests/dataset/test_aoi.py | 44 +++++++++---------- tests/test_verify_segmentation.py | 2 +- tests/{sampling => tiling}/header.csv | 0 tests/{sampling => tiling}/point_virgule.csv | 0 ...ling_segmentation_binary-multiband_ci.csv} | 0 ...segmentation_binary-singleband-url_ci.csv} | 0 ...ing_segmentation_binary-singleband_ci.csv} | 0 .../tiling_segmentation_binary-stac_ci.csv} | 0 .../tiling_segmentation_binary_ci.csv} | 0 .../tiling_segmentation_multiclass_ci.csv} | 0 tests/utils/test_geoutils.py | 2 +- tests/utils/test_utils.py | 4 +- ..._segmentation.py => tiling_segmentation.py | 20 ++++----- train_segmentation.py | 12 ++--- utils/augmentation.py | 2 +- utils/utils.py | 8 ++-- 24 files changed, 72 insertions(+), 79 deletions(-) create mode 100644 config/tiling/default_tiling.yaml rename tests/{sampling => tiling}/header.csv (100%) rename tests/{sampling => tiling}/point_virgule.csv (100%) rename tests/{sampling/sampling_segmentation_binary-multiband_ci.csv => tiling/tiling_segmentation_binary-multiband_ci.csv} (100%) rename tests/{sampling/sampling_segmentation_binary-singleband-url_ci.csv => tiling/tiling_segmentation_binary-singleband-url_ci.csv} (100%) rename tests/{sampling/sampling_segmentation_binary-singleband_ci.csv => tiling/tiling_segmentation_binary-singleband_ci.csv} (100%) rename tests/{sampling/sampling_segmentation_binary-stac_ci.csv => tiling/tiling_segmentation_binary-stac_ci.csv} (100%) rename tests/{sampling/sampling_segmentation_binary_ci.csv => tiling/tiling_segmentation_binary_ci.csv} (100%) rename tests/{sampling/sampling_segmentation_multiclass_ci.csv => tiling/tiling_segmentation_multiclass_ci.csv} (100%) rename sampling_segmentation.py => tiling_segmentation.py (97%) diff --git a/README.md b/README.md index 0e6852cd..94a49365 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ cd geo-deep-learning 2. Run the wanted script (for segmentation). ```shell # Creating the hdf5 from the raw data -python GDL.py mode=sampling +python GDL.py mode=tiling # Training the neural network python GDL.py mode=train # Inference on the data diff --git a/config/README.md b/config/README.md index 187fe168..60ec0dbb 100644 --- a/config/README.md +++ b/config/README.md @@ -65,7 +65,7 @@ The **_tracker section_** is set to `None` by default, but will still log the in If you want to set a tracker you can change the value in the config file or add the tracker parameter at execution time via the command line `python GDL.py tracker=mlflow mode=train`. The **_inference section_** contains the information to execute the inference job (more options will follow soon). -This part doesn't need to be filled if you want to launch sampling, train or hyperparameters search mode only. +This part doesn't need to be filled if you want to launch tiling, train or hyperparameters search mode only. The **_task section_** manages the executing task. `Segmentation` is the default task since it's the primary task of GDL. However, the goal will be to add tasks as need be. The `GDL.py` code simply executes the main function from the `task_mode.py` in the main folder of GDL. @@ -83,7 +83,7 @@ general: max_epochs: 2 # for train only min_epochs: 1 # for train only raw_data_dir: data - raw_data_csv: tests/sampling/sampling_segmentation_binary_ci.csv + raw_data_csv: tests/tiling/tiling_segmentation_binary_ci.csv sample_data_dir: data # where the hdf5 will be saved state_dict_path: save_weights_dir: saved_model/${general.project_name} @@ -95,10 +95,10 @@ If `True`, will save the config in the log folder. #### Mode Section ```YAML -mode: {sampling, train, inference, evaluate, hyperparameters_search} +mode: {tiling, train, inference, evaluate, hyperparameters_search} ``` -**GDL** has five modes: sampling, train, evaluate, inference and hyperparameters search. -- *sampling*, generates `hdf5` files from a folder containing folders for each individual image with their ground truth. +**GDL** has five modes: tiling, train, evaluate, inference and hyperparameters search. +- *tiling*, generates .geotiff and .geojson [chips](https://torchgeo.readthedocs.io/en/latest/user/glossary.html#term-chip) from each source aoi (image & ground truth). - *train*, will train the model specified with all the parameters in `training`, `trainer`, `optimizer`, `callbacks` and `scheduler`. The outcome will be `.pth` weights. - *evaluate*, this function needs to be filled with images, their ground truth and a weight for the model. At the end of the evaluation you will obtain statistics on those images. - *inference*, unlike the evaluation, the inference doesn't need a ground truth. The inference will produce a prediction on the content of the images fed to the model. Depending on the task, the outcome file will differ. @@ -148,4 +148,3 @@ new: $ python GDL.py --config-name=/path/to/new/gdl_config.yaml mode=train ``` - diff --git a/config/dataset/README.md b/config/dataset/README.md index 3c31e187..c3f9d339 100644 --- a/config/dataset/README.md +++ b/config/dataset/README.md @@ -5,7 +5,7 @@ ### Input dimensions and overlap These parameters respectively set the width and length of a single sample and stride from one sample to another as -outputted by sampling_segmentation.py. Default to 256 and 0, respectively. +outputted by tiling_segmentation.py. Default to 256 and 0, respectively. ### Train/validation percentage @@ -31,7 +31,7 @@ For more information on the concept of stratified sampling, see [this Medium art ### Modalities -Bands to be selected during the sampling process. Order matters (ie "BGR" is not equal to "RGB"). +Bands to be selected during the tiling process. Order matters (ie "BGR" is not equal to "RGB"). The use of this feature for band selection is a work in progress. It currently serves to indicate how many bands are in source imagery. diff --git a/config/dataset/test_ci_segmentation_binary.yaml b/config/dataset/test_ci_segmentation_binary.yaml index ce6af4d9..26a20cab 100644 --- a/config/dataset/test_ci_segmentation_binary.yaml +++ b/config/dataset/test_ci_segmentation_binary.yaml @@ -2,10 +2,6 @@ dataset: # dataset-wide name: - input_dim: 32 - overlap: - use_stratification: False - train_val_percent: {'trn':0.7, 'val':0.3, 'tst':0} raw_data_csv: ${general.raw_data_csv} raw_data_dir: ${general.raw_data_dir} download_data: False @@ -16,12 +12,11 @@ dataset: # ground truth attribute_field: properties/class attribute_values: [1] - min_annotated_percent: class_name: # will follow in the next version classes_dict: {'BUIL':1} class_weights: ignore_index: -1 # outputs - sample_data_dir: ${general.sample_data_dir} + tiling_data_dir: ${general.tiling_data_dir} diff --git a/config/dataset/test_ci_segmentation_binary_stac.yaml b/config/dataset/test_ci_segmentation_binary_stac.yaml index 51b95d9e..e338599b 100644 --- a/config/dataset/test_ci_segmentation_binary_stac.yaml +++ b/config/dataset/test_ci_segmentation_binary_stac.yaml @@ -2,11 +2,7 @@ dataset: # dataset-wide name: - input_dim: 32 - overlap: - use_stratification: False - train_val_percent: {'trn':0.7, 'val':0.3, 'tst':0} - raw_data_csv: tests/sampling/sampling_segmentation_binary-stac_ci.csv + raw_data_csv: tests/tiling/tiling_segmentation_binary-stac_ci.csv raw_data_dir: ${general.raw_data_dir} download_data: False @@ -16,12 +12,11 @@ dataset: # ground truth attribute_field: attribute_values: - min_annotated_percent: class_name: # will follow in the next version classes_dict: {'BUIL':1} class_weights: ignore_index: -1 # outputs - sample_data_dir: ${general.sample_data_dir} + tiling_data_dir: ${general.tiling_data_dir} diff --git a/config/dataset/test_ci_segmentation_multiclass.yaml b/config/dataset/test_ci_segmentation_multiclass.yaml index 6105e59e..a6d39d23 100644 --- a/config/dataset/test_ci_segmentation_multiclass.yaml +++ b/config/dataset/test_ci_segmentation_multiclass.yaml @@ -2,11 +2,7 @@ dataset: # dataset-wide name: - input_dim: 32 - overlap: - use_stratification: False - train_val_percent: {'trn':0.7, 'val':0.3, 'tst':0} - raw_data_csv: tests/sampling/sampling_segmentation_multiclass_ci.csv + raw_data_csv: tests/tiling/tiling_segmentation_multiclass_ci.csv raw_data_dir: ${general.raw_data_dir} download_data: False @@ -16,12 +12,11 @@ dataset: # ground truth attribute_field: properties/Quatreclasses attribute_values: [1,2,3,4] - min_annotated_percent: class_name: # will follow in the next version classes_dict: {'WAER':1, 'FORE':2, 'ROAI':3, 'BUIL':4} class_weights: ignore_index: 255 # outputs - sample_data_dir: ${general.sample_data_dir} + tiling_data_dir: ${general.tiling_data_dir} diff --git a/config/gdl_config_template.yaml b/config/gdl_config_template.yaml index 3fdb044b..4ff1bbd6 100644 --- a/config/gdl_config_template.yaml +++ b/config/gdl_config_template.yaml @@ -1,6 +1,7 @@ defaults: - model: gdl_unet - verify: default_verify + - tiling: default_tiling - training: default_training - loss: binary/softbce - optimizer: adamw @@ -31,10 +32,10 @@ general: max_epochs: 2 # for train only min_epochs: 1 # for train only raw_data_dir: dataset - raw_data_csv: tests/sampling/sampling_segmentation_binary_ci.csv - sample_data_dir: dataset # where the hdf5 will be saved + raw_data_csv: tests/tiling/tiling_segmentation_binary_ci.csv + tiling_data_dir: dataset # where the hdf5 will be saved save_weights_dir: saved_model/${general.project_name} print_config: True # save the config in the log folder -mode: {verify, sampling, train, inference, evaluate} +mode: {verify, tiling, train, inference, evaluate} debug: True #False # will print the complete yaml config plus run a validation test diff --git a/config/tiling/default_tiling.yaml b/config/tiling/default_tiling.yaml new file mode 100644 index 00000000..462ce5ef --- /dev/null +++ b/config/tiling/default_tiling.yaml @@ -0,0 +1,8 @@ +# @package _global_ +tiling: + tiling_data_dir: ${general.tiling_data_dir} + train_val_percent: {'trn':0.7, 'val':0.3, 'tst':0} + chip_size: 32 + overlap_size: + min_annot_perc: 1 + use_stratification: False \ No newline at end of file diff --git a/tests/dataset/test_aoi.py b/tests/dataset/test_aoi.py index 214617dd..a0e3648e 100644 --- a/tests/dataset/test_aoi.py +++ b/tests/dataset/test_aoi.py @@ -18,7 +18,7 @@ class Test_AOI(object): def test_multiband_input(self): """Tests reading a multiband raster as input""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") row = data[0] aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) src_count = rasterio.open(aoi.raster_raw_input).count @@ -28,7 +28,7 @@ def test_multiband_input(self): def test_multiband_input_band_selection(self): """Tests reading a multiband raster as input with band selection""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") row = data[0] bands_request = [2, 1] aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=bands_request) @@ -43,7 +43,7 @@ def test_multiband_input_band_selection(self): def test_multiband_input_band_selection_from_letters(self): """Tests error when selecting bands from a multiband raster using letters, not integers""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") row = data[0] bands_request = ["R", "G"] with pytest.raises(ValueError): @@ -53,7 +53,7 @@ def test_multiband_input_band_selection_from_letters(self): def test_multiband_input_band_selection_too_many(self): """Tests error when selecting too many bands from a multiband raster""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") row = data[0] bands_request = [1, 2, 3, 4, 5] with pytest.raises(ValueError): @@ -62,7 +62,7 @@ def test_multiband_input_band_selection_too_many(self): def test_singleband_input(self): """Tests reading a singleband raster as input with ${dataset.bands} pattern""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-singleband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv") for row in data: aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B']) aoi.close_raster() @@ -70,7 +70,7 @@ def test_singleband_input(self): def test_stac_input(self): """Tests singleband raster referenced by stac item as input""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-stac_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv") for row in data: aoi = AOI( raster=row['tif'], label=row['gpkg'], split=row['split'], @@ -80,7 +80,7 @@ def test_stac_input(self): def test_stac_url_input(self): """Tests download of singleband raster as url path referenced by a stac item""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-singleband-url_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-singleband-url_ci.csv") for row in data: aoi = AOI( raster=row['tif'], label=row['gpkg'], split=row['split'], @@ -94,7 +94,7 @@ def test_stac_url_input(self): def test_missing_label(self): """Tests error when missing label file""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: row['gpkg'] = "missing_file.gpkg" with pytest.raises(AttributeError): @@ -131,7 +131,7 @@ def test_bounds_iou(self) -> None: def test_corrupt_raster(self) -> None: """Tests error when reading a corrupt file""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: row['tif'] = "tests/data/massachusetts_buildings_kaggle/corrupt_file.tif" with pytest.raises(BaseException): @@ -141,7 +141,7 @@ def test_corrupt_raster(self) -> None: def test_image_only(self) -> None: """Tests AOI creation with image only, ie no label""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: aoi = AOI(raster=row['tif'], label=None) assert aoi.label is None @@ -150,7 +150,7 @@ def test_image_only(self) -> None: def test_missing_raster(self) -> None: """Tests error when pointing to missing raster""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: row['tif'] = "missing_raster.tif" with pytest.raises(RasterioIOError): @@ -160,7 +160,7 @@ def test_missing_raster(self) -> None: def test_wrong_split(self) -> None: """Tests error when setting a wrong split, ie not 'trn', 'tst' or 'inference'""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: row['split'] = "missing_split" with pytest.raises(ValueError): @@ -170,7 +170,7 @@ def test_wrong_split(self) -> None: def test_download_data(self) -> None: """Tests download data""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: row['tif'] = "http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/" \ "spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-N.tif" @@ -183,7 +183,7 @@ def test_download_data(self) -> None: def test_stac_input_missing_band(self): """Tests error when requestinga non-existing singleband input rasters from stac item""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-stac_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv") for row in data: with pytest.raises(ValueError): aoi = AOI( @@ -206,7 +206,7 @@ def test_stac_input_empty_band_request(self): def test_no_intersection(self) -> None: """Tests error testing no intersection between raster and label""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-multiband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: row['gpkg'] = "tests/data/new_brunswick_aerial/BakerLake_2017_clipped.gpkg" aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) @@ -217,7 +217,7 @@ def test_no_intersection(self) -> None: def test_write_multiband_from_single_band(self) -> None: """Tests the 'write_multiband' method""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-singleband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv") row = data[0] aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B'], write_multiband=True, root_dir="data") @@ -228,7 +228,7 @@ def test_write_multiband_from_single_band(self) -> None: def test_write_multiband_from_single_band_url(self) -> None: """Tests the 'write_multiband' method with singleband raster as URL""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-singleband-url_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-singleband-url_ci.csv") for row in data: aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['R', 'G', 'B'], write_multiband=True, root_dir="data", download_data=True) @@ -241,7 +241,7 @@ def test_write_multiband_from_single_band_url(self) -> None: def test_download_true_not_url(self) -> None: """Tests AOI creation if download_data set to True, but not necessary (local image)""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-singleband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv") for row in data: aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], download_data=True, raster_bands_request=['R', 'G', 'B']) @@ -250,7 +250,7 @@ def test_download_true_not_url(self) -> None: def test_raster_stats_from_stac(self) -> None: """Tests the calculation of statistics of raster data as stac item from an AOI instance""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-stac_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv") bands_request = ['red', 'green', 'blue'] expected_stats = { 'red': {'statistics': {'minimum': 0, 'maximum': 255, 'mean': 10.133578590682399, 'median': 9.0, @@ -273,7 +273,7 @@ def test_raster_stats_from_stac(self) -> None: def test_raster_stats_not_stac(self) -> None: """Tests the calculation of statistics of local multiband raster data from an AOI instance""" extract_archive(src="tests/data/new_brunswick_aerial.zip") - data = read_csv("tests/sampling/sampling_segmentation_multiclass_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_multiclass_ci.csv") expected_stats = { 'band_0': {'statistics': {'minimum': 11, 'maximum': 254, 'mean': 159.36075617930456, 'median': 165.0, 'std': 48.9924913616138}}, @@ -295,7 +295,7 @@ def test_raster_stats_not_stac(self) -> None: def test_to_dict(self): """Test the 'to_dict()' method on an AOI instance""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-stac_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-stac_ci.csv") for row in data: aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], raster_bands_request=['red', 'green', 'blue']) aoi.to_dict() @@ -304,7 +304,7 @@ def test_to_dict(self): def test_for_multiprocessing(self) -> None: """Tests multiprocessing on AOI instances""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_multiclass_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_multiclass_ci.csv") inputs = [] for row in data: aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], for_multiprocessing=True) diff --git a/tests/test_verify_segmentation.py b/tests/test_verify_segmentation.py index 21664c01..815f528f 100644 --- a/tests/test_verify_segmentation.py +++ b/tests/test_verify_segmentation.py @@ -12,7 +12,7 @@ class TestVerify(object): def test_verify_per_aoi(self): """Test stats outputs from an AOI""" extract_archive(src="tests/data/new_brunswick_aerial.zip") - data = read_csv("tests/sampling/sampling_segmentation_multiclass_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_multiclass_ci.csv") aoi = AOI(raster=data[0]['tif'], label=data[0]['gpkg'], split=data[0]['split']) aoi_dict, error = verify_per_aoi( aoi=aoi, diff --git a/tests/sampling/header.csv b/tests/tiling/header.csv similarity index 100% rename from tests/sampling/header.csv rename to tests/tiling/header.csv diff --git a/tests/sampling/point_virgule.csv b/tests/tiling/point_virgule.csv similarity index 100% rename from tests/sampling/point_virgule.csv rename to tests/tiling/point_virgule.csv diff --git a/tests/sampling/sampling_segmentation_binary-multiband_ci.csv b/tests/tiling/tiling_segmentation_binary-multiband_ci.csv similarity index 100% rename from tests/sampling/sampling_segmentation_binary-multiband_ci.csv rename to tests/tiling/tiling_segmentation_binary-multiband_ci.csv diff --git a/tests/sampling/sampling_segmentation_binary-singleband-url_ci.csv b/tests/tiling/tiling_segmentation_binary-singleband-url_ci.csv similarity index 100% rename from tests/sampling/sampling_segmentation_binary-singleband-url_ci.csv rename to tests/tiling/tiling_segmentation_binary-singleband-url_ci.csv diff --git a/tests/sampling/sampling_segmentation_binary-singleband_ci.csv b/tests/tiling/tiling_segmentation_binary-singleband_ci.csv similarity index 100% rename from tests/sampling/sampling_segmentation_binary-singleband_ci.csv rename to tests/tiling/tiling_segmentation_binary-singleband_ci.csv diff --git a/tests/sampling/sampling_segmentation_binary-stac_ci.csv b/tests/tiling/tiling_segmentation_binary-stac_ci.csv similarity index 100% rename from tests/sampling/sampling_segmentation_binary-stac_ci.csv rename to tests/tiling/tiling_segmentation_binary-stac_ci.csv diff --git a/tests/sampling/sampling_segmentation_binary_ci.csv b/tests/tiling/tiling_segmentation_binary_ci.csv similarity index 100% rename from tests/sampling/sampling_segmentation_binary_ci.csv rename to tests/tiling/tiling_segmentation_binary_ci.csv diff --git a/tests/sampling/sampling_segmentation_multiclass_ci.csv b/tests/tiling/tiling_segmentation_multiclass_ci.csv similarity index 100% rename from tests/sampling/sampling_segmentation_multiclass_ci.csv rename to tests/tiling/tiling_segmentation_multiclass_ci.csv diff --git a/tests/utils/test_geoutils.py b/tests/utils/test_geoutils.py index 49f47f19..083b9dd6 100644 --- a/tests/utils/test_geoutils.py +++ b/tests/utils/test_geoutils.py @@ -10,7 +10,7 @@ class TestGeoutils(object): def test_multiband_vrt_from_single_band(self) -> None: """Tests the 'stack_singlebands_vrt' utility""" extract_archive(src="tests/data/spacenet.zip") - data = read_csv("tests/sampling/sampling_segmentation_binary-singleband_ci.csv") + data = read_csv("tests/tiling/tiling_segmentation_binary-singleband_ci.csv") row = data[0] bands_request = ['R', 'G', 'B'] aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'], diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 87f80949..2307da60 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -7,7 +7,7 @@ class TestUtils(object): def test_wrong_seperation(self) -> None: extract_archive(src="tests/data/spacenet.zip") with pytest.raises(TypeError): - data = read_csv("tests/sampling/point_virgule.csv") + data = read_csv("tests/tiling/point_virgule.csv") ##for row in data: ##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) @@ -15,6 +15,6 @@ def test_wrong_seperation(self) -> None: def test_with_header_in_csv(self) -> None: extract_archive(src="tests/data/spacenet.zip") with pytest.raises(TypeError): - data = read_csv("tests/sampling/header.csv") + data = read_csv("tests/tiling/header.csv") ##for row in data: ##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) \ No newline at end of file diff --git a/sampling_segmentation.py b/tiling_segmentation.py similarity index 97% rename from sampling_segmentation.py rename to tiling_segmentation.py index 5d2be78d..6ae709e4 100644 --- a/sampling_segmentation.py +++ b/tiling_segmentation.py @@ -339,14 +339,14 @@ def main(cfg: DictConfig) -> None: # RAW DATA PARAMETERS data_path = get_key_def('raw_data_dir', cfg['dataset'], to_path=True, validate_path_exists=True) csv_file = get_key_def('raw_data_csv', cfg['dataset'], to_path=True, validate_path_exists=True) - out_path = get_key_def('sample_data_dir', cfg['dataset'], default=data_path, to_path=True, validate_path_exists=True) - - # SAMPLE PARAMETERS - samples_size = get_key_def('input_dim', cfg['dataset'], default=256, expected_type=int) - overlap = get_key_def('overlap', cfg['dataset'], default=0) - min_annot_perc = get_key_def('min_annotated_percent', cfg['dataset'], default=0) - val_percent = get_key_def('train_val_percent', cfg['dataset'], default=0.3)['val'] * 100 - samples_folder_name = f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}' \ + + # TILING PARAMETERS + out_path = get_key_def('tiling_data_dir', cfg['tiling'], default=data_path, to_path=True, validate_path_exists=True) + samples_size = get_key_def('chip_size', cfg['tiling'], default=256, expected_type=int) + overlap = get_key_def('overlap_size', cfg['tiling'], default=0) + min_annot_perc = get_key_def('min_annot_perc', cfg['tiling'], default=0) + val_percent = get_key_def('train_val_percent', cfg['tiling'], default={'val': 0.3})['val'] * 100 + samples_folder_name = f'chips{samples_size}_overlap{overlap}_min-annot{min_annot_perc}' \ f'_{num_bands}bands_{cfg.general.project_name}' samples_dir = out_path.joinpath(samples_folder_name) if samples_dir.is_dir(): @@ -381,12 +381,12 @@ def main(cfg: DictConfig) -> None: raise logging.critical(ValueError(f'\nAttribute value "{item}" is {type(item)}, expected int.')) # OPTIONAL - use_stratification = cfg.dataset.use_stratification if cfg.dataset.use_stratification is not None else False + use_stratification = cfg.tiling.use_stratification if cfg.tiling.use_stratification is not None else False if use_stratification: stratd = { 'trn': {'total_pixels': 0, 'total_counts': {}, 'total_props': {}}, 'val': {'total_pixels': 0, 'total_counts': {}, 'total_props': {}}, - 'strat_factor': cfg['dataset']['use_stratification'] + 'strat_factor': cfg['tiling']['use_stratification'] } else: stratd = None diff --git a/train_segmentation.py b/train_segmentation.py index ad59a571..7e660dbb 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -517,16 +517,16 @@ def train(cfg: DictConfig) -> None: # PARAMETERS FOR hdf5 SAMPLES # info on the hdf5 name - samples_size = get_key_def("input_dim", cfg['dataset'], expected_type=int, default=256) - overlap = get_key_def("overlap", cfg['dataset'], expected_type=int, default=0) - min_annot_perc = get_key_def('min_annotated_percent', cfg['dataset'], default=0) + samples_size = get_key_def('chip_size', cfg['tiling'], default=256, expected_type=int) + overlap = get_key_def('overlap_size', cfg['tiling'], default=0) + min_annot_perc = get_key_def('min_annot_perc', cfg['tiling'], default=0) samples_folder_name = ( - f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands_{experiment_name}' + f'chips{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands_{experiment_name}' ) data_path = get_key_def('raw_data_dir', cfg['dataset'], to_path=True, validate_path_exists=True) - my_hdf5_path = get_key_def('sample_data_dir', cfg['dataset'], default=data_path, to_path=True, - validate_path_exists=True) + my_hdf5_path = get_key_def('tiling_data_dir', cfg['tiling'], default=data_path, to_path=True, + validate_path_exists=True) samples_folder = my_hdf5_path.joinpath(samples_folder_name).resolve(strict=True) logging.info("\nThe HDF5 directory used '{}'".format(samples_folder)) diff --git a/utils/augmentation.py b/utils/augmentation.py index 98563621..06d05699 100644 --- a/utils/augmentation.py +++ b/utils/augmentation.py @@ -1,7 +1,7 @@ # WARNING: data being augmented may be scaled to (0,1) rather, for example, (0,255). # Therefore, implementing radiometric # augmentations (ex.: changing hue, saturation, brightness, contrast) may give undesired results. -# Scaling process is done in sampling_segmentation.py l.215 +# Scaling process is done in tiling_segmentation.py l.215 import logging import numbers from typing import Sequence diff --git a/utils/utils.py b/utils/utils.py index a3e3c7ad..9411b701 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -479,11 +479,11 @@ def print_config( save_dir = tree.add('Saving directory', style=style, guide_style=style) save_dir.add(os.getcwd()) - if config.get('mode') == 'sampling': + if config.get('mode') == 'tiling': fields += ( "general.raw_data_dir", "general.raw_data_csv", - "general.sample_data_dir", + "general.tiling_data_dir", ) elif config.get('mode') == 'train': fields += ( @@ -493,14 +493,14 @@ def print_config( 'callbacks', 'scheduler', 'augmentation', - "general.sample_data_dir", + "general.tiling_data_dir", "general.save_weights_dir", ) elif config.get('mode') == 'inference': fields += ( "inference", "model", - "general.sample_data_dir", + "general.tiling_data_dir", ) if config.get('tracker'): From cdea293bd810ea704005615353ebd76e49e2f898 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Tavon?= <34774759+remtav@users.noreply.github.com> Date: Wed, 5 Oct 2022 15:06:12 -0400 Subject: [PATCH 5/5] test_utils.py: add tests for updating utility (#357) utils.py: apply update to more params --- tests/utils/gdl_current_test.pth.tar | Bin 0 -> 47855 bytes tests/utils/gdl_pre20_test.pth.tar | Bin 0 -> 9391 bytes tests/utils/test_utils.py | 32 ++++++++- utils/utils.py | 97 ++++++++++++++++++--------- 4 files changed, 97 insertions(+), 32 deletions(-) create mode 100644 tests/utils/gdl_current_test.pth.tar create mode 100644 tests/utils/gdl_pre20_test.pth.tar diff --git a/tests/utils/gdl_current_test.pth.tar b/tests/utils/gdl_current_test.pth.tar new file mode 100644 index 0000000000000000000000000000000000000000..298be2ea3e1d1170be16449f4b3d480282f38424 GIT binary patch literal 47855 zcmdVD2Y6k@wf=27h~7I0u)zjPvQ00JF}UK&HU?A0k#y9ocIHSfIGf&kZ=r-vfKWmS zBtSv}BtU4P69^?hD1ih>=!E}oz59%`j;x#He&6%l=UzXsH0xcn=3P^EnLV@jnK@;t zr3VdKe)&O*{abC&8iR77zPV*yzNR79ovZGe+csg3RR$eU_V36g)54&kbC$t9-_=>) zTv%etpgGH;b>%{?y;xYXzOy~wl&kOTXsm8%sqeN#OH;)-l6z*!!cxPIDlDBAmKi!{ z1^jB;^W8b)D=h0=>vF|>1rFW0mX3UwR?1RQSPtLi)4~cv=U5FwzOlA@VOPFTQO}>2 zj;4yS9SbW8E6!OFV`p7!zP>xNuT)=KTi4Ul*4@%kTU%jqRTNfM5uIJ#Eu9^?w!$hW z&RGhtx%q{KRi~?djcvK6Vqvwku(~Q(><&Hk-GwzyEY7JQqpY=|v$MOf=1Fr_Q8dNQ zwt4wbTc4|M&KK6IukUPYBLvDXRy5=rb3JX{8TH!DD<4-_Cn>Be%8t(Yh4s?H`dOWZ z>|Zscy=q89VT0oe8;WVvh*6`eM(kWQYNx_Rih0PSsv*;=hRiB#d|Y7@AtQIF8a1+N z`<)7#_7nwBjjZOC%`47XiQw{W`SyHAsW~=lUa_Zc&T_!DbQ^GC^XYP{rMtRK-y$sx z&U(3dVME9jhBRBRZJ8Ff8rr;Uv-QHzv@mRFb4Bxt%`4R_PqlS*#U#saJzY`Qx_Olo zn^$cf(o7R>qvAt8*U;I~wy-e##A5TBCpE9ttd^)s3f0X7+N`*mH*a3L8UC7~YMs#2 zQ5Yfik!fL6WZyO|YzMn%PI@}VwSKz=@j!9vsGNvt!X=f2N zw1mPgZU}aCHsp&Hv$_N0YC7qP!md<%x3n;tAY48T)xHt>Ju;0+3%gr%e$e%FwA8DV zvBWo4@%7T~VR7#1;vAP0#`o0G5+mnuv<|AB#{iu$XBf%0&Zc4wZBncmyoiCbOB^m~ z)JtO{ElljGE7Hq$q!#Dr^Ie5WYQ@QEVXxSVQ_{lJeyupmvZ@`Y5$Ck59Vc0Q(_MTs zlETcz+i~x#$E+cfW^1fjX<;AZvNoBW7WTFA(pUXvReL{I?3|>qzp6dCdBU8nVr|EM75g}fTLF)cL325wFZEk@gSkkqztfCWxC94;YW z+tt?7)Y4I`W)QX1*T#}6_Aqs7sod;nYR!(OyoEZ~h1!-B+86JNj_h#m&V{CYcWt3l z4bqhs3b8>#S}2w`NOPwZ-R+9*Nec7S3XOEf&eSw(ga4=s)n!#E%+Cg+opLl7X{`ln zt%Ygfu-ICQ(!%%u8?99;C~GW@+kD8{D$UwzfmP^mSD_=4!jX%&)loyKx((N()50;a zeU42F#|_jz$GhTAND3#aeHPL_JN{?wQwln3o>I_-lNd>M1!>PMsIAL&*EiP|TMo+? zPIhOR>})id!PArv+3DyMLOwMuoEAeqJuRF;$Yp1pwoWdJg)=#n&&=B7WNVYNTqVv< z3g^U1Oee9Vy5-!oa9-?|^V7lwR&MDMYnSqh!iCK%P0vbP(n`F@m3VPd_yMgvmecvz zQpsCTvK=mVGHcGabrmiV%cW`IvdD6ITKJ){_-hF{vktw23S5z$i!L#`D;?ccN#W|n zJN26EYMSfns_n?N=L^@Whp$Tu*T?4iQChg6yoYtn^m^QAL^nC2A18&IRi8`g?$xPD z*+~JCv~Wx5Jk-pEjv;=lNN-CEKZ&Hbr-eKEk=l8yl>bhnyvtGkG%4KOOF7(8S|e7+ zP}9nmVr5t7{CvnaRMsu5)HS%eqHs?snnrF?h~{UC=I3eQ-Wbh&Y2p3>qABJ7g+=pA z7tI4n;a9!U>`gQ?`*BtdZSL;wD(+HKQ|iU~VHLBB`O2y)=~WhWWp`)gyq5g@EXvAa zSH8Zbv8BG!Y!!uH6SAFCyYuZ`ZMp7z;XwucP+E962K}3~@Y{Ytmw7y5VL$4^ek>{c zt~cyg`!{0N^l8QD(UmsVTAF%7&TiQ%n?h$#SFv*RXzoMXyE>`H!p@#hSznf^F?6<9 zb~l@~lz*uHj?V5YhgDPz9y@bp?VdBHPnbNhc49JP_RLuoeOy-h3~m`iOGCag*Rim> znVXHumJW){mzXMt(j8nM!f0*+dh%9cRvL-YMM;Ap>$J3@@OXBp**&USz1GZyCv?C) znHHY1cC;I~r_;jkX~(hy&Tif&kv%)=p5dr_COhh$utt5>HR^Lo;rYdnx)-veuDADI zRPVi%7G93M_lLCbN_p=!tHWM(W&bfLyrvF&k^^m&v}BcT0P|nuaUV zQ|YjZS^0cr|Djje*cmF@xyaIZq^}-w0 z!|%9W_;XTtx3?E;YBXldn8JJ6;99wNaFr@mcwgh|gS7CMIKDni3m+K=SBmD9vr+WE zas1VB{4FVbT)I>+BjJ8m_v4=o)!d}iy@gK~y9_eR%PxU`XB__{EqoU1^?6$O!Z^9# zDqWpEHOem?@o%w)CMpznDr#5DVoq=Wx%UC|9Rbg2YX;=cCYIWdfT_2iQ!RkG&ruhbrimY+NI?Am(@2thcy5?r_tVbdZ z>my@48wlA@h>u4u#Iq3=7thA1N!SGG;~7Uh2M!rpzcAM^tT<$7uE=z}J#Wd7FXp?4 z6%WSR)e|^khZS}3^;s&13~kRBi@7FDVN(^%J&et%?0VD8VcW)l36&Oyb>n6v(y%!) zcHOA(&hpwnR6aKo#zbsFwC+U@75&JVCTumJ6uSW zkTPERz*~*U@z$UwVT6H>x78|SBWNUlvxfooZMkAGj56FhZCes)*v`zQLA1S)9fbH! zlMBQ||vKH8nL>lIp zxilX37jl3Q-vV-h^*}6+^&r$F9E?;8STz_AR=>&9C(HR`dq1nJn zCR(M(Sgi_wi-SXp8LWL;Nu*&eGPX~fkai)ieWr+|1K6>2q9&mWsm8Io&^R`wj;Dlz zY`N?jy&i_rtwwEr0TtoiL-!zd(hBBmAr(oap&J=P>Jc(ehzqGzF4TQK78lh5)FdoK zT2u^>3Cw0~u4m6+-EB`e%=~O{E+UbJ?;~R;q(TlC;%^vQlPks7W}=z-*q` zck>XAwo?54Lpa9rSzV7Mk%r?eLuq&(FXRLvzOHhC_e3m?_axLLoQzZtSVacBi3q2d z#p-%0i8P#sjCDO-$QeRJ887Yff%8mEj`J+kB%E!aI}lsN&AIs;D=0f1wb*@UIM;F- z^?4-HaK728EqDCk0wEU)DWjGT)E8lL)EA>B;RgnKYO9l77B1nhzm{KWxV7_TB+_s> zGPd&%g~|>a<32{k6ei7J}fSt`%#nd3#5;yE$g;_eG67+Q(NxJvJULK2cwI*6JS|_ zJ4v&652V46j%TM*b>^H2e`6$zK!lx)4t;7s%hh;>h1bO~Rj$p4{ra zFyGeJIX}E5liGZIlFUfL+g5;aze6Goe?~^`cZIws#B<99?)R}c?hjCt@E4@#w)rqq zoA9A<-I45>Ox|bmo@1wmkE{UW{ws+z{0$koKNj+d5YH_axIe|>xc`otgnu9h;{HrH zm)h*V^4p@%tt1=nUyw+{mq-i3E(2c)`C5pAh*u!_5XCo`Tom7;CgGn3`r*FF0EYW_ z{LL=7@it@-lTC3iK_U%HA|v-wLY5X%#w{PXm%-$?mqks&at3;CtCVpsZw0B_vzf^X zmeHENf^uLVVN91>#k(IO0`Nldu|6O>Z^fQ2VcECaWuFznRGz=4Iim zNg@qvA!9ge3t2~q4@WM9vo01F&U&axSRd)bDVv#WAY*nbph;8M(8@6SjYyJiT0?uf*c$H$zRr=15OJ;LKzTWwIMUyKQf5;g)2u6=q{`2#GXoiHzf6D>|VmBNu|%6^jdI zH`FAIM*3j10GizdeorQy5v&L7sv5>v8AiW5i8PEwM*2O3>?y?4%LV#zSRDO$)Fezm zdU`uH;uBgeliG?%HLVjSDre~|)jL_8WQE$Om`oxKdm-bfm?C7Vkg`!B9|D?&$pthW zH3>5e^rIr4p=Me^8Wq`mb#KdPEs>B&!z?qGM#VltW()BxAs2}E#o~zfLrsF+*{LP0 zCX9;S`Re{Ml%JRn0Bb#z^B4RFl1{@x$XN4(g&ZP8{H=LA)!7rrR`ayUXRZD1n|vOM znuMHze)6eUn0(e@7^>RVlS)H_bgjvUypTpAuD03ivkBPIG@~Y=1*xi9=v39uK3g-J zHSAoHX=p>nhHV$pA;eQHVKwRmc2r%cNhlyaRqyE}$dLaJX1PTx*W&3Wk%k^*jAx#Z z`9fSgt#Tox1z2223sIAB7*Zi=WW|ZD-3#P9=7mM(Vgvg7B+`%~W49hI|4QM?b566I(&aUljC<(`+y9XfILh*PuDd-V-I1V#09nYUMoPf+` zi;XS$wuWLjQSeEEtscqj>EUFwDNAV7PMbY-pUH-uV%U;BI-}(&Tbql*si39xBin{m ztY8n$G~477PQy|%G#zsG1$OS@rSFi1%P(`>9uz97@vBgfIRx8P(xXAwY zMR74IGhhR=V=G(GW|iXLa0w_O*&ep~Tw7aRu6}MYTxwZu3|vNnIWRJY^Ftw52q}BW zAs@J}#N@cILS+(c;9?Ist^tMHmSK1r*wxcr%j}k&FX39t&R&La9SP>N$T;GEB;*Dm zz1wi?ydod?Z^Y#IZ$f2Gi;O*>dy8rJc7)swnw{^nSnolP!{leA z)9`c4!T7Y>;d_PLC#0fvv5{fUHZtzFzkU3_Kuy9g4Rj-;)hcd-;{p3?{qQRi%r=p+ zA08C)kdWSfpr+Yb_hE2H^BYvAnn;~@EPBqmtJ3ePT6JkZ(&!Q45#{08pDsV$x#HaL zC{V2TW2BjTB4fQD7xILV0qSkeR_`b6Z(qHiLS+hS;6U~MJ%*vG_%o!^@T~OgYV(|s z=Y=d@@fW}y&x@!@cnPVBTliGm7B=mqH@Z8pF*LlKnXR4wKr#)lAY(hfD&&ts%II{- zc@5mry^flMH;|srwxHS8R91>ccjt@Uwe>Bv#e9=(q|_dPT5V0srXFu*z81@$NT%T} zWQ^r)A@2w&i^Vq2gg=A3Sl&fV!h1*`%SOb)LNZQx?3 zfp0*E>S)mY@3)r8TIZi6n6{d=bcq^7%q-?2ZIQQchhnRIU|kZEV_gb02}>i@Kvu!h zy-;bCgk@yQW+KgYVqO+tv8(hRoh;<#Ixp^}-887BA*^rFSc&1{m;)?V(D!e-{cPQ9=>i8O41jAVm_ z3=vYnvXS=a*(FatkZp;{k!^*VgrP`BHU&jXEj)Z-rAwRL!Z35R8g5O3rD0?w9WJCw zNN)}8E7yg_zATp&}$0bxziOU+HAs(iCQaA0PHX zCSf*`vyuMj%q>bn-`5O%q3@^McAWS#cW-CzkD*oP6N~l$G!~1IG1`NK94w@wRe>3$ zZjItC<+#$R><|pchgxJ34n-24{wTb7&dQmgFT6VCt~bzKSX!;w?7?FLf3u0yTpQ+H zA#Vm7XN@G%&}8<~)umZTi;%K$CLda+6_cZ$i<*Qsq{f-mopHuvfI*`}yHL%|S~%)N z-C-Gwx|2j2x{#5&AS4JWqm~cUMNE#m8#M_%$o|yB!#ttc%Onfx+j<)EwY+6g4D-R` zxL!ay4GWR6pAHkUNXP)=+MI1%f8YN0wSJ14gu@N=<62F@-SrU|H2vowvnP%A6r!{# zFdS)4HfD|@k%ps@F_vS594n;kwu5P_Z41h7J&wc92+<#n5S2)b$D5%q#uHGJaH4^3 zgk*E{lR(QCp-;Aa)=8(3NW-a?p)@{D6LPwcvQCl@#Aje~#Al)=;Vh&&$?D1Yu%(!h z)uXCMgtKK)N0qkFv8lY659e4$3*lT6OqG!_g!6@5Afzk=`4GZ|m|O@Kp(f#Cqz?g= z5JrX{2)9FfQ2qRda0vYjmmwqXNWS@Lqw*@m`6VgsYIA*EUAiH1pnB zO~|*m)XZF1-<%8WGef7n09LGN%6C>Z@K#F|_gbNYT@c}F6`(Emr7KvGt-ba6Vz>rN z9L(2}PQ!J`IDoDf@*^Pw3}$n-!F+@L?HedJq9)-c12vfaj$0dYElsn;KACN*72(I^ z>#H1t{bqon8qv3qO2e(vvn%LrLVhBoqSag(x%y)xH=9x3ZfIXLcc3QWP6OSrZMDkU ziR~`_W;e}x!tzstZOq(FA`SPLwKQyhCgkTr%KA<|G{e1^9PxdqNw^=W>RDwNw!K?b z_4-)&h0NI^2Jh6g=fW>7pGEKhi8TBQ86)_$kOzg7MIav{cnFh=;9=Ax{08YGKpoUI zw<-KqxK6EIp>rNPvO{3kD9KH#mO~U5}4%FpeSXq5t{v~P}zA~`X z5U@ocsE9dF;%qk&&lD$cjSBcw)bfv>* ztiB#z50%Z-29|nw10fp zE@TTKWn{4f2V-(%Lr{~jCDN1Gl_7iQy2o}}Y$bDFe-1@wc^DaKw-&OEkTTlXpTjXZ z+A7o}R3ja&>ZLO>kF#`jDm`Gghd4D>fL)75kYHalGLnxHvaOIZa`|vM*bb8;-ySsy zJ0L}#-Nfn1ftJ*oJ6r5_WJhzdo7tU6q+w^I5!&Q=7a_X}DVsdY2g2PjIl|GXN%)?D zZaJY<#9a5sSV6jp)L;&~TSi+t8%rV$dziU2mi81fPKaMRlMBS-u{h!hs7aWJR86fW z3?-{&=^-dC7x*^GZL5^0!%jQmrDOcUbyz+L1}oI!$&+AbG&081B&1o0k4G-V(}Kmt(~6pexkw+6 zZ7Yb4Q_kF`oEk|i2h`8yey!a~v|u_&u&#`Z!E^~J2=T$lg1iltwV%V zKG{K)tyhG@t!xYH2ofwXBV$-c2{~Gb4@)kDbqp34*0HEbI1brAtbeyAa=d&rfQ#JC zH}tR>D4bxWTVN-WNW)3U7}&`|P7&e*lM8{Jip2$X8fp?wNBY3*%%XQkb3Ne<*%UzE z8ts`@iUn{Mi8P#zi~*b@0J#vrd01Qk=c6X!0;CVXUJZ^TwvY8f_*@I#~z#U2Xqah3M@;R;#o{G&yj za3v`*UqvDfS0f|yHA1cx;+f?F^L1Dp^Yy4n_z}`GTO;(HKH@^r4a%i6y;ksSGJc~K zXKip3i8TBe86&w_$Sp#AByu5=Td}xEZbMDNPmn$mf2QvxyKVc7fqaa5Ej?^PCEbZxPsE!<}XSq%4+ zNW(9XF@|3Xc|eGdK`zAbD=aRCU!x}BL8Ont_6yo$`<@OS%s(Wf2Ba?eHMWQ4VJpTW z_zj6P{1zD_ctpsfLVN^rA%e%SxCnlSnuN!ZJ_3K)e?qpt#q65CgXKvp%ffhyL>it( z#xQ;_Gj z#1qQ};#aUZ;#X0V@JFO0*2wRjC%h)37E`iFBzt3d-3qi=-XOso7#U;vlaRNB_*mpZ zEN^3RvAl!I6d3tm#`3O=S_<^B*t_H5JuA%ud7nfYK0wAm{vzZ7_&d}j4B{;8h;_W{f*zKTrEdyZ6F&eG zmc*hMmLkDq7#U+&M#!>4d<=3ShUKuh7?wv(!U{+iLu-lIZf3IX&Sp8G!kn${E0RdV zO328&vXE7TcwV`{yDApPyBcZ|R!4f?tgA~~LG5I~sj!Z(CxkW3+lbdB!B%NxBwky{ zIzk+=?slU4x>)Ve>w5OL?_uZqs7crWX+hX7G97_~7d7PTdYV}KmBp42NN8tCTiA%C z=3E<-V8s_{k=UGTQz4Z?6iM91Cm-lH!{q2UM@_;O2D(|hK5rEUTPcdBw)B;rFvK!i z8*E7;4O>}`(wu9kkYPex8??#=;;peb;%!iqFdV7Ix0+COyAtX`Q>?ZtK=thDTn4I? zDZ6zHon7{FX|-jyIBH0wVFWV9F;d7VAwCYd5XZJyTpZh>CSiM|i^INl!+|k~=~Z3W z!5oZhM-ps`Mn-B8&SjV$M~-nA9b(na@s=4DgrF(g>yMaJ$O zD`XELW!)a%?rXy8ns}lXZF0~!Js435v^Z3qG$jz^{eHKQ|im&qOniB)d z6`60>7FpvhhM87^g|;^dUJ6FW&}IqQM~DwiE`&B4iwkXE)FkYObfNLNqyebA+B%`O zLr~QFlSsn>$Vh#lkb{JHYPmptFcwFB2x<~)k^QLGqG5QIn$L3N8~E5(UB0bU?L(~q zt9FhAF9;(ecfF7XA)Z?Bem8P9I=fwKpb zV!KINmc*lW-=|@!IXH-L_;k<@=eTrM?)>GEXBt zn*=WgBO~d#Le3N7N#z3R`B)t31*l255ZRB^Hc+)OYqjjgbTJ8Lt;oo9iI7W$cqX~P zbQu=MbU7-sR-|K^IA((#-uz&rJ*%juHhl6`G<|Ed7_KmPW4@9^8m>Y{=BtHVBg8Yy z1?Fq9IOgk6lW;xKG3))VzB8wNjx8fj!;h3Ri-(Vhmg2d=3blA{B$0-jkTISg3%OZ{ zk4G-Va|;$1&#kCQxDEL);^`$$!%vh`k1Fj5o8K9D{#6XOTd5Y%9VFN^jf?@^CFG|< zd_Zy`pu4fSfbKy}!q1Q{ARBph>9)Iw@N;u8u6s%FGB7f7-7n-9LOhpT;QA#N$Mpbe z5`KkrT-ra-&Vt#mh3mna^+->VlQ-X*4G&r&#{UorUeZNI{@)1stq{*I7x*8+;`kp$ zO~PYH#~+WfUe2z3s1Lt0Pa}Pt1dF`LNcyCZr-XP?xj_0f7DxJf)FeEE>_=*66uTPd zxnlDHTdl)CYraPN9EmhMkBqc02zgP6rm56FJBw!hcWvh-r?ZJoJ> z@QM{-?5~pGOs? z$j3r_MdbqVCs-Wur>IHzJ5m+3ns8@US*l{b$cu5db(;1q7OST!J*%^^o52?TAt#OZ zmZlE-N@xo=EDOSCR)z)lIf*oUfsDa@Dda05J~+7$+}BuKaNnRN;aj8+F1u|kZON_n z)J$xZ*Lum`cj4LlS%(3mHS%SjZ+qd?<1ulufa?P%2TA zuo=>YGI7k(U)*`(pPE0;4%PCPo;J4vY!GZgA`OF)aS#jznijaikp3dJ&ZO#Yme`dU}+Z_+hcbjV}+Ep zhkS@=4@@qiJyDY|&OqNDij4LcZw2YuL3?LM3x`XXU|Fp(CXz_QB+F20jLAaw65<;} zE>KUw;;5&hCSe*Q3o_C-4WDQo#Vp#99l$gK6#us<@A9U$aDA)ZVwkR61@ksXYhghP;y zOb=|_@o#rsy#A|9oQmyi-1|tUmi4vzxm}%nUOF6VWn2H}NHCj4#{RDt(jcU)|K&qm zc}y;@MpUM=2KxSw$3wFf)c<&BvE0@wtt9x|mE|c7fi@xSLVT;p1@;asj=d9=FJ2+l zDpnnafV~M-c05e8#gjez>R*t9KO%y8T8k7(@VP5wY>^%z^MsVOhmq0}6Q3pqlFZw|RYeIyn~eH1Eh1tZlQR+oN9 z#4)m|LyEb1`C85qwR6LN zYGJ1(=w|g)b27fuNHC>EM!qwIoGHZf$pyZ%usFW6QIl{E((!5TYt?J0&7Nk~gmaan zZ)$R$6=M;cPa+K$AY%j<3b{y#k3cR&a4{Aa!4FWAa0${!V29me2j->9lpUDyPN{F% z`!XxnhT!ES((ps1g=9CiR|vUMh(d}twelgRt1!8ku0~D5H3qsNm_1g$)=KGrL|$iE ztvRkI!Ccnzl!o99LT(h|n?o*8--N|c{}?q1HzU;?Ru_ig)-;E|S8C&NVqEX9PFVM^ z&reNebj-9bI;)JbLU3w^}I{-fbk9+9G3kw+p#Lh!0OLgm)(v7v5c{ycLY} z;c<_e{a8u3Tb8oJ^Bya~$bUwHH-nLp{9YmV3Gw7|f&6|fj{Fy>yb_G`kDIkLB%ToRq!3?ixxo7r7RUQEY7%~rRJE-l42c1ckY|)h^>578+wH=$mfOO3 zjs){sWDMg4AukH?VaSCrUc%zSco~&>Ez*TCam?%VJ^wgE*<;Y&+k{ul*?3D`lF8 z+S~5-c44hOgW~P;u(EmENLqyiYp+NnwyVi%LRJ@2Hj?B6^BR~O^O~qhSj#{+l3LZd zoHy3Cg7hq|J9LC~ESpt$T@oy}nz1yN))%sY5MN=rK)WFpN4pVf5;jJv!d8#w2@}V> z)2M$`oK59}sB+lEEJnF03BF8)jFg)R*<6UHlnazwU~!a#QIjwPSw{KHw)$71?52yi zG(V%xcUJg;2f-4=`E-42z7SEL(Gb9R8x_UsjFiuRFu=FYu|vi*eavjTkbpxQ#) zK}D5^Y-^c28n1=C6A31?$Qbf2LUtA6LzW96?}o*NJQ|e^%19sbzgo*2qrAEuZ0MZd zVUH@c(iwKQk}asQB$(48V^Dhv87IUCB^QDkkHrNw0hP_mNEehHmUd53$5*L~VG_pp z%GhMme98)Ge0DTW5i(VX_}mAb%+oe4OtZgzn-->{@--_1vzH(Bb=%Sbd2pDCA^V`y z-lSNgm7aZKW0sJ8gt*#eFM`bmb~O8%ii| z=tO1Cic}ccHFA!9jG2XBPTxGIDAw%F9`}&fGbDET7nE7IOnq_LMtE!iTO3_Q(o9~F zadhTK zNBXjt?y0j;Y5Uw+?W)anbadL{Pebpu{sJq-YIY$BK2e2?>=z69fe^>uDi`|x5-cu& zOHq?>8BziC-M`r~t9zn(xw#qR4@vM6E;2G+DdZ|4o>49^UX8^uUW1y1YmxmJ?b=rw z|61k>*O{Z;(Ogf0IV#dP?Hqc8kQ;@FGtQ6Y1L;ke9O;iynWP#x&>hV!R*-Iw_ZvGt z3Ab84tL$wgn0=bHbW3x)kUNC<%E|@aJFz(4yHJ^ZB2`(dNNHb!orNaMNXCb|4KuQP zNbs>`WF-5!kb8xAGPyu@9~MV;KPn$vMtU+10@W;;F?MqJrFj_H10>iYjErQz7V@AF zPbL@09>U_t9!BNOS)?bk=RQ`mJ!Vgy8h&dYM)n8^rklt}_Lz|03Grldf$VWCj_e84 zBs_`qWE>`T&t&@^vXPoCe-y)0=554Jli=-HWF&q@$g@H`v0Na24vQmx9+fv|ks|Ip zJZ!M~^|0`wIU46nBzS!m89DzTwmLzKzB4zJtn^V5H;KQbXVV2i+vH3f;ST>|JG#x32zm zl<=MvZ-?*uB-kU2w8-r6{fm$fg($Lk_{xXSKEmWe`ztD&gbj3uZ#K34*hm5Xxk|_IKZJZH#J7uFApabTBmV+5311@BE>;_kVcUvgzh@9K zQte#)%1p-jH3>eNjEtP$3i+oH&nXu;zr*4<2k~ek2}>Xyr+;*zozabLEnT%@UXq+* zUWx?ksL05?jF4r8cxJi4yc`zCY`+7{=3wN1!K{To?pySOgrUN`t;#EsV0SPw60a;| z6(OEjE)cJZ#SyQD%I;vKBesdQ-57Rsw%Cqf^DweCNw6ar8OhcbvW^f>CKt%o#p1}; zL*?Vi$bMvYR~ER5?`&s7Nka>NxGS`D2S2j70r}&d*M_87R7F|{HmElivWXCd;O@N4 z)9$=BwZDCLUX`dUtQzR=yfno4&co&yvOBLWNU@A6J-hQ7B4kS;uEN=!*H*xeW+*C4 zs7Te;Lbr+aJr~)(i@~f-w+3ahii~YKTu7A=PoW#IYG6lEgUUytke*_9TGRez_cQi0 z1a-kys5^LvQa>yl*lAj7+>utG1u}{RvsYvcWIG|-3vq$8%7s?l0gH=eN7N+jgdCuA zxl%OeYx5oaSW|tzw%(4Q?1Zti6=M9mkYFn@GV<>xWV8^^FBkZ~hsE)aK~2K$$bS5G zrDFF+d*7>)W6jgfk$aF}H!(5}r*T5Y3lV8NN6H7*378z~MARfqGSHnPTeGpmmSQVN z*XDMY4IJoDt;10IrZOqM1|>Z^RP#a_g}7Q~hiVhB zqiIIvo61O4(L(29V~Vw4D=Je_WNg7UA?-qZ**Zo$fL+<0sLVr=zU(D*aK;;{^7m;A zR)7u9K!Q(0A!8GE3+WLe;y66zL;KCc15w&mS!<%~Gnu$x_8P9|F%v!{^e?N(%L z@Y95xE=0_3%$lc-*)!~K-57>gsn88r#FAVuCcB<=BLX-Kkrqr29g0OY%v5jQr5TdfSEzl{VFQ)HySUC13m zJiT0?zY~k2zYCR#DY75EeR;ssmp1=}yUpLI?;*i76&b02F63Szo?0$Y--pFf-;c^P z6({8qYE#yHV;*6&s`9S&*CP(@(D$`X17d!p@ z)=JS4nJrkc<+Z-nzNHr)vHVu)M@jHLt7R;mfqp0CaUs6aa)JK|ERO$4RKE3uRHdy# zJX)PXV`op_t(j|zT{)I=^4&GFEWxW-wPR;Yt{qi7HCL?NH{X;iR*l$xhmqTFzw?gU zj@W5L?TC@JGe?XVS+(E3Betz>Epp;|TIJ6{>ev4mG$H)nsH_gpklmV2E@B$WBhZj-#05j6pL2ofz^nD+scv)GpkE!+U!BH@Ouo5kpS4c4PMaE$MDC9LE zJ{Y+W%)U+YR_b}$B7^CX4`aE_@_A<=XWId$zIMb zk#h;stn~`AC zDuj4qxj?)k7Dv1iY7$mPI^rp#sIA>z^A=TD#T<-lRTBJuFEVniE@TZMo=Yxpt%=2P zt%b_(_aZ$P*IIRF`AZM$n5(g_OM+EjWMo}m$Ob|@t6X5+5Q}5o2$fY|q+``j*7@yJ z@huNq?dAIv!Pb5g*1!=Z{;FmNZSY*F2si<7ecDS;zFuMO+pRQg~Ts#WOJYT z=AMqZ*)y3D=4X^6N$|5p$Vj=ZknMzcO1VI}Jr+l~18Nd>M0(2Ip|rVtuBW@xz2oM_ z|4vqnMX)moKEsTR5$q~tHz7U(xe&ouo^z4x;vFt-7&EFVx|?OD}3)u;(J?Ot8+qvA1|^TrMr}Ugv=J= z>nsD4_e15!i;$|bRffBiO{~K9-UIiAOpmJj>do(YxAS?beAr(u8vA|LZ~$nW z&>TpbNjNe#%)vqq5u(`KgvLBI3Aev}6PiO&nS>kYCp3za2~8b_?7OJ-q?m+D&$=Nm zq)~|PhSDq?-0?J{G7C4*_kq>H_WZZ~-cVO4VJwl~t6Np7Ip!YeL=TY*F=) zU`mdRQOy^!K!}e@E=08ui;L?kab>}b>^9D{UZb9h3|K7__LzTg{lT2)Uc!L-|QlzRJ2A!iBkRh0`>JsXSTJ_j`k=OR^A zs|me5C_j_lV@pnZ|}3r^Q{<*?*bBhsu>yMyGY2zLVSF3A-*4Aaq(S(nuJS{ zK0f;Z+2BQe<8kp5G4J)tCw9^0Xl!vt#xA--$dy9Mx=2&?tH50tSEI7O*}#Fi=voZr zJ#?LAwHCOZ1Z&60*a9~QxlxF30lCluH(_zqKSpKg7^xPpiqJzt|7{O(^Im$dKinc` zzD&-81)JsCl&PD~FSt|XtyZjEKyD+!%pMsV>UJS_2=NWoDjx#A6O#-0E>x!W2Ku&& z7m&NHpz_AL$MRZZ{fq<~o-IddINvMeJ|VubI0F!wW)Q6rxbvHPk%q8v2s`?YoA)jLNQP1N}8r zVQ~$8#SGTyuaaOJG%}LCCggP?zR|Nw=o{dU=1o*KK^y2>+^S%g(6{(o-r#SW#Txt_ z68r)ZGB)_TLf#YN8(c2b@O>A}trwzBNM`}w(*WWDkQ34Zbj8GGd`Azusez0xWlqWA`ri{e{Ue!JH| z*D=|H`0uO~bxY{%%)U%Bh~`lBmmtB685yg;l#r!``0C3A;$^Tn;$>0!C^J&kx0=u` z*>8mQT}mP`vu5R7>{-nUR;u+!1qtTM$QaN{LRJ>y1Ck2?t%Ah`v??k;=!^6LWpli~ zfZTH;zS*qI13rvd!wR>!)+E8=G&07uwvcs%__*XkTMC>F;$3^fT`Bm1%1(-(auzQNw~ z%zkXA7`8EQBOXqIX*4ntR|}~T;)&$~@dzxAcqA(C{v!Jk+Z?boFp52mjV%kfTC%#cIn@8iQ>@ z+J*SK%LVQZERMSqmGx?*>TZ>3o?s^%5?N!nwCC8CQ!s-u1`@1NBO_zCkRBnPQ7$mf z!{Qj{qq0bibd1@P)`qsQP`Kuy{BB=6yM)4F=5N%CNU%hWjMS-+!-aTixj=mc7Ds&~ zDofNzPrW6-@l;dS(ow_LA)70!>MJ*|>{{5}+}TlC)1Gf=$?@|xr87Mb02WqORcQqa z*;twV{Efe2_$yg8uPl*0gZ@q4!He9jNOix{UUhE-y4Iwng6n!(+8V0Zv0llS6?sU( z@4i+0&ArvOLo_9V?0wxTR%c?B{Cm8$O}P#-@sr{FmO@8km0js#zW6KaW~kge9Id8{ z$D*B;vV-9mYiBzcjwO+X0XOGBREu zW$bQCM|9l5%;d}R~+%!)b@5@v5ICldD(Gwt@V=~ zSJ#o?ccqc>xcZTh8-$b{SMs5cZ^Y#K_$E|-SK7dVj;oukAl-ATFhui~Jl=a82Nbvj8$k-nb z3VBFKS%1ifFdoL_!uSm;KQV2f?+-N!x1Wz#LAn>z+Y*mjKCAv?B=|vTvz9vHaUo9# z@zs|Lyia0ryicL>OVUVH-zq{UJpbQ3*7Kb#gSFefX#1%S-e~N11ok^O{FnLc#88tB zwi>HwwbI*W=Z3c2!hHBQXNUg}N1I*1i+Ov*ZRZ5rn@~F9%ME*fT9;?v#Vh`U43!_JEpMm*?I=*sh`9=w7Z|TT|d@i$d^0IT4U0!=D zHNZvR0!sgNwaO=frIUzmG}L6OlC@^-C{8WES66E&>0Ov-tjBFAJxhX@F_3X6Jul=1 zA$};yg`xB!7B`e$LQTTUNI#UG`(H4WtfGBGsioup&Pdwv-;5;L|7S;%HKs<=AJo$t zNqW)g6&nsVl3pdjV@zZmNv{ccU5Fn^a$zLBfyIraH&J_uwte)$WU|qnhrGDK|$VNhZJ>>%F##kKbCa6i+6sdYzEw~}q zK1;Wd8!DASW81zut*O~&me+#VoCM!pLB=2k3mGEB2O$@N*b<8iVk^`n3`P1NY&mc6 zBK=ayF5Icxk9Ak`)6)F@qcUphrlohS1$0HZwslnpn z8G)LFkw_m;wrrHe!_NolmrYwTZnoelvwnr46q9|ZDDZ>A)>hhryseQ~eA|&=^#B>; z+d;^VLVSF3A-$ z-G%tTyC5lJx6H{RVbz=_k0;Fit)icud#qH@ve>9d85{+XNCU zHXvhclY~qb;$xEwvF(M$#Wn?%#RjC0%}yh{2cX}Mvk9d!hiS^BC;m3twP`<0x8f|4 z86;S9K*mV+7Lo|@k;sKeW?^xW?1Rdh1JXxg>kA9{P4KMV`^q|p^QbKy$4}4iX9Zad z)`zS*AY%*%2su!Qk3lZPa1a(3!@;PmIv`yPy<1J0IY!D_^E1jrNu(i%jFfdk>VZM6x#;n#|Kkn@RASk;q8eDrBw@PbwEk+psv&cGM(v zARVcGj%0whGD4^FmB(WrLh7i$h1rrYo{G=ezeY!qHZk#c>P?UbI8TIF1u?ybvFUT!`ZYEG~`{QF+l0 z>EqafIQm{vDAx478sbkQd>Tf6cJ;Qw))kn7*h9|7mZoru@mPqbl1Rg8$Qa`3Le3E4 zLzD|4o{7bUcou3B&PG~@+?r`Et9Se4IhNaq&LzR`N+KiC`9dxb;)&z}(S=wX(M709 zxEL9U^bS1_n|U_$1Hl6Hi&xfNw94X8F?=k@S1#~gfyMD&iJF9~kdc?g zy4GQAa8dbJV@Sg_$XNchLar0y%a;q~UysF=|0C2S+<=Vbw~nN6dJQPs2hf!LHo}eO zXmz@Y1h3v9Bj?RRZV}=+ z|1+m7#dulXp2PS1x9XswgK})qW?i9%TMOH#Jz8w06K(cg=0<@9j&reG`5EZlp@lyA<@yDa$T9yp&^;rK)=KDju}> jcTvf*#?P_@esgwbL^@SG}G| z2rxpFjfr?4D9E9R3W$K92nvb^B67$fpn?Y~;t7H$it_)x*E72t{y{&<=X~FEzxq|x ztK+?T_3Bk`RwL^}p|N8_|L@lrni6s;;iY%!NYX92ZN>iVx~_)Mj)Q&|)>TZ%ij6{b zT`VL#GeU)9qme4QNl1ge_bh6r|d^(?sv$td8_{cSR$;1veQfrO_J`j!g=B5!d7t`kc6%Px@xOs*i<%OI;_Ap_>U(u$;!CdM%Dp_1$On}%SNI>mG2mePg1nG|D*@DNkQvTquzT>XTb)eeu+O#TlO zyH%U8nX+K7nQEz!8lfv@niXpRQ;~#|MJJp@A#0`|n$Je$uPJ21W(HWyte7J#=`U9^ zO=@avBsg*PS*k%Yq)$gy%u#5~Hb+}(s+y$L2-_UPwN{vxg|6E&~2}i z+^}gz({rK~bF9eQ>B|;&JAP?TR-2an{ysA|4xT*MNgXez9G-D=+yOzlx_g~c!GXxk zya35%rFxmtW`1lVDjlY32)jrwOXa*DF;;8_SXSH6=M*)iS(btP7sMKoNzfc;)v|Oh z<#WHHW}&D>wGP+>g%^n)vzi5^T4NJY+1mq_g<{F!=FDOhNOBN`nnEMmun2fz@RBI` z!7_+MMUhhJlH<9)XO_gqpv1D}I^G%YH0|Yz>9D-%9t*j%VwPH7lbYboQgt3gxy=t~bs$D=pQ4WP_Thz45kLWvNL>PEu2~H^DZ?TdEPs zM%AFb2HTuqsmVxAR%5l-Xqyu)H3i8jYJ&DA+vX%zlsf!T_&D2~Y^iBTPUGY}+ni#l z=}1oJBH64$tfby-kJWA+)UMlWIRG^~oz;vjT409EdNEBKED;PAbkRxp zyPQ7Q1f!xNfL(6ZfxHsX0J9P1a0)7yeGT0utnR*aB3R=+s!pU)*rh4!7EM$&iAqR2 zusKdnm%KvKY?1~y$EJf#E}eJWe7*oPp_7A+Y*8cpis`LMa%;tGV@b|~BwmQ%^46l_%*)@+Bz7G$Ry=~m2}YV7Tbc{8)05B3er z-r*&Gwx#AnDTY|PU>)~lr6!m)+nyZEokJS-R_7DiBo9CVb&YRd#iW?;oSFi9mZxl! zw$vQR=U7#zy^XfXSWMy%1?;g+zolA`Y*90`7qv~+Qge}CoZiFr$n>@*kD$ff2&X%GdLafNoXx0bcsyj58B z6!M8uA>VG^CM7pT+m{3-FFGjQVFsi-o1&d&UzJ*6-d?3vn*CCG%bHs41G1>Lw6Ah> z(NA@DEN@?B-XU#oS<}&BF0PR)%{#H7y{o1<@2;5lu;#Qvb0)GzzE^gk_gShHqUK%b z{Y=i#s?IhaVDf0KM%w0sOwQG+-Zmd%vQ?{5wz-7Km0FFq&4-yhRV&!Rk63CkGlLDh zlu0my{kx3G7Oldz`6!c%w1OGD+)`~^7AEitCXdu=qHV5Z($Wf+@G2%FT1~Re$1D}$ zvM`7rXL7n$Fost%d8}4bZ1V{wJGGi>n`@XnVVJUh63Y6iV2ypcrmVis`vqFR)?6D{ zqg)|h@*KE1DQ&KciiM#<0qZ-La{3C{q$YEH?Bv5%U-Oy3NEXU4eC7sOD4(@N(nU9K zZag4luADVD9T2kzl3gZqvnccDVk6LnsJR7}=GK~AKVLClV7V@@$@Mn2G+(sTa?IF} zrMVrma7Qo;Uy8$`d|9+8@HDC(r@8acu13`l4V$}Q=e`0yU$vxjwQ6_U=4;kLRZE3_ zOmNMC-Nvb~`8sNU1GT>yc*58K%G%r$TZkdGG_QfP+00T&m*7h`TiD0bIk>XPX1GDk zaJ0(jG;i9vAZ)%RX2pm7EjhBz?>(qxO;n)pI&3sOFalxoZ8Ulx8vRaS2gKLBKUQK% zriy)jq|aYg^?1eabq?F_O%(Ed>69N4SB8BU2c9rCsy6oR%>!}>NavIK9P?mLXe`T3 zR03{dIhQvNsj20C=DUGK)%j#F#D_f|#v>TUqn3BS_kj0c&qA)c!h1-CQ2wzv9^Z>c zwXky1C5=PH;}!FSm;wiU%#*hHJ_`LHj?MaqEdQr!`tx+f{D}2uDfCA+!XI1SjkfuT zrIunRT&l)t?>yJ#%8rH!-p@V-8*1%tQ4*u%h>9J1#4fxBJ2R;6~<*}On1Hdbm zcTa55f2EvZ^H1+yk=4Jz;MJO}wy>;LLRPcO`zVAZ9QQ<`Gy-(ALUnl6%AU_O{4Kf&KS<|_t98+S7;pG)oG(J=n9P&B$zxFFHHd8J!jKI z&cFoK;~yR}6x0A5otnhO*terbP!MOiLX){l;0ej$r+R};vjph` z0tMCHvo;+m$O<4B&XBz6DBu{%(Rc+zIR;d13y?D=7ADQ+Y@DR28H5#hm+U$^$#epz zQ5=+M4sv3b|L=@O_3+-lq+_X5^>jOyIq+i_iyzzWER}<^BS)TXd^nCIJe}w-7Sj1r zTe0M|!5-4F+)Fv(bI18CO)Y?;a;B8Q@X0UdWx4m;_!oOuW3 zl>E|Q&V0$?LuTA9B|Hg1kOfekD#m)~Ku-&R?iY}c?Gg;}X(1QAL{0@&x0-_(vbl2^3)1U%*fhwLepYS@qlZ)1yLJA?2fugH(n^`lN4cu&xf^;1-AK;rX_$6 z+BrCE;1UkNZYjJ~=ygI~`KiM&z9404~rxidCNH>9?U&$r3l4}NIb8tAORY;3D)O0+PJSIn1 zn>m3C2QvdqSe^)Ie-Hz}s5SYC0%qz{=|n=BZ=grlEGvkT>L zcH@~zxjF3IYq(<2gF7ze-(>@|7T8`o4bKX#!?W6hk>V%%Y+Nfr{ePNa=)6rwh)r1F>Mj9NC&_J>WoUoaL5>f=;vnSz;{|h~SLl z^<~pVh`SJ{jR32Ix*Ij|8=xLwdnt-%g*I`eU|L80mxJQptCtQn?6^9C6wkeypv@>1P*UbyLswo96<`oGG3kB+t0l9v=cXU4p0u@UdrQH zp@I~UWk2o^`dRn-i>@ZJn_oI%UGMAxDhgL4Txp1F_z+zG%~Gy$O_eXZ`03{d$QQ09 z;aWait`8q|+f5gYYaXCo!gaTBJ?9N`)fL*qm8)uWuCO~#kS4XG>buhU4B&)D6(M#6 z9xT13nwrT@EL~7d;m|zNrVFd7$S|$g3k+KER=mWarMH1%t>BomUrN$NoWn9708x|q zRgZ1&o$>OF0on&}FTEYl3hkEyBJ)E@=ReE*ile?M2Q%KQVu?t_9fshyDd zIChFZ!~o0u5+QcN-^QthK3q*n$b~*qO(CEGnO|B>wGET`WxycwkKz@`{Blr%XMP12 zuXf@}&=tB$kU;Uo@BbJB88!treVl2v1GZ)d`jTN?vFU07cLLmrQ{IpP`UG$^dktQJ z0s17UpxIAxm1?t}23?_R1!+zHz#9bwpt=~GFW^lA0#Iuhv;@3aKmY(6bY_8opA&E` zz_ko66z~=SPXl-wgNp>bRls!s*CDK$X%+DE0(JxJ#zM}t3HSv8PX~B94mz170^TOz zdVuSBBVhl_rY{P(0pJGSGT2A5>2?7XK*iP92zZBp8v$;Fa5CKjen~)t$W@ewdAfjK z77zeH@LVt8odRwGxQW3H0^TLy7J%5vbS5p}R|Gr@Ak<@+7JL;rwBT;M0xkF&C^k6k z=B4x1Xf=JE^LY#U1_&!Ql`|1TOQv!CM}wCD^iAXjVKzi^4mxVxBRTBsB)oFIKX4wu z1(2^a`Cf)xzhf6X>wAIm*8Od~YdUxz=n8#DkS68Ubnt!#GTf$34=}A3I|sPY>_XyV8DnXqSGB z%kUI^4@7X@L;+ll^4Up3kn}hgVdMP-7s1u!D9EBhPjafd+`bRGLO&3sNo6q9scdR~ z7*w$7DIqfOFWA&PT}=h1=10}k5yJ-XV_+D-Pw)zi%TGb^q5m01Bu+v<3(EW)FOHVb zFN8qAgq{s@eks5`u~`^&i^S+!oBPuFwC}Yxr}GKYeBl?S=TIyNl+dq&ioceML1=4p z141Afgg}l&fa5pFsD(>@8a~O@Bn*{@Bc6dHuLkQ>KKc_%;Kqt|n+3!%T~zyM4J z73`8zxJq@|P6dr9nIKJSJeF;Um+f=`Mj=SXgc^?}j|*k(tw)ed5Zt`-4WIUW1jz(h z2n69ZotZ8DY7+27fY@JkhV3vyWdfcJ5C%_Y*zzD+CSV#My9MlZFLUZtZO-SZg$!Ni|+_qVb8GF|YW5M@ z@ILW}2nYXO`-RVIYY*qlLA9aZ4veu@yE=-^YV_#VYKdb;E*>l^{XwesKChbdC|K|h c^?3UW^~~jGC@ None: @@ -17,4 +19,30 @@ def test_with_header_in_csv(self) -> None: with pytest.raises(TypeError): data = read_csv("tests/tiling/header.csv") ##for row in data: - ##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) \ No newline at end of file + ##aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) + + def test_is_current_config(self) -> None: + ckpt = "tests/utils/gdl_current_test.pth.tar" + ckpt_dict = read_checkpoint(ckpt, update=False) + assert is_inference_compatible(ckpt_dict) + + def test_update_gdl_checkpoint(self) -> None: + ckpt = "tests/utils/gdl_pre20_test.pth.tar" + ckpt_dict = read_checkpoint(ckpt, update=False) + assert not is_inference_compatible(ckpt_dict) + ckpt_updated = update_gdl_checkpoint(ckpt_dict) + assert is_inference_compatible(ckpt_updated) + + # grouped to put emphasis on before/after result of updating + assert ckpt_dict['params']['global']['number_of_bands'] == 4 + assert ckpt_updated['params']['dataset']['bands'] == ['red', 'green', 'blue', 'nir'] + + assert ckpt_dict['params']['global']['num_classes'] == 1 + assert ckpt_updated['params']['dataset']['classes_dict'] == {'class1': 1} + + means = [0.0950882, 0.13039997, 0.12815733, 0.25175254] + assert ckpt_dict['params']['training']['normalization']['mean'] == means + assert ckpt_updated['params']['augmentation']['normalization']['mean'] == means + + assert ckpt_dict['params']['training']['augmentation']['clahe_enhance'] is True + assert ckpt_updated['params']['augmentation']['clahe_enhance_clip_limit'] == 0.1 diff --git a/utils/utils.py b/utils/utils.py index 9411b701..e2f66171 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -522,23 +522,35 @@ def print_config( rich.print(tree, file=fp) -def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict: - """ - Utility to update model checkpoints from older versions of GDL to current version - @param checkpoint_params: +def is_inference_compatible(cfg: Union[dict, DictConfig]): + """Checks whether a configuration dictionary contains a config structure compatible with current inference script""" + try: + # don't update if already a recent checkpoint + # checks if major keys for current config exist, especially those that have changed over time + cfg['params']['augmentation'] + cfg['params']['dataset']['classes_dict'] + cfg['params']['dataset']['bands'] + cfg['params']['model']['_target_'] + + # model state dicts + cfg['model_state_dict'] + return True + except KeyError as e: + logging.debug(e) + return False + + +def update_gdl_checkpoint(checkpoint: Union[dict, DictConfig]) -> Dict: + """ + Utility to update model checkpoints from older versions of GDL to current version. + NB: The purpose of this utility is ONLY to allow the use of "old" model in current inference script. + Mostly inference-relevant parameters are update. + @param checkpoint: Dictionary containing weights, optimizer state and saved configuration params from training @return: """ - # covers gdl checkpoints from version <= 2.0.1 - if 'model' in checkpoint_params.keys(): - checkpoint_params['model_state_dict'] = checkpoint_params['model'] - del checkpoint_params['model'] - if 'optimizer' in checkpoint_params.keys(): - checkpoint_params['optimizer_state_dict'] = checkpoint_params['optimizer'] - del checkpoint_params['optimizer'] - # covers gdl checkpoints pre-hydra (<=2.0.0) - bands = ['R', 'G', 'B', 'N'] + bands = {'red': 'R', 'green': 'G', 'blue': 'B', 'nir': 'N'} old2new = { 'manet_pretrained': { '_target_': 'segmentation_models_pytorch.MAnet', 'encoder_name': 'resnext50_32x4d', @@ -567,16 +579,19 @@ def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict: 'encoder_weights': 'imagenet' }, } - try: - # don't update if already a recent checkpoint - get_key_def('classes_dict', checkpoint_params['params']['dataset'], expected_type=(dict, DictConfig)) - get_key_def('modalities', checkpoint_params['params']['dataset'], expected_type=Sequence) - get_key_def('model', checkpoint_params['params'], expected_type=(dict, DictConfig)) - return checkpoint_params - except KeyError: - num_classes_ckpt = get_key_def('num_classes', checkpoint_params['params']['global'], expected_type=int) - num_bands_ckpt = get_key_def('number_of_bands', checkpoint_params['params']['global'], expected_type=int) - model_name = get_key_def('model_name', checkpoint_params['params']['global'], expected_type=str) + if not is_inference_compatible(checkpoint): + # covers gdl checkpoints from version <= 2.0.1 + if 'model' in checkpoint.keys(): + checkpoint['model_state_dict'] = checkpoint['model'] + del checkpoint['model'] + try: + num_classes_ckpt = get_key_def('num_classes', checkpoint['params']['global'], expected_type=int) + num_bands_ckpt = get_key_def('number_of_bands', checkpoint['params']['global'], expected_type=int) + model_name = get_key_def('model_name', checkpoint['params']['global'], expected_type=str) + except KeyError as e: + logging.critical(f"\nCouldn't update checkpoint parameters" + f"\nError {type(e)}: {e}") + raise e try: model_ckpt = old2new[model_name] except KeyError as e: @@ -585,17 +600,39 @@ def update_gdl_checkpoint(checkpoint_params: Dict) -> Dict: f"\nError {type(e)}: {e}") raise e # For GDL pre-v2.0.2 - #bands_ckpt = '' - #bands_ckpt = bands_ckpt.join([bands[i] for i in range(num_bands_ckpt)]) - checkpoint_params['params'].update({ + # Move transformation/augmentations hyperparameters + if not "augmentation" in checkpoint["params"].keys(): + checkpoint["params"]["augmentation"] = { + 'normalization': {'mean': [], 'std': []}, + 'clahe_enhance_clip_limit': None + } + try: + means_ckpt = checkpoint['params']['training']['normalization']['mean'] + stds_ckpt = checkpoint['params']['training']['normalization']['std'] + scale_ckpt = checkpoint['params']['global']['scale_data'] + # clahe_enhance was never officially added to GDL, so will default to None if not present + clahe_enhance = get_key_def('clahe_enhance', checkpoint['params']['training']['augmentation'], default=None) + except KeyError as e: # if KeyError on old keys, then we'll assume we have an up-to-date checkpoint + logging.debug(e) + return checkpoint + + checkpoint["params"]["augmentation"]["normalization"]["mean"] = means_ckpt + checkpoint["params"]["augmentation"]["normalization"]["std"] = stds_ckpt + checkpoint["params"]["augmentation"]["scale_data"] = scale_ckpt + checkpoint["params"]["augmentation"]["clahe_enhance_clip_limit"] = 0.1 if clahe_enhance is True else None + + checkpoint['params'].update({'model': model_ckpt}) + + checkpoint['params'].update({ 'dataset': { - 'modalities': [bands[i] for i in range(num_bands_ckpt)], #bands_ckpt, - #"classes_dict": {f"BUIL": 1} + 'bands': [list(bands.keys())[i] for i in range(num_bands_ckpt)], "classes_dict": {f"class{i + 1}": i + 1 for i in range(num_classes_ckpt)} + # Some manually update may be necessary when using old models + # 'bands': ['nir', 'red', 'green'], + # "classes_dict": {f"FORE": 1}, } }) - checkpoint_params['params'].update({'model': model_ckpt}) - return checkpoint_params + return checkpoint def map_wrapper(x):