Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

553 feature refactor segmentationdataset class #554

Merged
152 changes: 90 additions & 62 deletions dataset/create_dataset.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is clean, thanks!

Original file line number Diff line number Diff line change
@@ -1,119 +1,147 @@
import numpy as np
from pathlib import Path
from typing import Any, Dict, cast
import sys
from pathlib import Path
from typing import Any, Dict, List, cast

from rasterio.windows import from_bounds
import kornia as K
import numpy as np
import pandas as pd
import rasterio
import torch
from affine import Affine
from osgeo import ogr
# These two import statements prevent exception when using eval(metadata) in SegmentationDataset()'s __init__()
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from rasterio.plot import reshape_as_image
from rasterio.vrt import WarpedVRT
from rasterio.windows import from_bounds
from torch.utils.data import Dataset
from torchgeo.datasets import GeoDataset
from rasterio.vrt import WarpedVRT
from torchgeo.datasets.utils import BoundingBox
import torch
from osgeo import ogr

from utils.logger import get_logger

# These two import statements prevent exception when using eval(metadata) in SegmentationDataset()'s __init__()
from rasterio.crs import CRS
from affine import Affine

# Set the logging file
logging = get_logger(__name__) # import logging


def append_to_dataset(dataset, sample):
"""
Append a new sample to a provided dataset. The dataset has to be expanded before we can add value to it.
:param dataset:
:param sample: data to append
:return: Index of the newly added sample.
"""
old_size = dataset.shape[0] # this function always appends samples on the first axis
dataset.resize(old_size + 1, axis=0)
dataset[old_size, ...] = sample
return old_size


class SegmentationDataset(Dataset):
"""Semantic segmentation dataset based on input csvs listing pairs of imagery and ground truth patches as .tif."""

def __init__(self,
dataset_list_path,
dataset_type,
num_bands,
dontcare=None,
max_sample_count=None,
radiom_transform=None,
geom_transform=None,
totensor_transform=None,
debug=False):
# note: if 'max_sample_count' is None, then it will be read from the dataset at runtime
self.max_sample_count = max_sample_count
self.dataset_type = dataset_type
self.num_bands = num_bands
self.radiom_transform = radiom_transform
self.geom_transform = geom_transform
self.totensor_transform = totensor_transform
self.debug = debug
self.dontcare = dontcare
self.list_path = dataset_list_path

if not Path(self.list_path).is_file():
logging.error(f"Couldn't locate dataset list file: {self.list_path}.\n"
f"If purposely omitting test set, this error can be ignored")
self.max_sample_count = 0
else:
with open(self.list_path, 'r') as datafile:
datalist = datafile.readlines()
if self.max_sample_count is None:
self.max_sample_count = len(datalist)

self.assets = self._load_data()

def __len__(self):
return self.max_sample_count

return len(self.assets)
def __getitem__(self, index):
with open(self.list_path, 'r') as datafile:
datalist = datafile.readlines()
data_line = datalist[index]
with rasterio.open(data_line.split(';')[0], 'r') as sat_handle:
sat_img = reshape_as_image(sat_handle.read())
metadata = sat_handle.meta
with rasterio.open(data_line.split(';')[1].rstrip('\n'), 'r') as label_handle:
map_img = reshape_as_image(label_handle.read())
map_img = map_img[..., 0]

assert self.num_bands <= sat_img.shape[-1]

if isinstance(metadata, np.ndarray) and len(metadata) == 1:
metadata = metadata[0]
elif isinstance(metadata, bytes):
metadata = metadata.decode('UTF-8')
try:
metadata = eval(metadata)
except TypeError:
pass


sat_img, metadata = self._load_image(index)
map_img = self._load_label(index)

if isinstance(metadata, np.ndarray) and len(metadata) == 1:
metadata = metadata[0]
elif isinstance(metadata, bytes):
metadata = metadata.decode('UTF-8')
try:
metadata = eval(metadata)
except TypeError:
pass

sample = {"image": sat_img, "mask": map_img, "metadata": metadata, "list_path": self.list_path}

if self.radiom_transform: # radiometric transforms should always precede geometric ones
# radiometric transforms should always precede geometric ones
if self.radiom_transform:
sample = self.radiom_transform(sample)
if self.geom_transform: # rotation, geometric scaling, flip and crop. Will also put channels first and convert to torch tensor from numpy.
# rotation, geometric scaling, flip and crop.
# Will also put channels first and convert to torch tensor from numpy.
if self.geom_transform:
sample = self.geom_transform(sample)

sample = self.totensor_transform(sample)
if self.totensor_transform:
sample = self.totensor_transform(sample)

if self.debug:
# assert no new class values in map_img
initial_class_ids = set(np.unique(map_img))
final_class_ids = set(np.unique(sample["mask"].numpy()))
if self.dontcare is not None:
initial_class_ids.add(self.dontcare)
final_class_ids = set(np.unique(sample['mask'].numpy()))
if not final_class_ids.issubset(initial_class_ids):
logging.debug(f"WARNING: Class ids for label before and after augmentations don't match. "
f"Ignore if overwritting ignore_index in ToTensorTarget")
logging.warning(f"\nWARNING: Class values for label before and after augmentations don't match."
f"\nUnique values before: {initial_class_ids}"
f"\nUnique values after: {final_class_ids}"
f"\nIgnore if some augmentations have padded with dontcare value.")
sample['index'] = index

return sample

def _load_data(self) -> List[str]:
"""Load the filepaths to images and labels

Returns:
List[str]: a list of filepaths to train/test data
"""
df = pd.read_csv(self.list_path, sep=';', header=None, usecols=[i for i in range(2)])
assets = [{"image": x, "label": y} for x, y in zip(df[0], df[1])]

