From 9f77580864635f0d06741a53e44a3689f505ff16 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 4 Jan 2024 09:24:59 -0500 Subject: [PATCH] fix types when using cv2 (#211) * fix types when using cv2 fixes #210 * sort imports --- wsinfer/patchlib/patch.py | 16 +++++++++++----- wsinfer/patchlib/segment.py | 4 ++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/wsinfer/patchlib/patch.py b/wsinfer/patchlib/patch.py index 3de742b..dcacee2 100644 --- a/wsinfer/patchlib/patch.py +++ b/wsinfer/patchlib/patch.py @@ -2,7 +2,12 @@ import itertools import logging +import sys +from contextlib import contextmanager +from typing import TYPE_CHECKING +from typing import Iterator from typing import Sequence +from typing import cast as type_cast import cv2 as cv import numpy as np @@ -11,9 +16,6 @@ from shapely import Point from shapely import Polygon from shapely import STRtree -from contextlib import contextmanager -import sys -from typing import Iterator logger = logging.getLogger(__name__) @@ -52,7 +54,7 @@ def get_multipolygon_from_binary_arr( """ # Find contours and hierarchy contours: Sequence[npt.NDArray] - hierarchy: npt.NDArray[np.int_] + hierarchy: npt.NDArray contours, hierarchy = cv.findContours(arr, cv.RETR_CCOMP, cv.CHAIN_APPROX_SIMPLE) hierarchy = hierarchy.squeeze(0) @@ -64,7 +66,7 @@ def get_multipolygon_from_binary_arr( f" by {scale}" ) # Reshape to broadcast with contour coordinates. - scale_arr: npt.NDArray[np.float_] = np.array(scale).reshape(1, 1, 2) + scale_arr: npt.NDArray = np.array(scale).reshape(1, 1, 2) contours = tuple(c * scale_arr for c in contours_unscaled) del scale_arr @@ -114,6 +116,10 @@ def merge_polygons(polygon: MultiPolygon, idx: int, add: bool) -> MultiPolygon: # Call the function with an initial empty polygon and start from contour 0 polygon = merge_polygons(MultiPolygon(), 0, True) + if TYPE_CHECKING: + hierarchy = type_cast(npt.NDArray[np.int_], hierarchy) + contours_unscaled = type_cast(Sequence[npt.NDArray[np.int_]], contours_unscaled) + # Add back the axis in hierarchy because we squeezed it before. return polygon, contours_unscaled, hierarchy[np.newaxis] diff --git a/wsinfer/patchlib/segment.py b/wsinfer/patchlib/segment.py index 1542d5f..a6963fd 100644 --- a/wsinfer/patchlib/segment.py +++ b/wsinfer/patchlib/segment.py @@ -11,7 +11,7 @@ def segment_tissue( - im_arr: npt.NDArray[np.uint8], + im_arr: npt.NDArray, median_filter_size: int = 7, binary_threshold: int = 7, closing_kernel_size: int = 6, @@ -69,7 +69,7 @@ def segment_tissue( # Convert to boolean dtype. This helps with static type analysis because at this # point, im_arr is a uint8 array. - im_arr_binary = im_arr > 0 + im_arr_binary: npt.NDArray[np.bool_] = im_arr > 0 # type: ignore # Closing. This removes small holes. It might not be entirely necessary because # we have hole removal below.