Skip to content

Commit

Permalink
add morphological filter regression tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
zigaLuksic committed Aug 30, 2023
1 parent bd6019a commit a1cb239
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
4 changes: 2 additions & 2 deletions eolearn/core/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ..constants import FeatureType
from ..eodata import EOPatch
from ..types import Feature
from ..types import FeaturesSpecification
from ..utils.parsing import FeatureParser

DEFAULT_BBOX = BBox((0, 0, 100, 100), crs=CRS("EPSG:32633"))
Expand All @@ -46,7 +46,7 @@ def __post_init__(self) -> None:


def generate_eopatch(
features: list[Feature] | None = None,
features: FeaturesSpecification | None = None,
bbox: BBox = DEFAULT_BBOX,
timestamps: list[dt.datetime] | None = None,
seed: int = 42,
Expand Down
55 changes: 43 additions & 12 deletions tests/geometry/test_morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
import pytest
from numpy.testing import assert_array_equal

from sentinelhub import CRS, BBox

from eolearn.core import EOPatch, FeatureType
from eolearn.core.utils.testing import PatchGeneratorConfig, generate_eopatch
from eolearn.geometry import ErosionTask, MorphologicalFilterTask, MorphologicalOperations, MorphologicalStructFactory

CLASSES = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
Expand All @@ -21,6 +20,12 @@
# ruff: noqa: NPY002


@pytest.fixture(name="patch")
def patch_fixture() -> EOPatch:
config = PatchGeneratorConfig(max_integer_value=10, raster_shape=(50, 100), depth_range=(3, 4))
return generate_eopatch([MASK_FEATURE, MASK_TIMELESS_FEATURE], config=config)


@pytest.mark.parametrize("invalid_input", [None, 0, "a"])
def test_erosion_value_error(invalid_input):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -49,19 +54,45 @@ def test_erosion_partial(test_eopatch):
assert_array_equal(np.unique(mask_after, return_counts=True)[1], [1145, 7601, 1069, 87, 198])


@pytest.mark.parametrize("morph_operation", MorphologicalOperations)
@pytest.mark.parametrize(
"struct_element", [None, MorphologicalStructFactory.get_disk(5), MorphologicalStructFactory.get_rectangle(5, 6)]
("morph_operation", "struct_element", "mask_counts", "mask_timeless_counts"),
[
(
MorphologicalOperations.DILATION,
None,
[1, 30, 161, 669, 1690, 3557, 6973, 12247, 19462, 30210],
[7, 29, 112, 304, 639, 1336, 2465, 4012, 6096],
),
(MorphologicalOperations.EROSION, MorphologicalStructFactory.get_disk(5), [74925, 72, 3], [14989, 11]),
(MorphologicalOperations.OPENING, MorphologicalStructFactory.get_disk(5), [73137, 1800, 63], [14720, 280]),
(MorphologicalOperations.CLOSING, MorphologicalStructFactory.get_disk(5), [1157, 73843], [501, 14499]),
(
MorphologicalOperations.MEDIAN,
MorphologicalStructFactory.get_rectangle(5, 6),
[16, 562, 6907, 24516, 28864, 12690, 1403, 42],
[71, 1280, 4733, 5924, 2592, 382, 18],
),
(
MorphologicalOperations.OPENING,
MorphologicalStructFactory.get_rectangle(5, 6),
[47486, 24132, 2565, 497, 169, 96, 35, 20],
[9929, 4446, 494, 53, 54, 16, 8],
),
(
MorphologicalOperations.DILATION,
MorphologicalStructFactory.get_rectangle(5, 6),
[2, 20, 184, 3888, 70906],
[3, 32, 748, 14217],
),
],
)
def test_morphological_filter(morph_operation, struct_element):
eopatch = EOPatch(bbox=BBox((0, 0, 1, 1), CRS(3857)), timestamps=["2015-7-7"] * 10)
eopatch[MASK_FEATURE] = np.random.randint(20, size=(10, 100, 100, 3), dtype=np.uint8)
eopatch[MASK_TIMELESS_FEATURE] = np.random.randint(20, 50, size=(100, 100, 5), dtype=np.uint8)

def test_morphological_filter(patch, morph_operation, struct_element, mask_counts, mask_timeless_counts):
task = MorphologicalFilterTask(
[MASK_FEATURE, MASK_TIMELESS_FEATURE], morph_operation=morph_operation, struct_elem=struct_element
)
task.execute(eopatch)
task.execute(patch)

assert eopatch[MASK_FEATURE].shape == (10, 100, 100, 3)
assert eopatch[MASK_TIMELESS_FEATURE].shape == (100, 100, 5)
assert patch[MASK_FEATURE].shape == (5, 50, 100, 3)
assert patch[MASK_TIMELESS_FEATURE].shape == (50, 100, 3)
assert_array_equal(np.unique(patch[MASK_FEATURE], return_counts=True)[1], mask_counts)
assert_array_equal(np.unique(patch[MASK_TIMELESS_FEATURE], return_counts=True)[1], mask_timeless_counts)

0 comments on commit a1cb239

Please sign in to comment.