From 682bdc7cb6063368061dbab7c52cc54bf9631121 Mon Sep 17 00:00:00 2001 From: remtav <34774759+remtav@users.noreply.github.com> Date: Mon, 6 Jun 2022 09:05:44 -0400 Subject: [PATCH] Common input pipeline for single- and multi-band imagery and AOI object for input data (#309) * update csv to expect 3 mandatory columns and one optional. See comments in issue #221 * use inference data for binary segmentation in tests/, not data/ * environment.yml: hardcode setuptools version because of pytorch bug * environment.yml: set correct subversion to setuptools * environment.yml: move setuptools from conda section to pip * sampling_segmentation.py: implement AOI class verifications.py: update assert_crs_match function, add validate functions for rasters and vector files * remove support for AWS bucket via boto3 * finish draft of sampling with AOI objects (with basic validation), rather than from raw csv lines * environment.yml: fix and update * environment.yml: add issue link for setuptools * environment.yml: add issue link for setuptools * environment.yml: fix and update * environment.yml: add issue link for setuptools * environment.yml: add issue link for setuptools * sampling_segmentation.py: implement AOI class verifications.py: update assert_crs_match function, add validate functions for rasters and vector files * finish draft of sampling with AOI objects (with basic validation), rather than from raw csv lines * train_segmentation.py: add warning for debugging and skip save checkpoint if val loss is None * tests/data/massachusetts: restore larger format to prevent val_loss=None * tests/data/massachusetts...: switch back to smaller image test_ci_segmentation_binary.yaml: tile images to 32, not 256 test_ci_segmentation_multiclass.yaml: idem train_segmentation.py: raise ValueError for empty train or val dataloader * aoi.py: - create an AOI object with input validation. AOI would be the core input for tiling, training and inference, though only yet implemented for tiling. - add stac item support geoutils.py: - add utils: is_stac_item, stack_vrts() for create artificial multi-band raster from single-bands files test_aoi.py: add first test for parsing raster input from 3 types to a single rasterio.RasterDataset object default.yaml: activate debug functionality for logging test_ci_segmentation_multiclass.yaml: replace 'modalities' with 'bands' key test_ci_segmentation_binary.yaml: idem sample_creation.py: delete utils.py: remove validation from read_csv() function. * inference_segmentation.py: remove read_modalities README.md: start updating * evaluate_segmentation.py: fix bug (remove read_modalities()) dataset/README.md: add documentation on input data configuration and csv format README.md: update * aoi.py: - add write multiband fonction for demo and debugging - move aois_from_csv from sampling_segmentation.py * aoi.py: remove circular import automatically created py Pycharm * fix typos and potential bugs introduced by Pycharm's automatic refactoring * aoi.py: use pre-existing raster validation function sampling_segmentation.py: move validation to aoi object utils.py: finish removing AWS bucket feature verifications.py: - update all data validation functions * inference_segmentation.py: remove bucket parameter in list_input_images * test_aoi.py: use local stac item (prevent timeout error at CI) --- GDL.py | 3 +- README.md | 50 +- .../dataset/test_ci_segmentation_binary.yaml | 2 +- .../test_ci_segmentation_multiclass.yaml | 2 +- config/gdl_config_template.yaml | 4 +- config/hydra/default.yaml | 1 + dataset/README.md | 59 ++ dataset/aoi.py | 333 ++++++++++ {data => dataset}/colormap.csv | 0 evaluate_segmentation.py | 7 +- inference_segmentation.py | 10 +- sample_creation.py | 577 ------------------ sampling_segmentation.py | 74 ++- ...2_Las_Vegas-056155973080_01_P001-WV03.json | 1 + tests/dataset/test_aoi.py | 17 + train_segmentation.py | 1 + utils/geoutils.py | 47 +- utils/logger.py | 1 - utils/utils.py | 92 +-- utils/verifications.py | 143 +++-- 20 files changed, 640 insertions(+), 784 deletions(-) create mode 100644 dataset/README.md create mode 100644 dataset/aoi.py rename {data => dataset}/colormap.csv (100%) delete mode 100644 sample_creation.py create mode 100644 tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json create mode 100644 tests/dataset/test_aoi.py diff --git a/GDL.py b/GDL.py index 2fe5cccf..35396b4e 100644 --- a/GDL.py +++ b/GDL.py @@ -41,7 +41,8 @@ def run_gdl(cfg: DictConfig) -> None: # check if the mode is chosen if type(cfg.mode) is DictConfig: msg = "You need to choose between those modes: {}" - raise logging.critical(msg.format(list(cfg.mode.keys()))) + logging.critical(msg.format(list(cfg.mode.keys()))) + raise ValueError() # save all overwritten parameters logging.info('\nOverwritten parameters in the config: \n' + cfg.general.config_override_dirname) diff --git a/README.md b/README.md index d7527c9d..0e6852cd 100644 --- a/README.md +++ b/README.md @@ -3,19 +3,16 @@ ## **Overview** -The **geo-deep-learning** project stems from an initiative at NRCan's [CCMEO](https://www.nrcan.gc.ca/earth-sciences/geomatics/10776). Its aim is to allow using Convolutional Neural Networks (CNN) with georeferenced data sets. -The overall learning process comprises three broad stages. +The **geo-deep-learning** project stems from an initiative at NRCan's [CCMEO](https://www.nrcan.gc.ca/earth-sciences/geomatics/10776). Its aim is to allow using Convolutional Neural Networks (CNN) with georeferenced datasets. -### Data preparation -The data preparation phase (sampling) allows creating sub-images that will be used for either training, validation or testing. -The first phase of the process is to determine sub-images (samples) to be used for training, validation and, optionally, test. -Images to be used must be of the geotiff type. -Sample locations in each image must be stored in a GeoPackage. +In geo-deep-learning, the learning process comprises two broad stages: sampling and training, followed by inference, which makes use of a trained model to make new predictions on unseen imagery. -[comment]: <> (> Note: A data analysis module can be found [here](./utils/data_analysis.py) and the documentation in [`docs/README.md`](./docs/README.md). Useful for balancing training data.) +### Data sampling (or [tiling](https://torchgeo.readthedocs.io/en/latest/user/glossary.html#term-tiling)) +The data preparation phase creates [chips](https://torchgeo.readthedocs.io/en/latest/user/glossary.html#term-chip) (or patches) that will be used for either training, validation or testing. +The sampling step requires a csv as input with a list of rasters and labels to be used in the subsequent training phase. See [dataset documentation](dataset#input-data). ### Training, along with validation and testing -The training phase is where the neural network learn to use the data prepared in the previous phase to make all the predictions. +The training phase is where the neural network learns to use the data prepared in the previous phase to make all the predictions. The crux of the learning process is the training phase. - Samples labeled "*trn*" as per above are used to train the neural network. @@ -38,18 +35,14 @@ This project comprises a set of commands to be run at a shell command prompt. E > The system can be used on your workstation or cluster. ## **Installation** -Those steps are for your workstation on Ubuntu 18.04 using miniconda. -Set and activate your python environment with the following commands: +To execute scripts in this project, first create and activate your python environment with the following commands: ```shell conda env create -f environment.yml conda activate geo_deep_env ``` -> For Windows OS: -> - Install rasterio, fiona and gdal first, before installing the rest. We've experienced some [installation issues](https://github.com/conda-forge/gdal-feedstock/issues/213), with those libraries. -> - Mlflow should be installed using pip rather than conda, as mentioned [here](https://github.com/mlflow/mlflow/issues/1951) - +> Tested on Ubuntu 20.04 and Windows 10 using miniconda. ## **Running GDL** -This is an example of how to run GDL with hydra in simple steps with the _**massachusetts buildings**_ dataset in the `/data` folder, for segmentation on buildings: +This is an example of how to run GDL with hydra in simple steps with the _**massachusetts buildings**_ dataset in the `tests/data/` folder, for segmentation on buildings: 1. Clone this github repo. ```shell @@ -67,15 +60,14 @@ python GDL.py mode=train python GDL.py mode=inference ``` -> This example is running with the default configuration `./config/gdl_config_template.yaml`, for further examples on running options see the [documentation](config/#Examples). -> You will also fund information on how to change the model or add a new one to GDL. +> This example runs with a default configuration `./config/gdl_config_template.yaml`. For further examples on configuration options see the [configuration documentation](config/#Examples). > If you want to introduce a new task like object detection, you only need to add the code in the main folder and name it `object_detection_sampling.py` for example. -> The principle is to name the code like `task_mode.py` and the `GDL.py` will deal with the rest. +> The principle is to name the code like `{task}_{mode}.py` and the `GDL.py` will deal with the rest. > To run it, you will need to add a new parameter in the command line `python GDL.py mode=sampling task=object_detection` or change the parameter inside the `./config/gdl_config_template.yaml`. ## **Folder Structure** -We suggest a high level structure to organize the images and the code. +We suggest the following high level structure to organize the images and the code. ``` ├── {dataset_name} └── data @@ -128,24 +120,6 @@ _**Don't forget to change the path of the dataset in the config yaml.**_ [comment]: <> ( num_gpus: 2) -[comment]: <> ( BGR_to_RGB: False # <-- must be already in RGB) - -[comment]: <> ( scale_data: [0,1]) - -[comment]: <> ( aux_vector_file:) - -[comment]: <> ( aux_vector_attrib:) - -[comment]: <> ( aux_vector_ids:) - -[comment]: <> ( aux_vector_dist_maps:) - -[comment]: <> ( aux_vector_dist_log:) - -[comment]: <> ( aux_vector_scale:) - -[comment]: <> ( debug_mode: True) - [comment]: <> ( # Module to include the NIR) [comment]: <> ( modalities: RGBN # <-- must be add) diff --git a/config/dataset/test_ci_segmentation_binary.yaml b/config/dataset/test_ci_segmentation_binary.yaml index 22507424..d53ea376 100644 --- a/config/dataset/test_ci_segmentation_binary.yaml +++ b/config/dataset/test_ci_segmentation_binary.yaml @@ -10,7 +10,7 @@ dataset: raw_data_dir: ${general.raw_data_dir} # imagery - modalities: RGB + bands: [R, G, B] # ground truth attribute_field: properties/class diff --git a/config/dataset/test_ci_segmentation_multiclass.yaml b/config/dataset/test_ci_segmentation_multiclass.yaml index 86c750e3..726334bc 100644 --- a/config/dataset/test_ci_segmentation_multiclass.yaml +++ b/config/dataset/test_ci_segmentation_multiclass.yaml @@ -10,7 +10,7 @@ dataset: raw_data_dir: ${general.raw_data_dir} # imagery - modalities: RGB + bands: [R, G, B] # ground truth attribute_field: properties/Quatreclasses diff --git a/config/gdl_config_template.yaml b/config/gdl_config_template.yaml index 0f1c43fe..db180cd6 100644 --- a/config/gdl_config_template.yaml +++ b/config/gdl_config_template.yaml @@ -29,9 +29,9 @@ general: workspace: your_name max_epochs: 2 # for train only min_epochs: 1 # for train only - raw_data_dir: data + raw_data_dir: dataset raw_data_csv: tests/sampling/sampling_segmentation_binary_ci.csv - sample_data_dir: data # where the hdf5 will be saved + sample_data_dir: dataset # where the hdf5 will be saved save_weights_dir: saved_model/${general.project_name} print_config: True # save the config in the log folder diff --git a/config/hydra/default.yaml b/config/hydra/default.yaml index ab86d6ad..dc2a364a 100644 --- a/config/hydra/default.yaml +++ b/config/hydra/default.yaml @@ -4,6 +4,7 @@ run: sweep: dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} subdir: ${hydra.job.num} +verbose: ${debug} # you can set here environment variables that are universal for all users # for system specific variables (like data paths) it's better to use .env file! diff --git a/dataset/README.md b/dataset/README.md new file mode 100644 index 00000000..a1b4d2f5 --- /dev/null +++ b/dataset/README.md @@ -0,0 +1,59 @@ +# Input data +The sampling and inference steps requires a csv referencing input data. An example of input csv can be found in [tests](tests/sampling/sampling_segmentation_binary_ci.csv). +Each row of this csv is considered, in geo-deep-learning terms, to be an [AOI](https://torchgeo.readthedocs.io/en/latest/user/glossary.html#term-area-of-interest-AOI). + +| raster path | vector ground truth path | dataset split | aoi id (optional) | +|---------------------------|--------------------------|---------------|-------------------| +| my_dir/my_geoimagery1.tif | my_dir/my_geogt1.gpkg | trn | Ontario-1 | +| my_dir/my_geoimagery2.tif | my_dir/my_geogt2.gpkg | tst | NewBrunswick-23 | +| ... | ... | ... | ... | + +> The use of aoi id information will be implemented in a near future. It will serve, for example, to print a detailed report of sampling, training and evaluation, or for easier debugging. + +The path to a custom csv must be entered in the [dataset configuration](https://github.com/NRCan/geo-deep-learning/blob/develop/config/dataset/test_ci_segmentation_binary.yaml#L9). See the [configuration documentation](config/README.md) for more information. +Also check the [suggested folder structure](https://github.com/NRCan/geo-deep-learning#folder-structure). + +## Dataset splits +Split in csv should be either "trn", "tst" or "inference". The validation split is automatically created during sampling. It's proportion is set by the [dataset config](https://github.com/NRCan/geo-deep-learning/blob/develop/config/dataset/test_ci_segmentation_binary.yaml#L8). + +## Raster and vector file compatibility +Rasters to be used must be in a format compatible with [rasterio](https://rasterio.readthedocs.io/en/latest/quickstart.html?highlight=supported%20raster%20format#opening-a-dataset-in-reading-mode)/[GDAL](https://gdal.org/drivers/raster/index.html) (ex.: GeoTiff). Similarly, labels (aka annotations) for each image must be stored as polygons in a [Geopandas compatible vector file](Rasters to be used must be in a format compatible with [rasterio](https://rasterio.readthedocs.io/en/latest/quickstart.html?highlight=supported%20raster%20format#opening-a-dataset-in-reading-mode)/[GDAL](https://gdal.org/drivers/raster/index.html) (ex.: GeoTiff). Similarly, labels (aka annotations) for each image must be stored as polygons in a [Geopandas compatible vector file](https://geopandas.org/en/stable/docs/user_guide/io.html#reading-spatial-data) (ex.: GeoPackage). +) (ex.: GeoPackage). + +## Single-band vs multi-band imagery + +To support both single-band and multi-band imagery, the path in the first column of an input csv can be in **one of three formats**: + +### 1. Path to a multi-band image file: +`my_dir/my_multiband_geofile.tif` + +### 2. Path to single-band image files, using only a common string +A path to a list of single-band rasters can be inserted in the csv, but only a the string common to all single-band files should be considered. +The "band specific" string in the file name must be in a [hydra-like interpolation format](https://hydra.cc/docs/1.0/advanced/override_grammar/basic/#primitives), with `${...}` notation. The interpolation string completed during execution by a dataset parameter with a list of desired band identifiers to help resolve the single-band filenames. + +#### Example: + +In [dataset config](../config/dataset/test_ci_segmentation_binary.yaml): + +`bands: [R, G, B]` + +In [input csv](../tests/sampling/sampling_segmentation_binary_ci.csv): + +| raster path | ground truth path | dataset split | +|------------------------------------------------------------|-------------------|---------------| +| my_dir/my_singleband_geofile_band_**${dataset.bands}**.tif | gt.gpkg | trn | + +During execution, this would result in using, **in the same order as bands appear in dataset config**, the following files: +`my_dir/my_singleband_geofile_band_R.tif` +`my_dir/my_singleband_geofile_band_G.tif` +`my_dir/my_singleband_geofile_band_B.tif` + +> To simplify the use of both single-band and multi-band rasters through a unique input pipeline, single-band files are artificially merged as a [virtual raster](https://gdal.org/drivers/raster/vrt.html). + +### 3. Path to a Stac Item +> Only Stac Items referencing **single-band assets** are supported currently. See [our Worldview-2 example](https://datacube-stage.services.geo.ca/api/collections/spacenet-samples/items/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03). + +Bands must be selected by [common name](https://github.com/stac-extensions/eo/#common-band-names) in dataset config: +`bands: ["red", "green", "blue"]` + +> Order matters: `["red", "green", "blue"]` is not equal to `["blue", "green", "red"]` ! \ No newline at end of file diff --git a/dataset/aoi.py b/dataset/aoi.py new file mode 100644 index 00000000..ba9ead79 --- /dev/null +++ b/dataset/aoi.py @@ -0,0 +1,333 @@ +from collections import OrderedDict +from pathlib import Path +from typing import Union, Sequence, Dict, Tuple, List + +import geopandas as gpd +import pyproj +import pystac +import rasterio +from pystac.extensions.eo import ItemEOExtension, Band +from omegaconf import listconfig, ListConfig +from shapely.geometry import box +from solaris.utils.core import _check_rasterio_im_load, _check_gdf_load +from tqdm import tqdm + +from utils.geoutils import stack_vrts, is_stac_item +from utils.logger import get_logger +from utils.utils import read_csv +from utils.verifications import validate_by_geopandas, assert_crs_match, validate_raster, \ + validate_num_bands, validate_features_from_gpkg + +logging = get_logger(__name__) # import logging + + +class SingleBandItemEO(ItemEOExtension): + """ + Single-Band Stac Item with assets by common name. + For info on common names, see https://github.com/stac-extensions/eo#common-band-names + """ + def __init__(self, item: pystac.Item, bands: Sequence = None): + super().__init__(item) + if not is_stac_item(item): + raise TypeError(f"Expected a valid pystac.Item object. Got {type(item)}") + self.item = item + self._assets_by_common_name = None + + if bands is not None and len(bands) == 0: + logging.warning(f"At least one band should be chosen if assets need to be reached") + + # Create band inventory (all available bands) + self.bands_all = [band for band in self.asset_by_common_name.keys()] + + # Make sure desired bands are subset of inventory + if not set(bands).issubset(set(self.bands_all)): + raise ValueError(f"Requested bands ({bands}) should be a subset of available bands ({self.bands_all})") + + # Filter only requested bands + self.bands_requested = {band: self.asset_by_common_name[band] for band in bands} + logging.debug(self.bands_all) + logging.debug(self.bands_requested) + + @property + def asset_by_common_name(self) -> Dict: + """ + Get assets by common band name (only works for assets containing 1 band) + Adapted from: + https://github.com/sat-utils/sat-stac/blob/40e60f225ac3ed9d89b45fe564c8c5f33fdee7e8/satstac/item.py#L75 + @return: + """ + if self._assets_by_common_name is None: + self._assets_by_common_name = OrderedDict() + for name, a_meta in self.item.assets.items(): + bands = [] + if 'eo:bands' in a_meta.extra_fields.keys(): + bands = a_meta.extra_fields['eo:bands'] + if len(bands) == 1: + eo_band = bands[0] + if 'common_name' in eo_band.keys(): + common_name = eo_band['common_name'] + if not Band.band_range(common_name): # Hacky but easiest way to validate common names + raise ValueError(f'Must be one of the accepted common names. Got "{common_name}".') + else: + self._assets_by_common_name[common_name] = {'href': a_meta.href, 'name': name} + if not self._assets_by_common_name: + raise ValueError(f"Common names for assets cannot be retrieved") + return self._assets_by_common_name + + +class AOI(object): + """ + Object containing all data information about a single area of interest + based on https://github.com/stac-extensions/ml-aoi + """ + + def __init__(self, raster: Union[Path, str], + raster_bands_request: List = None, + label: Union[Path, str] = None, + split: str = None, + aoi_id: str = None, + collection: str = None, + raster_num_bands_expected: int = None, + attr_field_filter: str = None, + attr_values_filter: Sequence = None, + write_multiband: bool = False): + # TODO: dict printer to output report on list of aois + """ + @param raster: pathlib.Path or str + Path to source imagery + @param label: pathlib.Path or str + Path to ground truth file. If not provided, AOI is considered only for inference purposes + @param split: str + Name of destination dataset for aoi. Should be 'trn', 'tst' or 'inference' + @param aoi_id: str + Name or id (loosely defined) of area of interest. Used to name output folders. + Multiple AOI instances can bear the same name. + @param collection: str + Name of collection containing AOI. All AOIs in the same collection should never be spatially overlapping + @param raster_num_bands_expected: + Number of bands expected in processed raster (e.g. after combining single-bands files into a VRT) + @param attr_field_filter: str, optional + Name of attribute field used to filter features. If not provided all geometries in ground truth file + will be considered. + @param attr_values_filter: list of ints, optional + The list of attribute values in given attribute field used to filter features from ground truth file. + If not provided, all values in field will be considered + @param write_multiband: bool, optional + If True, a multi-band raster side by side with single-bands rasters as provided in input csv. For debugging purposes. + """ + # Check and parse raster data + if not isinstance(raster, str): + raise TypeError(f"Raster path should be a string.\nGot {raster} of type {type(raster)}") + self.raster_raw_input = raster + if raster_bands_request and not isinstance(raster_bands_request, (List, ListConfig)): + raise ValueError(f"Requested bands should be a list." + f"\nGot {raster_bands_request} of type {type(raster_bands_request)}") + self.raster_bands_request = raster_bands_request + raster_parsed = self.parse_input_raster(csv_raster_str=self.raster_raw_input, + raster_bands_requested=self.raster_bands_request) + # If parsed result is a tuple, then we're dealing with single-band files + if isinstance(raster_parsed, Tuple): + [validate_raster(file) for file in raster_parsed] + self.raster_tuple = raster_parsed + raster_parsed = stack_vrts(raster_parsed) + else: + validate_raster(self.raster_raw_input) + self.raster_tuple = None + + if raster_num_bands_expected: + validate_num_bands(raster_path=raster_parsed, num_bands=raster_num_bands_expected) + + self.raster = _check_rasterio_im_load(str(raster_parsed)) + + if self.raster_tuple and write_multiband: + self.write_multiband_from_singleband_rasters_as_vrt() + + # Check label data + if label: + validate_by_geopandas(label) + self.label_gdf = _check_gdf_load(str(label)) + label_bounds = self.label_gdf.total_bounds + label_bounds_box = box(*label_bounds.tolist()) + raster_bounds_box = box(*list(self.raster.bounds)) + if not label_bounds_box.intersects(raster_bounds_box): + raise ValueError(f"Features in label file {label} do not intersect with bounds of raster file " + f"{self.raster.name}") + validate_features_from_gpkg(label, attr_field_filter) + + self.label = Path(label) + # TODO: unit test for failed CRS match + try: + # TODO: check if this creates overhead. Make data validation optional? + self.crs_match, self.epsg_raster, self.epsg_label = assert_crs_match(self.raster, self.label_gdf) + except pyproj.exceptions.CRSError as e: + logging.warning(f"\nError while checking CRS match between raster and label." + f"\n{e}") + else: + self.label = self.crs_match = self.epsg_raster = self.epsg_label = None + + # Check split string + if split and not isinstance(split, str): + raise ValueError(f"\nDataset split should be a string.\nGot {split}.") + + if label and split not in ['trn', 'tst', 'inference']: + raise ValueError(f"\nWith ground truth, split should be 'trn', 'tst' or 'inference'. \nGot {split}") + # force inference split if no label provided + elif not label and (split != 'inference' or not split): + logging.warning(f"\nNo ground truth provided. Dataset split will be set to 'inference'" + f"\nOriginal split: {split}") + split = 'inference' + self.split = split + + # Check aoi_id string + if aoi_id and not isinstance(aoi_id, str): + raise TypeError(f'AOI name should be a string. Got {aoi_id} of type {type(aoi_id)}') + elif not aoi_id: + aoi_id = self.raster.stem # Defaults to name of image without suffix + self.aoi_id = aoi_id + + # Check collection string + if collection and not isinstance(collection, str): + raise TypeError(f'Collection name should be a string. Got {collection} of type {type(collection)}') + self.aoi_id = aoi_id + + # If ground truth is provided, check attribute field + if label and attr_field_filter and not isinstance(attr_field_filter, str): + raise TypeError(f'Attribute field name should be a string.\n' + f'Got {attr_field_filter} of type {type(attr_field_filter)}') + self.attr_field_filter = attr_field_filter + + # If ground truth is provided, check attribute values to filter from + if label and attr_values_filter and not isinstance(attr_values_filter, (list, listconfig.ListConfig)): + raise TypeError(f'Attribute values should be a list.\n' + f'Got {attr_values_filter} of type {type(attr_values_filter)}') + self.attr_values_filter = attr_values_filter + logging.debug(self) + + @classmethod + def from_dict(cls, + aoi_dict, + bands_requested: List = None, + attr_field_filter: str = None, + attr_values_filter: list = None): + """Instanciates an AOI object from an input-data dictionary as expected by geo-deep-learning""" + if not isinstance(aoi_dict, dict): + raise TypeError('Input data should be a dictionary.') + # TODO: change dataset for split + if not {'tif', 'gpkg', 'split'}.issubset(set(aoi_dict.keys())): + raise ValueError(f"Input data should minimally contain the following keys: \n" + f"'tif', 'gpkg', 'split'.") + if not aoi_dict['gpkg']: + logging.warning(f"No ground truth data found for {aoi_dict['tif']}.\n" + f"Only imagery will be processed from now on") + if "aoi_id" not in aoi_dict.keys() or not aoi_dict['aoi_id']: + aoi_dict['aoi_id'] = Path(aoi_dict['tif']).stem + aoi_dict['attribute_name'] = attr_field_filter + new_aoi = cls( + raster=aoi_dict['tif'], + raster_bands_request=bands_requested, + label=aoi_dict['gpkg'], + split=aoi_dict['split'], + attr_field_filter=attr_field_filter, + attr_values_filter=attr_values_filter, + aoi_id=aoi_dict['aoi_id'] + ) + return new_aoi + + def __str__(self): + return ( + f"\nAOI ID: {self.aoi_id}" + f"\n\tRaster: {self.raster.name}" + f"\n\tLabel: {self.label}" + f"\n\tCRS match: {self.crs_match}" + f"\n\tSplit: {self.split}" + f"\n\tAttribute field filter: {self.attr_field_filter}" + f"\n\tAttribute values filter: {self.attr_values_filter}" + ) + + # TODO def to_dict() + # return a dictionary containing all important attributes of AOI (ex.: to print a report or output csv) + + # TODO def raster_stats() + # return a dictionary with mean and std of raster, per band + + def write_multiband_from_singleband_rasters_as_vrt(self): + """Writes a multiband raster to file from a pre-built VRT. For debugging and demoing""" + if not self.raster.driver == 'VRT' or not self.raster_tuple or not "${dataset.bands}" in self.raster_raw_input: + logging.warning(f"To write a multi-band raster from single-band files, a VRT must be provided." + f"\nGot {self.raster.meta}") + return + + out_tif_path = self.raster_raw_input.replace("${dataset.bands}", ''.join(self.raster_bands_request)) + out_meta = self.raster.meta.copy() + out_meta.update({"driver": "GTiff", + "count": self.raster.count}) + with rasterio.open(out_tif_path, "w", **out_meta) as dest: + logging.debug(f"Writing multi-band raster to {out_tif_path}") + out_img = self.raster.read() + dest.write(out_img) + + @staticmethod + def parse_input_raster(csv_raster_str: str, raster_bands_requested: List) -> Union[str, Tuple]: + """ + From input csv, determine if imagery is + 1. A Stac Item with single-band assets (multi-band assets not implemented) + 2. Single-band imagery as path or url with hydra-like interpolation for band identification + 3. Multi-band path or url + @param csv_raster_str: + input imagery to parse + @param raster_bands_requested: + dataset configuration parameters + @return: + """ + if is_stac_item(csv_raster_str): + item = SingleBandItemEO(item=pystac.Item.from_file(csv_raster_str), bands=raster_bands_requested) + raster = [Path(value['href']) for value in item.bands_requested.values()] + return tuple(raster) + elif "${dataset.bands}" in csv_raster_str: + if not isinstance(raster_bands_requested, (List, ListConfig)) or len(raster_bands_requested) == 0: + raise ValueError(f"\nRequested bands should a list of bands. " + f"\nGot {raster_bands_requested} of type {type(raster_bands_requested)}") + raster = [Path(csv_raster_str.replace("${dataset.bands}", band)) for band in raster_bands_requested] + return tuple(raster) + else: + try: + validate_raster(csv_raster_str) + return Path(csv_raster_str) + except (FileNotFoundError, rasterio.RasterioIOError, TypeError) as e: + logging.critical(f"Couldn't parse input raster. Got {csv_raster_str}") + raise e + + +def aois_from_csv(csv_path: Union[str, Path], bands_requested: List = None, attr_field_filter: str = None, attr_values_filter: str = None): + """ + Creates list of AOIs by parsing a csv file referencing input data + @param csv_path: + path to csv file containing list of input data. See README for details on expected structure of csv. + @param bands_requested: + List of bands to select from inputted imagery. Applies only to single-band input imagery. + @param attr_values_filter: + Attribute filed to filter features from + @param attr_field_filter: + Attribute values (for given attribute field) for features to keep + Returns: a list of AOIs objects + """ + aois = [] + data_list = read_csv(csv_path) + logging.info(f'\n\tSuccessfully read csv file: {Path(csv_path).name}\n' + f'\tNumber of rows: {len(data_list)}\n' + f'\tCopying first row:\n{data_list[0]}\n') + for i, aoi_dict in tqdm(enumerate(data_list), desc="Creating AOI's"): + try: + new_aoi = AOI.from_dict( + aoi_dict=aoi_dict, + bands_requested=bands_requested, + attr_field_filter=attr_field_filter, + attr_values_filter=attr_values_filter + ) + logging.debug(new_aoi) + aois.append(new_aoi) + except FileNotFoundError as e: + logging.critical(f"{e}\nGround truth file may not exist or is empty.\n" + f"Failed to create AOI:\n{aoi_dict}\n" + f"Index: {i}") + return aois diff --git a/data/colormap.csv b/dataset/colormap.csv similarity index 100% rename from data/colormap.csv rename to dataset/colormap.csv diff --git a/evaluate_segmentation.py b/evaluate_segmentation.py index e75515b2..2219b93a 100644 --- a/evaluate_segmentation.py +++ b/evaluate_segmentation.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd import rasterio -from hydra.utils import get_original_cwd from mlflow import log_metrics from shapely.geometry import Polygon from tqdm import tqdm @@ -15,7 +14,7 @@ from utils.geoutils import clip_raster_with_gpkg, vector_to_raster from utils.metrics import ComputePixelMetrics -from utils.utils import get_key_def, list_input_images, read_modalities +from utils.utils import get_key_def, list_input_images from utils.logger import get_logger from utils.verifications import validate_num_classes, assert_crs_match @@ -79,8 +78,8 @@ def main(params): """ start_seg = time.time() state_dict = get_key_def('state_dict_path', params['inference'], to_path=True, validate_path_exists=True) - modalities = read_modalities(get_key_def('modalities', params['dataset'], expected_type=str)) - num_bands = len(modalities) + bands_requested = get_key_def('bands', params['dataset'], default=None, expected_type=Sequence) + num_bands = len(bands_requested) working_folder = state_dict.parent.joinpath(f'inference_{num_bands}bands') img_dir_or_csv = get_key_def('img_dir_or_csv_file', params['inference'], expected_type=str, to_path=True, validate_path_exists=True) diff --git a/inference_segmentation.py b/inference_segmentation.py index aad521f3..609b7a1c 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -24,7 +24,7 @@ from models.model_choice import define_model, read_checkpoint from utils import augmentation from utils.utils import get_device_ids, get_key_def, \ - list_input_images, add_metadata_from_raster_to_sample, _window_2D, read_modalities, set_device + list_input_images, add_metadata_from_raster_to_sample, _window_2D, set_device from utils.verifications import validate_input_imagery # Set the logging file @@ -327,11 +327,11 @@ def main(params: Union[DictConfig, dict]) -> None: ) # Dataset params - modalities = get_key_def('modalities', params['dataset'], default=("red", "blue", "green"), expected_type=Sequence) + bands_requested = get_key_def('bands', params['dataset'], default=("red", "blue", "green"), expected_type=Sequence) classes_dict = get_key_def('classes_dict', params['dataset'], expected_type=DictConfig) num_classes = len(classes_dict) num_classes = num_classes + 1 if num_classes > 1 else num_classes # multiclass account for background - num_bands = len(modalities) + num_bands = len(bands_requested) working_folder = state_dict.parent.joinpath(f'inference_{num_bands}bands') logging.info("\nThe state dict path directory used '{}'".format(working_folder)) @@ -385,12 +385,12 @@ def main(params: Union[DictConfig, dict]) -> None: ) # GET LIST OF INPUT IMAGES FOR INFERENCE - list_img = list_input_images(img_dir_or_csv, None, glob_patterns=["*.tif", "*.TIF"]) + list_img = list_input_images(img_dir_or_csv, glob_patterns=["*.tif", "*.TIF"]) # VALIDATION: anticipate problems with imagery before entering main for loop for info in tqdm(list_img, desc='Validating imagery'): is_valid = validate_input_imagery(info['tif'], num_bands=num_bands, extended=debug) - # TODO: exclude invalid imagery at inference (prevent execution break) + # TODO: address with issue #310 logging.info('\nSuccessfully validated imagery') # LOOP THROUGH LIST OF INPUT IMAGES diff --git a/sample_creation.py b/sample_creation.py deleted file mode 100644 index 7563f3a0..00000000 --- a/sample_creation.py +++ /dev/null @@ -1,577 +0,0 @@ -import argparse -from datetime import datetime -import os -import numpy as np - -np.random.seed(1234) # Set random seed for reproducibility -import warnings -import rasterio -import shutil -import time -import json - -from pathlib import Path -from tqdm import tqdm -from collections import Counter -from typing import List -from ruamel_yaml import YAML - -from utils.create_dataset import create_files_and_datasets -from utils.utils import get_key_def, pad, pad_diff, add_metadata_from_raster_to_sample -from utils.geoutils import vector_to_raster, clip_raster_with_gpkg -from utils.verifications import assert_crs_match, validate_num_classes -from rasterio.windows import Window -from rasterio.plot import reshape_as_image - - -def read_parameters(param_file): - """Read and return parameters in .yaml file - Args: - param_file: Full file path of the parameters file - Returns: - YAML (Ruamel) CommentedMap dict-like object - """ - yaml = YAML() - with open(param_file) as yamlfile: - params = yaml.load(yamlfile) - return params - - -def validate_class_prop_dict(actual_classes_dict, config_dict): - """ - Populate dictionary containing class values found in vector data with values (thresholds) from sample/class_prop - parameter in config file - - actual_classes_dict: dict - Dictionary where each key is a class found in vector data. Value is not relevant (should be 0) - - config_dict: - Dictionary with class ids (keys and thresholds (values) from config file - - """ - # Validation of class proportion parameters (assert types). - if not isinstance(config_dict, dict): - warnings.warn(f"Class_proportion parameter should be a dictionary. Got type {type(config_dict)}. " - f"Ignore if parameter was omitted)") - return None - - for key, value in config_dict.items(): - try: - assert isinstance(key, str) - int(key) - except (ValueError, AssertionError): - f"Class should be a string castable as an integer. Got {key} of type {type(key)}" - assert isinstance(value, int), f"Class value should be an integer, got {value} of type {type(value)}" - - # Populate actual classes dictionary with values from config - for key, value in config_dict.items(): - if int(key) in actual_classes_dict.keys(): - actual_classes_dict[int(key)] = value - else: - warnings.warn(f"Class {key} not found in provided vector data.") - - return actual_classes_dict.copy() - - -def getFeatures(gdf): - """Function to parse features from GeoDataFrame in such a manner that rasterio wants them""" - import json - return [json.loads(gdf.to_json())['features'][0]['geometry']] - -def process_raster_img(rst_pth, gpkg_pth): - with rasterio.open(rst_pth) as src: - rst_pth = clip_raster_with_gpkg(src, gpkg_pth) - # TODO: Return clipped raster handle - return rst_pth, src - - -def reorder_bands(a: List[str], b: List[str]): - read_band_order = [] - for band in a: - if band in b: - read_band_order.insert(a.index(band) + 1, b.index(band) + 1) - # print(f'{a.index(band)},{band}, {b.index(band)}') - return read_band_order - - -def gen_img_samples(rst_pth, tile_size, dist_samples, *band_order): - with rasterio.open(rst_pth) as src: - for row in range(0, src.height, dist_samples): - for column in range(0, src.width, dist_samples): - window = Window.from_slices(slice(row, row + tile_size), - slice(column, column + tile_size)) - if band_order: - window_array = reshape_as_image(src.read(band_order[0], window=window)) - else: - window_array = reshape_as_image(src.read(window=window)) - - if window_array.shape[0] < tile_size or window_array.shape[1] < tile_size: - padding = pad_diff(window_array.shape[0], window_array.shape[1], tile_size, tile_size) - window_array = pad(window_array, padding, fill=np.nan) - - yield window_array - - -def process_vector_label(rst_pth, gpkg_pth, ids): - if rst_pth is not None: - with rasterio.open(rst_pth) as src: - np_label = vector_to_raster(vector_file=gpkg_pth, - input_image=src, - out_shape=(src.height, src.width), - attribute_name='properties/Quatreclasses', - fill=0, - attribute_values=ids, - merge_all=True, - ) - return np_label - - -def gen_label_samples(np_label, dist_samples, tile_size): - h, w = np_label.shape - for row in range(0, h, dist_samples): - for column in range(0, w, dist_samples): - target = np_label[row:row + tile_size, column:column + tile_size] - target_row = target.shape[0] - target_col = target.shape[1] - if target_row < tile_size or target_col < tile_size: - padding = pad_diff(target_row, target_col, tile_size, - tile_size) # array, actual height, actual width, desired size - target = pad(target, padding, fill=-1) - indices = (row, column) - yield target, indices - - -def minimum_annotated_percent(target_background_percent, min_annotated_percent): - if not min_annotated_percent: - return True - elif float(target_background_percent) <= 100 - min_annotated_percent: - return True - - return False - - -def append_to_dataset(dataset, sample): - """ - Append a new sample to a provided dataset. The dataset has to be expanded before we can add value to it. - :param dataset: - :param sample: data to append - :return: Index of the newly added sample. - """ - old_size = dataset.shape[0] # this function always appends samples on the first axis - dataset.resize(old_size + 1, axis=0) - dataset[old_size, ...] = sample - return old_size - - -def class_proportion(target, sample_size: int, class_min_prop: dict): - if not class_min_prop: - return True - sample_total = sample_size ** 2 - for key, value in class_min_prop.items(): - if key not in np.unique(target): - target_prop_classwise = 0 - else: - target_prop_classwise = (round((np.bincount(target.clip(min=0).flatten())[key] / sample_total) * 100, 1)) - if target_prop_classwise < value: - return False - return True - - -def add_to_datasets(dataset, - samples_file, - val_percent, - val_sample_file, - data, - target, - sample_metadata, - metadata_idx, - dict_classes): - """ Add sample to Hdf5 (trn, val or tst) and computes pixel classes(%). """ - val = False - if dataset == 'trn': - random_val = np.random.randint(1, 100) - if random_val > val_percent: - pass - else: - val = True - samples_file = val_sample_file - append_to_dataset(samples_file["sat_img"], data) - append_to_dataset(samples_file["map_img"], target) - append_to_dataset(samples_file["sample_metadata"], repr(sample_metadata)) - append_to_dataset(samples_file["meta_idx"], metadata_idx) - - # adds pixel count to pixel_classes dict for each class in the image - for key, value in enumerate(np.bincount(target.clip(min=0).flatten())): - cls_keys = dict_classes.keys() - if key in cls_keys: - dict_classes[key] += value - elif key not in cls_keys and value > 0: - raise ValueError(f"A class value was written ({key}) that was not defined in the classes ({cls_keys}).") - - return val - - -def sample_prep(src, data, target, indices, gpkg_classes, sample_size, sample_type, samples_count, samples_file, - num_classes, - val_percent, - val_sample_file, - min_annot_perc=None, - class_prop=None, - dontcare=-1 - ): - added_samples = 0 - excl_samples = 0 - pixel_classes = {key: 0 for key in gpkg_classes} - background_val = 0 - pixel_classes[background_val] = 0 - class_prop = validate_class_prop_dict(pixel_classes, class_prop) - pixel_classes[dontcare] = 0 - - image_metadata = add_metadata_from_raster_to_sample(sat_img_arr=data, - raster_handle=src, - meta_map={}, - raster_info={}) - # Save label's per class pixel count to image metadata - image_metadata['source_label_bincount'] = {class_num: count for class_num, count in - enumerate(np.bincount(target.clip(min=0).flatten())) - if count > 0} # TODO: add this to add_metadata_from[...] function? - - if sample_type == 'trn': - idx_samples = samples_count['trn'] - append_to_dataset(val_sample_file["metadata"], repr(image_metadata)) - elif sample_type == 'tst': - idx_samples = samples_count['tst'] - else: - raise ValueError(f"Sample type must be trn or tst. Provided type is {sample_type}") - - idx_samples_v = samples_count['val'] - # Adds raster metadata to the dataset. All samples created by tiling below will point to that metadata by index - metadata_idx = append_to_dataset(samples_file["metadata"], repr(image_metadata)) - u, count = np.unique(target, return_counts=True) - # print('class:', u, 'count:', count) - target_background_percent = round(count[0] / np.sum(count) * 100 if 0 in u else 0, 1) - sample_metadata = {'sample_indices': indices} - val = False - if minimum_annotated_percent(target_background_percent, min_annot_perc) and \ - class_proportion(target, sample_size, class_prop): - val = add_to_datasets(dataset=sample_type, - samples_file=samples_file, - val_percent=val_percent, - val_sample_file=val_sample_file, - data=data, - target=target, - sample_metadata=sample_metadata, - metadata_idx=metadata_idx, - dict_classes=pixel_classes) - if val: - idx_samples_v += 1 - else: - idx_samples += 1 - added_samples += 1 - else: - excl_samples += 1 - - target_class_num = np.max(u) - if num_classes < target_class_num: - num_classes = target_class_num - - sample_type_ = 'val' if val else sample_type - # assert added_samples > 0, "No sample added for current raster. Problems may occur with use of metadata" - - if sample_type == 'tst': - samples_count['tst'] = idx_samples - else: - samples_count['trn'] = idx_samples - samples_count['val'] = idx_samples_v - - return samples_count, num_classes, pixel_classes - - -def class_pixel_ratio(pixel_classes: dict, source_data: str, file_path: str): - with open(file_path, 'a+') as f: - pixel_total = sum(pixel_classes.values()) - print(f'\n****{source_data}****\n', file=f) - for i in pixel_classes: - prop = round((pixel_classes[i] / pixel_total) * 100, 1) if pixel_total > 0 else 0 - print(f'{source_data}_class', i, ':', prop, '%', file=f) - print(f'\n****{source_data}****\n', file=f) - - -def main(params): - """ - Dataset preparation (trn, val, tst). - :param params: (dict) Parameters found in the yaml config file. - - """ - assert params['global']['task'] == 'segmentation', \ - f"sample_creation.py isn't necessary when performing classification tasks" - num_classes = get_key_def('num_classes', params['global'], expected_type=int) - num_bands = get_key_def('number_of_bands', params['global'], expected_type=int) - debug = get_key_def('debug_mode', params['global'], False) - targ_ids = get_key_def('target_ids', params['sample'], None, expected_type=List) - - # SET BASIC VARIABLES AND PATHS. CREATE OUTPUT FOLDERS. - val_percent = params['sample']['val_percent'] - samples_size = params["global"]["samples_size"] - overlap = params["sample"]["overlap"] - dist_samples = round(samples_size * (1 - (overlap / 100))) - min_annot_perc = get_key_def('min_annotated_percent', params['sample']['sampling_method'], None, expected_type=int) - - list_params = params['read_img'] - source_pan = get_key_def('pan', list_params['source'], default=False, expected_type=bool) - source_mul = get_key_def('mul', list_params['source'], default=False, expected_type=bool) - mul_band_order = get_key_def('mulband', list_params['source'], default=[], expected_type=list) - prep_band = get_key_def('band', list_params['prep'], default=[], expected_type=list) - tst_set = get_key_def('benchmark', list_params, default=[], expected_type=list) - in_pth = get_key_def('input_file', list_params, default='data_file.json', expected_type=str) - sensor_lst = get_key_def('sensorID', list_params, default=['GeoEye1', 'QuickBird2' 'WV2', 'WV3', 'WV4'], - expected_type=list) - month_range = get_key_def('month_range', list_params, default=list(range(1, 12 + 1)), expected_type=list) - root_folder = Path(get_key_def('root_img_folder', list_params, default='', expected_type=str)) - gpkg_status = 'all' - - data_path = Path(params['global']['data_path']) - Path.mkdir(data_path, exist_ok=True, parents=True) - if not data_path.is_dir(): - raise FileNotFoundError(f'Could not locate data path {data_path}') - - # mlflow logging - experiment_name = get_key_def('mlflow_experiment_name', params['global'], default='gdl-training', expected_type=str) - samples_folder_name = (f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands' - f'_{experiment_name}') - samples_folder = data_path.joinpath(samples_folder_name) - if samples_folder.is_dir(): - if debug: - # Move existing data folder with a random suffix. - last_mod_time_suffix = datetime.fromtimestamp(samples_folder.stat().st_mtime).strftime('%Y%m%d-%H%M%S') - shutil.move(samples_folder, data_path.joinpath(f'{str(samples_folder)}_{last_mod_time_suffix}')) - else: - raise FileExistsError(f'Data path exists: {samples_folder}. Remove it or use a different experiment_name.') - - Path.mkdir(samples_folder, exist_ok=False) # TODO: what if we want to append samples to existing hdf5? - trn_hdf5, val_hdf5, tst_hdf5 = create_files_and_datasets(samples_size=samples_size, - number_of_bands=num_bands, - samples_folder=samples_folder, - params=params) - - class_prop = get_key_def('class_proportion', params['sample']['sampling_method'], None, expected_type=dict) - dontcare = get_key_def("ignore_index", params["training"], -1) - number_samples = {'trn': 0, 'val': 0, 'tst': 0} - number_classes = 0 - - pixel_pan_counter = Counter() - pixel_mul_counter = Counter() - pixel_prep_counter = Counter() - filename = samples_folder.joinpath('class_distribution.txt') - - with open(Path(in_pth), 'r') as fin: - dict_images = json.load(fin) - - for i_dict in tqdm(dict_images['all_images'], desc=f'Writing samples to {samples_folder}'): - if i_dict['sensorID'] in sensor_lst and \ - datetime.strptime(i_dict['date']['yyyy/mm/dd'], '%Y/%m/%d').month in month_range: - - if source_pan: - if not len(i_dict['pan_img']) == 0 and i_dict['gpkg']: - if gpkg_status == 'all': - if 'corr' or 'prem' in i_dict['gpkg'].keys(): - gpkg = root_folder.joinpath(list(i_dict['gpkg'].values())[0]) - gpkg_classes = validate_num_classes(gpkg, num_classes, - 'properties/Quatreclasses', - dontcare, - targ_ids) - for img_pan in i_dict['pan_img']: - img_pan = root_folder.joinpath(img_pan) - assert_crs_match(img_pan, gpkg) - rst_pth, r_ = process_raster_img(img_pan, gpkg) - np_label = process_vector_label(rst_pth, gpkg, targ_ids) - if np_label is not None: - if Path(gpkg).stem in tst_set: - sample_type = 'tst' - out_file = tst_hdf5 - else: - sample_type = 'trn' - out_file = trn_hdf5 - val_file = val_hdf5 - src = r_ - pan_label_gen = gen_label_samples(np_label, dist_samples, samples_size) - pan_img_gen = gen_img_samples(rst_pth, samples_size, dist_samples) - else: - continue - for pan_img, pan_label in zip(pan_img_gen, pan_label_gen): - number_samples, number_classes, class_pixels_pan = sample_prep(src, pan_img, pan_label[0], - pan_label[1], gpkg_classes, - samples_size, sample_type, - number_samples, out_file, - number_classes, - val_percent, val_file, - min_annot_perc, - class_prop=class_prop, - dontcare=dontcare) - pixel_pan_counter.update(class_pixels_pan) - - if source_mul: - if not len(i_dict['mul_img']) == 0 and i_dict['gpkg']: - band_order = reorder_bands(i_dict['mul_band'], mul_band_order) - if gpkg_status == 'all': - if 'corr' or 'prem' in i_dict['gpkg'].keys(): - gpkg = root_folder.joinpath(list(i_dict['gpkg'].values())[0]) - gpkg_classes = validate_num_classes(gpkg, num_classes, - 'properties/Quatreclasses', - dontcare, - targ_ids) - for img_mul in i_dict['mul_img']: - img_mul = root_folder.joinpath(img_mul) - assert_crs_match(img_mul, gpkg) - rst_pth, r_ = process_raster_img(img_mul, gpkg) - np_label = process_vector_label(rst_pth, gpkg, targ_ids) - if np_label is not None: - if Path(gpkg).stem in tst_set: - sample_type = 'tst' - out_file = tst_hdf5 - else: - sample_type = 'trn' - out_file = trn_hdf5 - val_file = val_hdf5 - src = r_ - - mul_label_gen = gen_label_samples(np_label, dist_samples, samples_size) - mul_img_gen = gen_img_samples(rst_pth, samples_size, dist_samples, band_order) - else: - continue - for mul_img, mul_label in zip(mul_img_gen, mul_label_gen): - number_samples, number_classes, class_pixels_mul = sample_prep(src, mul_img, mul_label[0], - mul_label[1], gpkg_classes, - samples_size, sample_type, - number_samples, out_file, - number_classes, - val_percent, val_file, - min_annot_perc, - class_prop=class_prop, - dontcare=dontcare) - pixel_mul_counter.update(class_pixels_mul) - - if prep_band: - bands_gen_list = [] - if set(prep_band).issubset({'R', 'G', 'B', 'N'}): - for ib in prep_band: - if i_dict[f'{ib}_band'] and i_dict['gpkg']: - i_dict[f'{ib}_band'] = root_folder.joinpath(i_dict[f'{ib}_band']) - if gpkg_status == 'all': - if 'corr' or 'prem' in i_dict['gpkg'].keys(): - gpkg = root_folder.joinpath(list(i_dict['gpkg'].values())[0]) - gpkg_classes = validate_num_classes(gpkg, num_classes, - 'properties/Quatreclasses', - dontcare, - targ_ids) - assert_crs_match(i_dict[f'{ib}_band'], gpkg) - rst_pth, r_ = process_raster_img(i_dict[f'{ib}_band'], gpkg) - np_label = process_vector_label(rst_pth, gpkg, targ_ids) - prep_img_gen = gen_img_samples(rst_pth, samples_size, dist_samples) - bands_gen_list.append(prep_img_gen) - - if np_label is not None: - if Path(gpkg).stem in tst_set: - sample_type = 'tst' - out_file = tst_hdf5 - else: - sample_type = 'trn' - out_file = trn_hdf5 - val_file = val_hdf5 - src = r_ - prep_label_gen = gen_label_samples(np_label, dist_samples, samples_size) - if len(prep_band) and len(bands_gen_list) == 1: - for b1, prep_label in zip(bands_gen_list[0], prep_label_gen): - prep_img = b1 - number_samples, number_classes, class_pixels_prep = sample_prep(src, prep_img, - prep_label[0], - prep_label[1], - gpkg_classes, - samples_size, - sample_type, - number_samples, - out_file, - number_classes, - val_percent, val_file, - min_annot_perc, - class_prop=class_prop, - dontcare=dontcare) - pixel_prep_counter.update(class_pixels_prep) - - elif len(prep_band) and len(bands_gen_list) == 2: - for b1, b2, prep_label in zip(*bands_gen_list, prep_label_gen): - prep_img = np.dstack(np.array([b1, b2])) - number_samples, number_classes, class_pixels_prep = sample_prep(src, prep_img, - prep_label[0], - prep_label[1], - gpkg_classes, - samples_size, - sample_type, - number_samples, - out_file, - number_classes, - val_percent, val_file, - min_annot_perc, - class_prop=class_prop, - dontcare=dontcare) - pixel_prep_counter.update(class_pixels_prep) - - elif len(prep_band) and len(bands_gen_list) == 3: - for b1, b2, b3, prep_label in zip(*bands_gen_list, prep_label_gen): - prep_img = np.dstack(np.array([b1, b2, b3])) - number_samples, number_classes, class_pixels_prep = sample_prep(src, prep_img, - prep_label[0], - prep_label[1], - gpkg_classes, - samples_size, - sample_type, - number_samples, - out_file, - number_classes, - val_percent, val_file, - min_annot_perc, - class_prop=class_prop, - dontcare=dontcare) - pixel_prep_counter.update(class_pixels_prep) - - elif len(prep_band) and len(bands_gen_list) == 4: - for b1, b2, b3, b4, prep_label in zip(*bands_gen_list, prep_label_gen): - prep_img = np.dstack(np.array([b1, b2, b3, b4])) - number_samples, number_classes, class_pixels_prep = sample_prep(src, prep_img, - prep_label[0], - prep_label[1], - gpkg_classes, - samples_size, - sample_type, - number_samples, - out_file, - number_classes, - val_percent, val_file, - min_annot_perc, - class_prop=class_prop, - dontcare=dontcare) - pixel_prep_counter.update(class_pixels_prep) - else: - continue - else: - continue - trn_hdf5.close() - val_hdf5.close() - tst_hdf5.close() - - class_pixel_ratio(pixel_pan_counter, 'pan_source', filename) - class_pixel_ratio(pixel_mul_counter, 'mul_source', filename) - class_pixel_ratio(pixel_prep_counter, 'prep_source', filename) - print("Number of samples created: ", number_samples, number_classes) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Sample preparation') - parser.add_argument('ParamFile', metavar='DIR', - help='Path to training parameters stored in yaml') - args = parser.parse_args() - params = read_parameters(args.ParamFile) - start_time = time.time() - tqdm.write(f'\n\nStarting images to samples preparation with {args.ParamFile}\n\n') - main(params) - print("Elapsed time:{}".format(time.time() - start_time)) diff --git a/sampling_segmentation.py b/sampling_segmentation.py index d94d9008..a29467c1 100644 --- a/sampling_segmentation.py +++ b/sampling_segmentation.py @@ -3,22 +3,22 @@ import rasterio import numpy as np +from solaris.utils.core import _check_rasterio_im_load from tqdm import tqdm from pathlib import Path from datetime import datetime from omegaconf import DictConfig, open_dict -# Our modules +from dataset.aoi import aois_from_csv from utils.logger import get_logger from utils.geoutils import vector_to_raster from utils.readers import image_reader_as_array from utils.create_dataset import create_files_and_datasets, append_to_dataset from utils.utils import ( - get_key_def, pad, pad_diff, read_csv, add_metadata_from_raster_to_sample, get_git_hash, - read_modalities, + get_key_def, pad, pad_diff, add_metadata_from_raster_to_sample, get_git_hash, ) from utils.verifications import ( - validate_num_classes, assert_crs_match, validate_features_from_gpkg, validate_input_imagery + validate_num_classes ) # Set the logging file logging = get_logger(__name__) # import logging @@ -364,8 +364,8 @@ def main(cfg: DictConfig) -> None: """ # PARAMETERS num_classes = len(cfg.dataset.classes_dict.keys()) - num_bands = len(cfg.dataset.modalities) - modalities = read_modalities(cfg.dataset.modalities) # TODO add the Victor module to manage the modalities + bands_requested = get_key_def('bands', cfg['dataset'], default=None, expected_type=Sequence) + num_bands = len(bands_requested) debug = cfg.debug # RAW DATA PARAMETERS @@ -390,7 +390,7 @@ def main(cfg: DictConfig) -> None: logging.critical( f'Data path exists: {samples_dir}. Remove it or use a different experiment_name.' ) - raise FileExistsError() + raise FileExistsError(f'Data path exists: {samples_dir}. Remove it or use a different experiment_name.') Path.mkdir(samples_dir, exist_ok=False) # TODO: what if we want to append samples to existing hdf5? # LOGGING PARAMETERS TODO see logging yaml @@ -428,7 +428,12 @@ def main(cfg: DictConfig) -> None: with open_dict(cfg): cfg.general.git_hash = get_git_hash() - list_data_prep = read_csv(csv_file) + list_data_prep = aois_from_csv( + csv_path=csv_file, + bands_requested=bands_requested, + attr_field_filter=attribute_field, + attr_values_filter=attr_vals + ) # IF DEBUG IS ACTIVATE if debug: @@ -436,24 +441,14 @@ def main(cfg: DictConfig) -> None: f'\nDebug mode activated. Some debug features may mobilize extra disk space and cause delays in execution.' ) - # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg and (2) check CRS match (tif and gpkg) + # VALIDATION: (1) Assert num_classes parameters == num actual classes in gpkg valid_gpkg_set = set() - for info in tqdm(list_data_prep, position=0): - validate_input_imagery(info['tif'], num_bands) - if info['gpkg'] not in valid_gpkg_set: + for aoi in tqdm(list_data_prep, position=0): + if aoi.label not in valid_gpkg_set: gpkg_classes = validate_num_classes( - info['gpkg'], num_classes, attribute_field, dontcare, attribute_values=attr_vals, + aoi.label, num_classes, attribute_field, dontcare, attribute_values=attr_vals, ) - assert_crs_match(info['tif'], info['gpkg']) - valid_gpkg_set.add(info['gpkg']) - - if debug: - # VALIDATION (debug only): Checking validity of features in vector files - for info in tqdm(list_data_prep, position=0, desc=f"Checking validity of features in vector files"): - # TODO: make unit to test this with invalid features. - invalid_features = validate_features_from_gpkg(info['gpkg'], attribute_field) - if invalid_features: - logging.critical(f"{info['gpkg']}: Invalid geometry object(s) '{invalid_features}'") + valid_gpkg_set.add(aoi.label) number_samples = {'trn': 0, 'val': 0, 'tst': 0} number_classes = 0 @@ -475,27 +470,27 @@ def main(cfg: DictConfig) -> None: f"\nPreparing samples \n Samples_size: {samples_size} \n Overlap: {overlap} " f"\n Validation set: {val_percent} % of created training samples" ) - for info in tqdm(list_data_prep, position=0, leave=False): + for aoi in tqdm(list_data_prep, position=0, leave=False): try: - logging.info(f"\nReading as array: {info['tif']}") - with rasterio.open(info['tif'], 'r') as raster: + logging.info(f"\nReading as array: {aoi.raster.name}") + with _check_rasterio_im_load(aoi.raster) as raster: # 1. Read the input raster image np_input_image, raster, dataset_nodata = image_reader_as_array( input_image=raster, - clip_gpkg=info['gpkg'] + #FIXME: remove clip_gpkg=aoi.label ) # 2. Burn vector file in a raster file - logging.info(f"\nRasterizing vector file (attribute: {attribute_field}): {info['gpkg']}") + logging.info(f"\nRasterizing vector file (attribute: {attribute_field}): {aoi.label}") try: - np_label_raster = vector_to_raster(vector_file=info['gpkg'], + np_label_raster = vector_to_raster(vector_file=aoi.label, input_image=raster, out_shape=np_input_image.shape[:2], attribute_name=attribute_field, fill=background_val, attribute_values=attr_vals) # background value in rasterized vector. except ValueError: - logging.error(f"No vector features found for {info['gpkg']} with provided configuration." + logging.error(f"No vector features found for {aoi.label} with provided configuration." f"Will skip to next AOI.") continue @@ -509,7 +504,7 @@ def main(cfg: DictConfig) -> None: out_meta.update({"driver": "GTiff", "height": np_image_debug.shape[1], "width": np_image_debug.shape[2]}) - out_tif = samples_dir / f"{Path(info['tif']).stem}_clipped.tif" + out_tif = samples_dir / f"{Path(aoi.raster.name).stem}_clipped.tif" logging.debug(f"Writing clipped raster to {out_tif}") with rasterio.open(out_tif, "w", **out_meta) as dest: dest.write(np_image_debug) @@ -520,7 +515,7 @@ def main(cfg: DictConfig) -> None: "height": np_label_debug.shape[1], "width": np_label_debug.shape[2], 'count': 1}) - out_tif = samples_dir / f"{Path(info['gpkg']).stem}_clipped.tif" + out_tif = samples_dir / f"{Path(aoi.label).stem}_clipped.tif" logging.debug(f"\nWriting final rasterized gpkg to {out_tif}") with rasterio.open(out_tif, "w", **out_meta) as dest: dest.write(np_label_debug) @@ -529,17 +524,16 @@ def main(cfg: DictConfig) -> None: if mask_reference: np_label_raster = mask_image(np_input_image, np_label_raster) - if info['dataset'] == 'trn': + if aoi.split == 'trn': out_file = trn_hdf5 - elif info['dataset'] == 'tst': + elif aoi.split == 'tst': out_file = tst_hdf5 else: - raise ValueError(f"\nDataset value must be trn or tst. Provided value is {info['dataset']}") + raise ValueError(f"\nDataset value must be trn or tst. Provided value is {aoi.split}") val_file = val_hdf5 metadata = add_metadata_from_raster_to_sample(sat_img_arr=np_input_image, - raster_handle=raster, - raster_info=info) + raster_handle=raster) # Save label's per class pixel count to image metadata metadata['source_label_bincount'] = {class_num: count for class_num, count in enumerate(np.bincount(np_label_raster.clip(min=0).flatten())) @@ -556,7 +550,7 @@ def main(cfg: DictConfig) -> None: samples_file=out_file, val_percent=val_percent, val_sample_file=val_file, - dataset=info['dataset'], + dataset=aoi.split, pixel_classes=pixel_classes, dontcare=dontcare, image_metadata=metadata, @@ -567,8 +561,8 @@ def main(cfg: DictConfig) -> None: # logging.info(f'\nNumber of samples={number_samples}') out_file.flush() except OSError: - logging.exception(f'\nAn error occurred while preparing samples with "{Path(info["tif"]).stem}" (tiff) and ' - f'{Path(info["gpkg"]).stem} (gpkg).') + logging.exception(f'\nAn error occurred while preparing samples with "{Path(aoi.raster.name).stem}" (tiff) and ' + f'{Path(aoi.label).stem} (gpkg).') continue trn_hdf5.close() diff --git a/tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json b/tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json new file mode 100644 index 00000000..cb977ba9 --- /dev/null +++ b/tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json @@ -0,0 +1 @@ +{"id":"SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03","stac_version":"1.0.0","stac_extensions":["https://stac-extensions.github.io/projection/v1.0.0/schema.json","https://stac-extensions.github.io/eo/v1.0.0/schema.json"],"type":"Feature","geometry":{"type":"Polygon","coordinates":[[[-115.28540582000387,36.241524691687786],[-115.28491559145193,36.26395226444633],[-115.28755456238802,36.263990086821984],[-115.30751080120687,36.26427421837858],[-115.30769213475851,36.25587054603559],[-115.307994578854,36.24184638306735],[-115.28540582000387,36.241524691687786]]]},"bbox":[-115.307994578854,36.241524691687786,-115.28491559145193,36.26427421837858],"links":[{"href":"https://datacube-stage.services.geo.ca/api/collections/spacenet-samples/items/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03","rel":"self","type":"application/json"},{"href":"https://datacube-stage.services.geo.ca/api/collections/spacenet-samples","rel":"parent","type":"application/json","title":"SpaceNet Samples / Échantillons SpaceNet"},{"href":"../collection.json","rel":"collection","type":"application/json","title":"SpaceNet Samples / Échantillons SpaceNet"},{"href":"https://datacube-stage.services.geo.ca/api","rel":"root","type":"application/json","title":"Welcome to Franklin"}],"assets":{"N":{"eo:bands":[{"name":"n","common_name":"nir","description":"sensor:WV03, min_wavelength_nm: 770, max_wavelength_nm: 895, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.8325,"full_width_half_max":0.125}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-N.tif","title":"Near infrared band","description":"COG - Near infrared Single spectral band / COG - Bande spectrale unique Proche infrarouge","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"Y":{"eo:bands":[{"name":"y","common_name":"yellow","description":"sensor:WV03, min_wavelength_nm: 585, max_wavelength_nm: 625, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.605,"full_width_half_max":0.040000000000000036}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-Y.tif","title":"Yellow band","description":"COG - Yellow Single spectral band / COG - Bande spectrale unique Jaune","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"G":{"eo:bands":[{"name":"g","common_name":"green","description":"sensor:WV03, min_wavelength_nm: 510, max_wavelength_nm: 580, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.5449999999999999,"full_width_half_max":0.06999999999999995}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-G.tif","title":"Green band","description":"COG - Green Single spectral band / COG - Bande spectrale unique Vert","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"METADATA":{"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-METADATA.xml","title":"Metadata file from satellite provider","roles":["metadata"],"type":"application/xml"},"B":{"eo:bands":[{"name":"b","common_name":"blue","description":"sensor:WV03, min_wavelength_nm: 450, max_wavelength_nm: 510, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.48,"full_width_half_max":0.06}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-B.tif","title":"Blue band","description":"COG - Blue Single spectral band / COG - Bande spectrale unique Bleu","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"C":{"eo:bands":[{"name":"c","common_name":"coastal","description":"sensor:WV03, min_wavelength_nm: 400, max_wavelength_nm: 450, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.42500000000000004,"full_width_half_max":0.04999999999999999}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-C.tif","title":"Coastal band","description":"COG - Coastal Single spectral band / COG - Bande spectrale unique Côtier","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"N2":{"eo:bands":[{"name":"n2","common_name":"nir09","description":"sensor:WV03, min_wavelength_nm: 860, max_wavelength_nm: 1040, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.95,"full_width_half_max":0.18000000000000005}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-N2.tif","title":"Near infrared 2 band","description":"COG - Near infrared 2 Single spectral band / COG - Bande spectrale unique Proche infrarouge 2","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"RE":{"eo:bands":[{"name":"re","common_name":"rededge","description":"sensor:WV03, min_wavelength_nm: 705, max_wavelength_nm: 745, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.725,"full_width_half_max":0.040000000000000036}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-RE.tif","title":"Red Edge - single spectral band","description":"COG - Red Edge Single spectral band / COG - Bande spectrale unique Bordure rouge","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"},"R":{"eo:bands":[{"name":"r","common_name":"red","description":"sensor:WV03, min_wavelength_nm: 630, max_wavelength_nm: 690, orthorectified, pansharpened, downsampled to 8 bit","center_wavelength":0.6599999999999999,"full_width_half_max":0.05999999999999994}],"href":"http://datacube-stage-data-public.s3.ca-central-1.amazonaws.com/store/imagery/optical/spacenet-samples/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-R.tif","title":"Red band","description":"COG - Red Single spectral band / COG - Bande spectrale unique Rouge","roles":["data"],"type":"image/tiff; application=geotiff; profile=cloud-optimized"}},"collection":"spacenet-samples","properties":{"validation:attemptedExtensions":["https://stac-extensions.github.io/projection/v1.0.0/schema.json"],"validation:errors":[],"proj:epsg":32611,"collection":"spacenet-samples","proj:shape":[9130,7449],"proj:geometry":{"type":"Polygon","coordinates":[[[654066.7035330499,4012100.6890192926],[654066.7035330499,4014589.457639933],[652036.1627384004,4014589.457639933],[652036.1627384004,4012100.6890192926],[654066.7035330499,4012100.6890192926]]]},"eo:cloud_cover":0.0,"proj:transform":[652036.1627384004,0.2725924009463665,0.0,4014589.457639933,0.0,-0.2725924009463665],"datetime":"2015-10-22T18:36:56Z","created":"2022-03-10T15:52:36Z","updated":"2022-04-05T15:55:31Z"}} \ No newline at end of file diff --git a/tests/dataset/test_aoi.py b/tests/dataset/test_aoi.py new file mode 100644 index 00000000..7c68c193 --- /dev/null +++ b/tests/dataset/test_aoi.py @@ -0,0 +1,17 @@ +from dataset.aoi import AOI + + +class Test_AOI(object): + def test_parse_input_raster(self) -> None: + raster_raw = { + "tests/data/spacenet/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03.json": [ + "red", "green", "blue"], + "tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped_${dataset.bands}.tif": ["R", "G", "B"], + "tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif": None, + } + for raster_raw, bands_requested in raster_raw.items(): + raster_parsed = AOI.parse_input_raster(csv_raster_str=raster_raw, raster_bands_requested=bands_requested) + print(raster_parsed) + +# TODO: SingleBandItem +# test raise ValueError if request more than available bands \ No newline at end of file diff --git a/train_segmentation.py b/train_segmentation.py index 023f807e..4b673b3b 100644 --- a/train_segmentation.py +++ b/train_segmentation.py @@ -168,6 +168,7 @@ def get_num_samples(samples_path, params): weights = [] samples_weight = None for i in ['trn', 'val', 'tst']: + logging.debug(f"Reading {samples_path}/{i}_samples.hdf5...") if get_key_def(f"num_{i}_samples", params['training'], None) is not None: num_samples[i] = get_key_def(f"num_{i}_samples", params['training']) with h5py.File(samples_path.joinpath(f"{i}_samples.hdf5"), 'r') as hdf5_file: diff --git a/utils/geoutils.py b/utils/geoutils.py index db38c9c9..b698ede4 100644 --- a/utils/geoutils.py +++ b/utils/geoutils.py @@ -6,9 +6,14 @@ import fiona import os + +import pystac import rasterio +from rasterio import MemoryFile from rasterio.features import is_valid_geom from rasterio.mask import mask +from rasterio.shutil import copy as riocopy +import xml.etree.ElementTree as ET logger = logging.getLogger(__name__) @@ -97,7 +102,7 @@ def clip_raster_with_gpkg(raster, gpkg, debug=False): dest.write(out_img) return out_tif except ValueError as e: # if gpkg's extent outside raster: "ValueError: Input shapes do not overlap raster." - logging.error(f"e\n {raster.name}\n{gpkg}") + logging.error(f"{e}\n {raster.name}\n{gpkg}") def vector_to_raster(vector_file, input_image, out_shape, attribute_name, fill=0, attribute_values=None, merge_all=True): @@ -194,3 +199,43 @@ def get_key_recursive(key, config): assert len(key) > 1, "missing keys to index metadata subdictionaries" return get_key_recursive(key[1:], val) return int(val) + + +def is_stac_item(path: str) -> bool: + """Checks if an input string or object is a valid stac item""" + if isinstance(path, pystac.Item): + return True + else: + try: + pystac.Item.from_file(str(path)) + return True + except (FileNotFoundError, pystac.STACTypeError, UnicodeDecodeError): + return False + + +def stack_vrts(srcs, band=1): + """ + Stacks multiple single-band raster into a single multiband virtual raster + Source: https://gis.stackexchange.com/questions/392695/is-it-possible-to-build-a-vrt-file-from-multiple-files-with-rasterio + @param srcs: + List of paths/urls to single-band raster + @param band: + TODO + @return: + RasterDataset object containing VRT + """ + vrt_bands = [] + for srcnum, src in enumerate(srcs, start=1): + with rasterio.open(src) as ras, MemoryFile() as mem: + riocopy(ras, mem.name, driver='VRT') + vrt_xml = mem.read().decode('utf-8') + vrt_dataset = ET.fromstring(vrt_xml) + for bandnum, vrt_band in enumerate(vrt_dataset.iter('VRTRasterBand'), start=1): + if bandnum == band: + vrt_band.set('band', str(srcnum)) + vrt_bands.append(vrt_band) + vrt_dataset.remove(vrt_band) + for vrt_band in vrt_bands: + vrt_dataset.append(vrt_band) + + return ET.tostring(vrt_dataset).decode('UTF-8') diff --git a/utils/logger.py b/utils/logger.py index 2c94f9dd..92f85426 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -1,5 +1,4 @@ import logging -import os from pathlib import Path from typing import Union diff --git a/utils/utils.py b/utils/utils.py index 1a709eda..a47c3ddc 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -327,7 +327,6 @@ def checkpoint_url_download(url: str): def list_input_images(img_dir_or_csv: Path, - bucket_name: str = None, glob_patterns: List = None): """ Create list of images from given directory or csv file. @@ -340,41 +339,31 @@ def list_input_images(img_dir_or_csv: Path, returns list of dictionaries where keys are "tif" and values are paths to found images. "meta" key is also added if input is csv and second column contains a metadata file. Then, value is path to metadata file. """ - if bucket_name: - s3 = boto3.resource('s3') - bucket = s3.Bucket(bucket_name) - if img_dir_or_csv.suffix == '.csv': - bucket.download_file(str(img_dir_or_csv), 'img_csv_file.csv') - list_img = read_csv('img_csv_file.csv') - else: - raise NotImplementedError( - 'Specify a csv file containing images for inference. Directory input not implemented yet') + if img_dir_or_csv.suffix == '.csv': + list_img = read_csv(img_dir_or_csv) + elif is_url(str(img_dir_or_csv)): + list_img = [] + img = {'tif': img_dir_or_csv} + list_img.append(img) else: - if img_dir_or_csv.suffix == '.csv': - list_img = read_csv(img_dir_or_csv) - elif is_url(str(img_dir_or_csv)): - list_img = [] - img = {'tif': img_dir_or_csv} - list_img.append(img) + img_dir = img_dir_or_csv + if not img_dir.is_dir(): + raise NotADirectoryError(f'Could not find directory/file "{img_dir_or_csv}"') + + list_img_paths = set() + if img_dir.is_dir(): + for glob_pattern in glob_patterns: + if not isinstance(glob_pattern, str): + raise TypeError(f'Invalid glob pattern: "{glob_pattern}"') + list_img_paths.update(sorted(img_dir.glob(glob_pattern))) else: - img_dir = img_dir_or_csv - if not img_dir.is_dir(): - raise NotADirectoryError(f'Could not find directory/file "{img_dir_or_csv}"') - - list_img_paths = set() - if img_dir.is_dir(): - for glob_pattern in glob_patterns: - if not isinstance(glob_pattern, str): - raise TypeError(f'Invalid glob pattern: "{glob_pattern}"') - list_img_paths.update(sorted(img_dir.glob(glob_pattern))) - else: - list_img_paths.update([img_dir]) - list_img = [] - for img_path in list_img_paths: - img = {'tif': img_path} - list_img.append(img) - if not len(list_img) >= 0: - raise ValueError(f'No .tif files found in {img_dir_or_csv}') + list_img_paths.update([img_dir]) + list_img = [] + for img_path in list_img_paths: + img = {'tif': img_path} + list_img.append(img) + if not len(list_img) >= 0: + raise ValueError(f'No .tif files found in {img_dir_or_csv}') return list_img @@ -399,21 +388,14 @@ def read_csv(csv_file_name: str) -> Dict: raise ValueError(f"Rows in csv should be of same length. Got rows with length: {row_lengths_set}") row.extend([None] * (4 - len(row))) # fill row with None values to obtain row of length == 5 row[0] = to_absolute_path(row[0]) # Convert relative paths to absolute with hydra's util to_absolute_path() - if not Path(row[0]).is_file(): - logging.critical(f"Raster not found: {row[0]}. This data will be removed from input data list" - f"since all geo-deep-learning modules require imagery.") - continue row[1] = to_absolute_path(row[1]) - if not Path(row[1]).is_file(): - logging.critical(f"Ground truth not found: {row[1]}") - if row[2] and not row[2] in ['trn', 'tst']: - logging.critical(f'Invalid dataset split: {row[2]}. Expected "trn" or "tst".') + # save all values list_values.append( - {'tif': str(row[0]), 'gpkg': str(row[1]), 'dataset': row[2], 'aoi_id': row[3]}) + {'tif': str(row[0]), 'gpkg': str(row[1]), 'split': row[2], 'aoi_id': row[3]}) try: # Try sorting according to dataset name (i.e. group "train", "val" and "test" rows together) - list_values = sorted(list_values, key=lambda k: k['dataset']) + list_values = sorted(list_values, key=lambda k: k['split']) except TypeError: log.warning('Unable to sort csv rows') return list_values @@ -421,7 +403,7 @@ def read_csv(csv_file_name: str) -> Dict: def add_metadata_from_raster_to_sample(sat_img_arr: np.ndarray, raster_handle: dict, - raster_info: dict + raster_info: dict = None ) -> dict: """ :param sat_img_arr: source image as array (opened with rasterio.read) @@ -524,26 +506,6 @@ def ordereddict_eval(str_to_eval: str): return str_to_eval -def read_modalities(modalities: str) -> list: - """ - Function that read the modalities from the yaml and convert it to a list - of all the bands specified. - - ------- - :param modalities: (str) A string composed of all the bands of the images. - - ------- - :returns: A list of all the bands of the images. - """ - if str(modalities).find('IR') != -1: - ir_position = str(modalities).find('IR') - modalities = list(str(modalities).replace('IR', '')) - modalities.insert(ir_position, 'IR') - else: - modalities = list(str(modalities)) - return modalities - - def getpath(d, path): """ TODO diff --git a/utils/verifications.py b/utils/verifications.py index 2f5175ce..ce883a3c 100644 --- a/utils/verifications.py +++ b/utils/verifications.py @@ -1,16 +1,23 @@ +import os from pathlib import Path from typing import Union, List import fiona +import geopandas as gpd import numpy as np import rasterio +from fiona._err import CPLE_OpenFailedError +from fiona.errors import DriverError from rasterio.features import is_valid_geom +from solaris.utils.core import _check_rasterio_im_load, _check_gdf_load, _check_crs from tqdm import tqdm from utils.geoutils import lst_ids, get_key_recursive import logging +from utils.utils import is_url + logger = logging.getLogger(__name__) @@ -20,6 +27,7 @@ def validate_num_classes(vector_file: Union[str, Path], ignore_index: int, attribute_values: List): """Check that `num_classes` is equal to number of classes detected in the specified attribute for each GeoPackage. + # FIXME: use geopandas FIXME: this validation **will not succeed** if a Geopackage contains only a subset of `num_classes` (e.g. 3 of 4). Args: :param vector_file: full file path of the vector image @@ -68,62 +76,59 @@ def validate_num_classes(vector_file: Union[str, Path], return num_classes_ -def validate_raster(raster_path: Union[str, Path], extended: bool = False) -> bool: +def validate_raster(raster: Union[str, Path, rasterio.DatasetReader], extended: bool = False) -> None: """ Checks if raster is valid, i.e. not corrupted (based on metadata, or actual byte info if under size threshold) - @param raster_path: Path to raster to be validated + @param raster: Path to raster to be validated @param extended: if True, raster data will be entirely read to detect any problem @return: if raster is valid, returns True, else False (with logging.critical) """ - if not raster_path: - return False + if not raster: + raise FileNotFoundError(f"No raster provided. Got: {raster}") try: - raster_path = Path(raster_path) if isinstance(raster_path, str) else raster_path + raster = Path(raster) if isinstance(raster, str) and not is_url(raster) else raster except TypeError as e: - logging.critical(f"Invalid raster.\nRaster path: {raster_path}\n{e}") - return False + logging.critical(f"Invalid raster.\nRaster path: {raster}\n{e}") + raise e try: - logging.debug(f'Raster to validate: {raster_path}\n' - f'Size: {raster_path.stat().st_size}\n' + logging.debug(f'Raster to validate: {raster}\n' + f'Size: {raster.stat().st_size}\n' f'Extended check: {extended}') - with rasterio.open(raster_path, 'r') as raster: + with rasterio.open(raster, 'r') as raster: if not raster.meta['dtype'] in ['uint8', 'uint16']: # will trigger exception if invalid raster logging.warning(f"Only uint8 and uint16 are supported in current version.\n" - f"Datatype {raster.meta['dtype']} for {raster.name} may cause problems.") + f"Datatype {raster.meta['dtype']} for {raster.aoi_id} may cause problems.") if extended: - logging.debug(f'Will perform extended check.\nWill read first band: {raster_path}') - with rasterio.open(raster_path, 'r') as raster: + logging.debug(f'Will perform extended check.\nWill read first band: {raster}') + with rasterio.open(raster, 'r') as raster: raster_np = raster.read(1) logging.debug(raster_np.shape) if not np.any(raster_np): - logging.critical(f"Raster data filled with zero values.\nRaster path: {raster_path}") + logging.critical(f"Raster data filled with zero values.\nRaster path: {raster}") return False except FileNotFoundError as e: - logging.critical(f"Could not locate raster file.\nRaster path: {raster_path}\n{e}") - return False - except rasterio.errors.RasterioIOError as e: - logging.critical(f"Invalid raster.\nRaster path: {raster_path}\n{e}") - return False - return True + logging.critical(f"Could not locate raster file.\nRaster path: {raster}\n{e}") + raise e + except (rasterio.errors.RasterioIOError, TypeError) as e: + logging.critical(f"\nRasterio can't open the provided raster: {raster}\n{e}") + raise e -def validate_num_bands(raster_path: Union[str, Path], num_bands: int) -> bool: +def validate_num_bands(raster_path: Union[str, Path], num_bands: int) -> None: """ Checks match between expected and actual number of bands @param raster_path: Path to raster to be validated @param num_bands: Number of bands expected @return: if expected and actual number of bands match, returns True, else False (with logging.critical) """ - with rasterio.open(raster_path, 'r') as raster: - input_band_count = raster.meta['count'] + raster = _check_rasterio_im_load(raster_path) + input_band_count = raster.meta['count'] if not input_band_count == num_bands: logging.critical(f"The number of bands expected doesn't match number of bands in input image.\n" f"Expected: {num_bands} bands\n" f"Got: {input_band_count} bands\n" - f"Raster path: {raster_path}") - return False - else: - return True + f"Raster path: {raster.name}") + raise ValueError() def validate_input_imagery(raster_path: Union[str, Path], num_bands: int, extended: bool = False) -> bool: @@ -134,48 +139,90 @@ def validate_input_imagery(raster_path: Union[str, Path], num_bands: int, extend @param num_bands: Number of bands expected @return: """ - if not validate_raster(raster_path, extended): + try: + validate_raster(raster_path, extended) + except Exception as e: # TODO: address with issue #310 return False - if not validate_num_bands(raster_path, num_bands): + try: + validate_num_bands(raster_path, num_bands) + except Exception as e: return False return True -def assert_crs_match(raster_path: Union[str, Path], gpkg_path: Union[str, Path]): +def assert_crs_match( + raster: Union[str, Path, rasterio.DatasetReader], + label: Union[str, Path, gpd.GeoDataFrame]): """ Assert Coordinate reference system between raster and gpkg match. - :param raster_path: (str or Path) path to raster file - :param gpkg_path: (str or Path) path to gpkg file + :param raster: (str or Path) path to raster file + :param label: (str or Path) path to gpkg file """ - with fiona.open(gpkg_path, 'r') as src: - gpkg_crs = src.crs - - with rasterio.open(raster_path, 'r') as raster: - raster_crs = raster.crs - - if not gpkg_crs == raster_crs: - logging.warning(f"CRS mismatch: \n" - f"TIF file \"{raster_path}\" has {raster_crs} CRS; \n" - f"GPKG file \"{gpkg_path}\" has {src.crs} CRS.") - + raster = _check_rasterio_im_load(raster) + raster_crs = raster.crs + gt = _check_gdf_load(label) + gt_crs = gt.crs -def validate_features_from_gpkg(gpkg: Union[str, Path], attribute_name: str): + epsg_gt = _check_crs(gt_crs.to_epsg()) + try: + if raster_crs.is_epsg_code: + epsg_raster = _check_crs(raster_crs.to_epsg()) + else: + logging.warning(f"Cannot parse epsg code from raster's crs '{raster.name}'") + return False, raster_crs, gt_crs + + if epsg_raster != epsg_gt: + logging.error(f"CRS mismatch: \n" + f"TIF file \"{raster}\" has {epsg_raster} CRS; \n" + f"GPKG file \"{label}\" has {epsg_gt} CRS.") + return False, raster_crs, gt_crs + else: + return True, raster_crs, gt_crs + except AttributeError as e: + logging.critical(f'Problem reading crs from image or label.') + logging.critical(e) + return False, raster_crs, gt_crs + + +def validate_features_from_gpkg(label: Union[str, Path], attribute_name: str): """ Validate features in gpkg file - :param gpkg: (str or Path) path to gpkg file + :param label: (str or Path) path to gpkg file :param attribute_name: name of the value field representing the required classes in the vector image file """ + # FIXME: use geopandas # TODO: test this with invalid features. invalid_features_list = [] # Validate vector features to burn in the raster image - with fiona.open(gpkg, 'r') as src: # TODO: refactor as independent function + with fiona.open(label, 'r') as src: lst_vector = [vector for vector in src] shapes = lst_ids(list_vector=lst_vector, attr_name=attribute_name) for index, item in enumerate(tqdm([v for vecs in shapes.values() for v in vecs], leave=False, position=1)): + feature_id = lst_vector[index]["id"] # geom must be a valid GeoJSON geometry type and non-empty geom, value = item geom = getattr(geom, '__geo_interface__', None) or geom if not is_valid_geom(geom): - if lst_vector[index]["id"] not in invalid_features_list: # ignore if feature is already appended - invalid_features_list.append(lst_vector[index]["id"]) + if feature_id not in invalid_features_list: # ignore if feature is already appended + if index == 0: + logging.critical(f"Label file contains at least one invalid feature: {label}") + invalid_features_list.append(feature_id) + logging.critical(f"Invalid geometry object: '{feature_id}'") return invalid_features_list + + +def validate_by_geopandas(label: Union[Path, str]): + # TODO: unit test for valid/invalid label file + """Check if `label` is readable by geopandas, if not, log and raise error.""" + # adapted from https://github.com/CosmiQ/solaris/blob/main/solaris/utils/core.py#L52 + if not Path(label).is_file() or os.stat(label).st_size == 0: + raise FileNotFoundError(f'{label} is not a valid file') + try: + return gpd.read_file(label) + except (DriverError, CPLE_OpenFailedError) as e: + logging.error( + f"GeoDataFrame couldn't be loaded: either {label} isn't a valid" + " path or it isn't a valid vector file. Returning an empty" + " GeoDataFrame." + ) + raise e