From d46abbbc5bc13508f0dd07c4cf39414e5ebb7e24 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Fri, 7 Oct 2022 10:50:22 -0400 Subject: [PATCH] fix issue at inference without label --- dataset/aoi.py | 2 +- inference_segmentation.py | 2 +- tests/dataset/test_aoi.py | 11 +++++++++-- .../inference_segmentation_multiclass_no_label.csv | 1 + utils/utils.py | 10 +++++++--- 5 files changed, 19 insertions(+), 7 deletions(-) create mode 100644 tests/inference/inference_segmentation_multiclass_no_label.csv diff --git a/dataset/aoi.py b/dataset/aoi.py index 0b7946ba..b0598762 100644 --- a/dataset/aoi.py +++ b/dataset/aoi.py @@ -330,7 +330,7 @@ def from_dict(cls, if not {'tif', 'gpkg', 'split'}.issubset(set(aoi_dict.keys())): raise ValueError(f"Input data should minimally contain the following keys: \n" f"'tif', 'gpkg', 'split'.") - if not aoi_dict['gpkg']: + if aoi_dict['gpkg'] is None: logging.warning(f"No ground truth data found for {aoi_dict['tif']}.\n" f"Only imagery will be processed from now on") if "aoi_id" not in aoi_dict.keys() or not aoi_dict['aoi_id']: diff --git a/inference_segmentation.py b/inference_segmentation.py index 9dbfcec4..57e477b3 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -325,7 +325,7 @@ def main(params: Union[DictConfig, dict]) -> None: ) # Dataset params - bands_requested = get_key_def('bands', params['dataset'], default=("red", "blue", "green"), expected_type=Sequence) + bands_requested = get_key_def('bands', params['dataset'], default=[1, 2, 3], expected_type=Sequence) classes_dict = get_key_def('classes_dict', params['dataset'], expected_type=DictConfig) num_classes = len(classes_dict) num_classes = num_classes + 1 if num_classes > 1 else num_classes # multiclass account for background diff --git a/tests/dataset/test_aoi.py b/tests/dataset/test_aoi.py index a0e3648e..10440c9c 100644 --- a/tests/dataset/test_aoi.py +++ b/tests/dataset/test_aoi.py @@ -10,7 +10,7 @@ from shapely.geometry import box from torchgeo.datasets.utils import extract_archive -from dataset.aoi import AOI +from dataset.aoi import AOI, aois_from_csv from utils.utils import read_csv @@ -92,7 +92,7 @@ def test_stac_url_input(self): os.remove("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-R.tif") def test_missing_label(self): - """Tests error when missing label file""" + """Tests error when provided label file is missing""" extract_archive(src="tests/data/spacenet.zip") data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv") for row in data: @@ -100,6 +100,13 @@ def test_missing_label(self): with pytest.raises(AttributeError): aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split']) aoi.close_raster() + + def test_no_label(self): + """Test when no label are provided. Should pass for inference. """ + extract_archive(src="tests/data/new_brunswick_aerial.zip") + csv_path = "tests/inference/inference_segmentation_multiclass_no_label.csv" + aois = aois_from_csv(csv_path=csv_path, bands_requested=[1,2,3]) + assert aois[0].label is None def test_parse_input_raster(self) -> None: """Tests parsing for three accepted patterns to reference input raster data with band selection""" diff --git a/tests/inference/inference_segmentation_multiclass_no_label.csv b/tests/inference/inference_segmentation_multiclass_no_label.csv new file mode 100644 index 00000000..a5940898 --- /dev/null +++ b/tests/inference/inference_segmentation_multiclass_no_label.csv @@ -0,0 +1 @@ +tests/data/new_brunswick_aerial/23322E759967N_clipped_1m_inference.tif,,inference \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py index e2f66171..2bc138f6 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -320,13 +320,17 @@ def read_csv(csv_file_name: str) -> Dict: row_lengths_set.update([len(row)]) if not len(row_lengths_set) == 1: raise ValueError(f"Rows in csv should be of same length. Got rows with length: {row_lengths_set}") + row = [str(i) or None for i in row] # replace empty strings to None. row.extend([None] * (4 - len(row))) # fill row with None values to obtain row of length == 5 + row[0] = to_absolute_path(row[0]) if not is_url(row[0]) else row[0] # Convert relative paths to absolute with hydra's util to_absolute_path() - row[1] = to_absolute_path(row[1]) if not is_url(row[1]) else row[1] - + try: + row[1] = str(to_absolute_path(row[1]) if not is_url(row[1]) else row[1]) + except TypeError: + row[1] = None # save all values list_values.append( - {'tif': str(row[0]), 'gpkg': str(row[1]), 'split': row[2], 'aoi_id': row[3]}) + {'tif': str(row[0]), 'gpkg': row[1], 'split': row[2], 'aoi_id': row[3]}) try: # Try sorting according to dataset name (i.e. group "train", "val" and "test" rows together) list_values = sorted(list_values, key=lambda k: k['split'])