return assets

def _load_image(self, index: int):
""" Load image

Args:
index: poosition of image

Returns:
image array and metadata
"""
image_path = self.assets[index]["image"]
with rasterio.open(image_path, 'r') as image_handle:
image = reshape_as_image(image_handle.read())
metadata = image_handle.meta
assert self.num_bands <= image.shape[-1]

return image, metadata

def _load_label(self, index: int):
""" Load label

Args:
index: poosition of label

Returns:
label array and metadata
"""
label_path = self.assets[index]["label"]

with rasterio.open(label_path, 'r') as label_handle:
label = reshape_as_image(label_handle.read())
label = label[..., 0]

return label


class DRDataset(GeoDataset):
def __init__(self, dr_ds: DatasetReader) -> None:
Expand Down
12 changes: 8 additions & 4 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,21 @@ def segmentation(param,

"""
sample = {"image": None, "mask": None, 'metadata': None}
start_seg = time.time()
# start_seg = time.time()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep this? (same for the end time)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will clean it out

print_log = True if logging.level == 20 else False # 20 is INFO
model.eval() # switch to evaluate mode

# initialize test time augmentation
transforms = tta.aliases.d4_transform()
# transforms = tta.aliases.d4_transform()
transforms = tta.Compose([])
tf_len = len(transforms)
h_padded, w_padded = input_image.height + chunk_size, input_image.width + chunk_size
patch_list = generate_patch_list(w_padded, h_padded, chunk_size, use_hanning)

fp = np.memmap(tp_mem, dtype='float16', mode='w+', shape=(tf_len, h_padded, w_padded, num_classes))
img_gen = gen_img_samples(src=input_image, patch_list=patch_list, chunk_size=chunk_size)
single_class_mode = False if num_classes > 1 else True
start_time = time.time()
for sub_image, h_idxs, w_idxs, hann_win in tqdm(
img_gen, position=0, leave=True, desc='Inferring on patches',
total=len(patch_list)
Expand Down Expand Up @@ -231,8 +233,10 @@ def segmentation(param,

pred_heatmap[row:row + chunk_size, col:col + chunk_size, :] = arr1.astype(heatmap_dtype)

end_seg = time.time() - start_seg
logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60))
# end_seg = time.time() - start_seg
end_time = time.time() - start_time
# logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60))
logging.info('Segmentation Completed in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60))

if debug:
logging.debug(f'Bin count of final output: {np.unique(pred_heatmap, return_counts=True)}')
Expand Down
Binary file added tests/data/tiles/tiled_image_1.tif
Binary file not shown.
Binary file added tests/data/tiles/tiled_image_2.tif
Binary file not shown.
Binary file added tests/data/tiles/tiled_label_1.tif
Binary file not shown.
Binary file added tests/data/tiles/tiled_label_2.tif
Binary file not shown.
2 changes: 2 additions & 0 deletions tests/data/tiles/tiles.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tests/data/tiles/tiled_image_1.tif;tests/data/tiles/tiled_label_1.tif
tests/data/tiles/tiled_image_2.tif;tests/data/tiles/tiled_label_2.tif
55 changes: 45 additions & 10 deletions tests/dataset/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,53 @@
from typing import List
from tempfile import NamedTemporaryFile
from typing import List

import pytest
import rasterio
from rasterio.io import DatasetReader
from rasterio.crs import CRS
from torchgeo.datasets.utils import extract_archive
from torchgeo.datasets.utils import BoundingBox
from osgeo import ogr
from _pytest.fixtures import SubRequest
import torch

from dataset.create_dataset import DRDataset, GDLVectorDataset

from _pytest.fixtures import SubRequest
from osgeo import ogr
from rasterio.crs import CRS
from rasterio.io import DatasetReader
from torchgeo.datasets.utils import BoundingBox, extract_archive

from dataset.create_dataset import (DRDataset, GDLVectorDataset,
SegmentationDataset)


class TestSegmentationDataset:
@pytest.fixture
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

def data(self):
dataset_list_path = 'tests/data/tiles/tiles.csv'
num_bands = 3
dataset = SegmentationDataset(dataset_list_path, num_bands)
return dataset

def test_len(self, data):
expected_length = 2
assert len(data) == expected_length

def test_getitem(self, data):
sample = data[0]
assert "image" in sample
assert 'mask' in sample
assert 'metadata' in sample
assert 'list_path' in sample

def test_load_data(self, data):
# Test that _load_data returns the expected number of assets
assets = data._load_data()
assert len(assets) == len(data)

def test_load_image(self, data):
# Test that _load_image returns an image and metadata
image, metadata = data._load_image(0)
assert image is not None
assert metadata is not None

def test_load_label(self, data):
# Test that _load_label returns a label
label = data._load_label(0)
assert label is not None

class TestDRDataset:
@pytest.fixture(params=["tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif",
Expand Down
2 changes: 1 addition & 1 deletion train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def create_dataloader(patches_folder: Path,
# TODO: should user point to the paths of these csvs directly?
dataset_file, _ = Tiler.make_dataset_file_name(experiment_name, min_annot_perc, subset, attr_vals)
dataset_filepath = patches_folder / dataset_file
datasets.append(dataset_constr(dataset_filepath, subset, num_bands,
datasets.append(dataset_constr(dataset_filepath, num_bands,
max_sample_count=num_patches[subset],
radiom_transform=aug.compose_transforms(params=cfg,
dataset=subset,
Expand Down
Loading