Skip to content

Commit

Permalink
fix issue at inference without label
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Oct 7, 2022
1 parent da034c9 commit d46abbb
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dataset/aoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def from_dict(cls,
if not {'tif', 'gpkg', 'split'}.issubset(set(aoi_dict.keys())):
raise ValueError(f"Input data should minimally contain the following keys: \n"
f"'tif', 'gpkg', 'split'.")
if not aoi_dict['gpkg']:
if aoi_dict['gpkg'] is None:
logging.warning(f"No ground truth data found for {aoi_dict['tif']}.\n"
f"Only imagery will be processed from now on")
if "aoi_id" not in aoi_dict.keys() or not aoi_dict['aoi_id']:
Expand Down
2 changes: 1 addition & 1 deletion inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def main(params: Union[DictConfig, dict]) -> None:
)

# Dataset params
bands_requested = get_key_def('bands', params['dataset'], default=("red", "blue", "green"), expected_type=Sequence)
bands_requested = get_key_def('bands', params['dataset'], default=[1, 2, 3], expected_type=Sequence)
classes_dict = get_key_def('classes_dict', params['dataset'], expected_type=DictConfig)
num_classes = len(classes_dict)
num_classes = num_classes + 1 if num_classes > 1 else num_classes # multiclass account for background
Expand Down
11 changes: 9 additions & 2 deletions tests/dataset/test_aoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from shapely.geometry import box
from torchgeo.datasets.utils import extract_archive

from dataset.aoi import AOI
from dataset.aoi import AOI, aois_from_csv
from utils.utils import read_csv


Expand Down Expand Up @@ -92,14 +92,21 @@ def test_stac_url_input(self):
os.remove("data/SpaceNet_AOI_2_Las_Vegas-056155973080_01_P001-WV03-R.tif")

def test_missing_label(self):
"""Tests error when missing label file"""
"""Tests error when provided label file is missing"""
extract_archive(src="tests/data/spacenet.zip")
data = read_csv("tests/tiling/tiling_segmentation_binary-multiband_ci.csv")
for row in data:
row['gpkg'] = "missing_file.gpkg"
with pytest.raises(AttributeError):
aoi = AOI(raster=row['tif'], label=row['gpkg'], split=row['split'])
aoi.close_raster()

def test_no_label(self):
"""Test when no label are provided. Should pass for inference. """
extract_archive(src="tests/data/new_brunswick_aerial.zip")
csv_path = "tests/inference/inference_segmentation_multiclass_no_label.csv"
aois = aois_from_csv(csv_path=csv_path, bands_requested=[1,2,3])
assert aois[0].label is None

def test_parse_input_raster(self) -> None:
"""Tests parsing for three accepted patterns to reference input raster data with band selection"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tests/data/new_brunswick_aerial/23322E759967N_clipped_1m_inference.tif,,inference
10 changes: 7 additions & 3 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,17 @@ def read_csv(csv_file_name: str) -> Dict:
row_lengths_set.update([len(row)])
if not len(row_lengths_set) == 1:
raise ValueError(f"Rows in csv should be of same length. Got rows with length: {row_lengths_set}")
row = [str(i) or None for i in row] # replace empty strings to None.
row.extend([None] * (4 - len(row))) # fill row with None values to obtain row of length == 5

row[0] = to_absolute_path(row[0]) if not is_url(row[0]) else row[0] # Convert relative paths to absolute with hydra's util to_absolute_path()
row[1] = to_absolute_path(row[1]) if not is_url(row[1]) else row[1]

try:
row[1] = str(to_absolute_path(row[1]) if not is_url(row[1]) else row[1])
except TypeError:
row[1] = None
# save all values
list_values.append(
{'tif': str(row[0]), 'gpkg': str(row[1]), 'split': row[2], 'aoi_id': row[3]})
{'tif': str(row[0]), 'gpkg': row[1], 'split': row[2], 'aoi_id': row[3]})
try:
# Try sorting according to dataset name (i.e. group "train", "val" and "test" rows together)
list_values = sorted(list_values, key=lambda k: k['split'])
Expand Down

0 comments on commit d46abbb

Please sign in to comment.