Skip to content

Commit

Permalink
remove usage of skimage in morphology
Browse files Browse the repository at this point in the history
  • Loading branch information
Matic Lubej committed Aug 30, 2023
1 parent a1cb239 commit 7b9d82a
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions eolearn/geometry/morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

import itertools as it
from enum import Enum
from typing import Callable
from functools import partial

import cv2
import numpy as np
import skimage.filters.rank
import skimage.morphology

from eolearn.core import EOPatch, EOTask, MapFeatureTask
from eolearn.core.types import FeaturesSpecification, SingleFeatureSpec
Expand Down Expand Up @@ -44,7 +43,7 @@ def __init__(
parsed_mask_feature = parse_renamed_feature(mask_feature, allowed_feature_types=lambda fty: fty.is_array())

self.mask_type, self.mask_name, self.new_mask_name = parsed_mask_feature
self.disk = skimage.morphology.disk(disk_radius)
self.disk = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (disk_radius, disk_radius))
self.erode_labels = erode_labels
self.no_data_label = no_data_label

Expand All @@ -57,7 +56,7 @@ def execute(self, eopatch: EOPatch) -> EOPatch:
erode_labels = set(erode_labels) - {self.no_data_label}
other_labels = set(all_labels) - set(erode_labels) - {self.no_data_label}

eroded_masks = [skimage.morphology.binary_erosion(feature_array == label, self.disk) for label in erode_labels]
eroded_masks = [cv2.erode((feature_array == label).astype(np.uint8), self.disk) for label in erode_labels]
other_masks = [feature_array == label for label in other_labels]

merged_mask = np.logical_or.reduce(eroded_masks + other_masks, axis=0)
Expand All @@ -75,20 +74,18 @@ class MorphologicalOperations(Enum):
CLOSING = "closing"
DILATION = "dilation"
EROSION = "erosion"
MEDIAN = "median"

@classmethod
def get_operation(cls, morph_type: MorphologicalOperations) -> Callable:
def get_operation(cls, morph_type: MorphologicalOperations) -> int:
"""Maps morphological operation type to function
:param morph_type: Morphological operation type
"""
return {
cls.OPENING: skimage.morphology.opening,
cls.CLOSING: skimage.morphology.closing,
cls.DILATION: skimage.morphology.dilation,
cls.EROSION: skimage.morphology.erosion,
cls.MEDIAN: skimage.filters.rank.median,
cls.OPENING: cv2.MORPH_OPEN,
cls.CLOSING: cv2.MORPH_CLOSE,
cls.DILATION: cv2.MORPH_DILATE,
cls.EROSION: cv2.MORPH_ERODE,
}[morph_type]


Expand All @@ -103,15 +100,7 @@ def get_disk(radius: int) -> np.ndarray:
:param radius: Radius of disk
:return: The structuring element where elements of the neighborhood are 1 and 0 otherwise.
"""
return skimage.morphology.disk(radius)

@staticmethod
def get_diamond(radius: int) -> np.ndarray:
"""
:param radius: Radius of diamond
:return: The structuring element where elements of the neighborhood are 1 and 0 otherwise.
"""
return skimage.morphology.diamond(radius)
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (radius, radius))

@staticmethod
def get_rectangle(width: int, height: int) -> np.ndarray:
Expand All @@ -120,15 +109,15 @@ def get_rectangle(width: int, height: int) -> np.ndarray:
:param height: Height of rectangle
:return: A structuring element consisting only of ones, i.e. every pixel belongs to the neighborhood.
"""
return skimage.morphology.rectangle(width, height)
return cv2.getStructuringElement(cv2.MORPH_RECT, (height, width))

@staticmethod
def get_square(width: int) -> np.ndarray:
"""
:param width: Size of square
:return: A structuring element consisting only of ones, i.e. every pixel belongs to the neighborhood.
"""
return skimage.morphology.square(width)
return cv2.getStructuringElement(cv2.MORPH_RECT, (width, width))


class MorphologicalFilterTask(MapFeatureTask):
Expand All @@ -139,7 +128,7 @@ def __init__(
input_features: FeaturesSpecification,
output_features: FeaturesSpecification | None = None,
*,
morph_operation: MorphologicalOperations | Callable,
morph_operation: MorphologicalOperations,
struct_elem: np.ndarray | None = None,
):
"""
Expand All @@ -153,21 +142,19 @@ def __init__(
output_features = input_features
super().__init__(input_features, output_features)

if isinstance(morph_operation, MorphologicalOperations):
self.morph_operation = MorphologicalOperations.get_operation(morph_operation)
else:
self.morph_operation = morph_operation
self.morph_operation = MorphologicalOperations.get_operation(morph_operation)
self.struct_elem = struct_elem

def map_method(self, feature: np.ndarray) -> np.ndarray:
"""Applies the morphological operation to a raster feature."""
feature = feature.copy()
morph_func = partial(cv2.morphologyEx, kernel=self.struct_elem, op=self.morph_operation)
if feature.ndim == 3:
for channel in range(feature.shape[2]):
feature[..., channel] = self.morph_operation(feature[..., channel], self.struct_elem)
feature[..., channel] = morph_func(feature[..., channel])
elif feature.ndim == 4:
for time, channel in it.product(range(feature.shape[0]), range(feature.shape[3])):
feature[time, ..., channel] = self.morph_operation(feature[time, ..., channel], self.struct_elem)
feature[time, ..., channel] = morph_func(feature[time, ..., channel])
else:
raise ValueError(f"Invalid number of dimensions: {feature.ndim}")

Expand Down

0 comments on commit 7b9d82a

Please sign in to comment.