-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
553 feature refactor segmentationdataset class #554
Changes from 7 commits
8cd6e09
6030424
e90df4d
87e72ec
d64e98e
c902521
8077bd0
ba19e3a
1403896
0ef8ed1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,19 +158,21 @@ def segmentation(param, | |
|
||
""" | ||
sample = {"image": None, "mask": None, 'metadata': None} | ||
start_seg = time.time() | ||
# start_seg = time.time() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to keep this? (same for the end time) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will clean it out |
||
print_log = True if logging.level == 20 else False # 20 is INFO | ||
model.eval() # switch to evaluate mode | ||
|
||
# initialize test time augmentation | ||
transforms = tta.aliases.d4_transform() | ||
# transforms = tta.aliases.d4_transform() | ||
transforms = tta.Compose([]) | ||
tf_len = len(transforms) | ||
h_padded, w_padded = input_image.height + chunk_size, input_image.width + chunk_size | ||
patch_list = generate_patch_list(w_padded, h_padded, chunk_size, use_hanning) | ||
|
||
fp = np.memmap(tp_mem, dtype='float16', mode='w+', shape=(tf_len, h_padded, w_padded, num_classes)) | ||
img_gen = gen_img_samples(src=input_image, patch_list=patch_list, chunk_size=chunk_size) | ||
single_class_mode = False if num_classes > 1 else True | ||
start_time = time.time() | ||
for sub_image, h_idxs, w_idxs, hann_win in tqdm( | ||
img_gen, position=0, leave=True, desc='Inferring on patches', | ||
total=len(patch_list) | ||
|
@@ -231,8 +233,10 @@ def segmentation(param, | |
|
||
pred_heatmap[row:row + chunk_size, col:col + chunk_size, :] = arr1.astype(heatmap_dtype) | ||
|
||
end_seg = time.time() - start_seg | ||
logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60)) | ||
# end_seg = time.time() - start_seg | ||
end_time = time.time() - start_time | ||
# logging.info('Segmentation operation completed in {:.0f}m {:.0f}s'.format(end_seg // 60, end_seg % 60)) | ||
logging.info('Segmentation Completed in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60)) | ||
|
||
if debug: | ||
logging.debug(f'Bin count of final output: {np.unique(pred_heatmap, return_counts=True)}') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tests/data/tiles/tiled_image_1.tif;tests/data/tiles/tiled_label_1.tif | ||
tests/data/tiles/tiled_image_2.tif;tests/data/tiles/tiled_label_2.tif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,53 @@ | ||
from typing import List | ||
from tempfile import NamedTemporaryFile | ||
from typing import List | ||
|
||
import pytest | ||
import rasterio | ||
from rasterio.io import DatasetReader | ||
from rasterio.crs import CRS | ||
from torchgeo.datasets.utils import extract_archive | ||
from torchgeo.datasets.utils import BoundingBox | ||
from osgeo import ogr | ||
from _pytest.fixtures import SubRequest | ||
import torch | ||
|
||
from dataset.create_dataset import DRDataset, GDLVectorDataset | ||
|
||
from _pytest.fixtures import SubRequest | ||
from osgeo import ogr | ||
from rasterio.crs import CRS | ||
from rasterio.io import DatasetReader | ||
from torchgeo.datasets.utils import BoundingBox, extract_archive | ||
|
||
from dataset.create_dataset import (DRDataset, GDLVectorDataset, | ||
SegmentationDataset) | ||
|
||
|
||
class TestSegmentationDataset: | ||
@pytest.fixture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
def data(self): | ||
dataset_list_path = 'tests/data/tiles/tiles.csv' | ||
num_bands = 3 | ||
dataset = SegmentationDataset(dataset_list_path, num_bands) | ||
return dataset | ||
|
||
def test_len(self, data): | ||
expected_length = 2 | ||
assert len(data) == expected_length | ||
|
||
def test_getitem(self, data): | ||
sample = data[0] | ||
assert "image" in sample | ||
assert 'mask' in sample | ||
assert 'metadata' in sample | ||
assert 'list_path' in sample | ||
|
||
def test_load_data(self, data): | ||
# Test that _load_data returns the expected number of assets | ||
assets = data._load_data() | ||
assert len(assets) == len(data) | ||
|
||
def test_load_image(self, data): | ||
# Test that _load_image returns an image and metadata | ||
image, metadata = data._load_image(0) | ||
assert image is not None | ||
assert metadata is not None | ||
|
||
def test_load_label(self, data): | ||
# Test that _load_label returns a label | ||
label = data._load_label(0) | ||
assert label is not None | ||
|
||
class TestDRDataset: | ||
@pytest.fixture(params=["tests/data/massachusetts_buildings_kaggle/22978945_15_uint8_clipped.tif", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is clean, thanks!