Skip to content

Commit

Permalink
Fix cv2 (#1688)
Browse files Browse the repository at this point in the history
* fix cv2

Signed-off-by: tangy5 <[email protected]>

* fix cv2

Signed-off-by: tangy5 <[email protected]>

* fix cv2

Signed-off-by: tangy5 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* optinal import and fallback

Signed-off-by: tangy5 <[email protected]>

* optinal import and fallback

Signed-off-by: tangy5 <[email protected]>

* optinal import and fallback

Signed-off-by: tangy5 <[email protected]>

* optinal import and fallback

Signed-off-by: tangy5 <[email protected]>

* optinal import and fallback

Signed-off-by: tangy5 <[email protected]>

* optinal import and fallback

Signed-off-by: tangy5 <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: tangy5 <[email protected]>
Signed-off-by: tangy5 <[email protected]>
Co-authored-by: tangy5 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 8, 2024
1 parent ea2d7c2 commit 4dd1813
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 63 deletions.
102 changes: 63 additions & 39 deletions monailabel/transform/post.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import logging
from typing import Dict, Hashable, Mapping, Optional, Sequence, Union

import cv2
import nibabel as nib
import numpy as np
import skimage.measure as measure
Expand All @@ -27,18 +26,16 @@
generate_spatial_bounding_box,
get_extreme_points,
)
from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep
from monai.utils import InterpolateMode, convert_to_numpy, ensure_tuple_rep, optional_import
from shapely.geometry import Point, Polygon
from skimage.measure import approximate_polygon, find_contours
from torchvision.utils import make_grid, save_image

from monailabel.utils.others.label_colors import get_color

logger = logging.getLogger(__name__)


# TODO:: Move to MONAI ??


class LargestCCd(MapTransform):
def __init__(self, keys: KeysCollection, has_channel: bool = True):
super().__init__(keys)
Expand Down Expand Up @@ -183,7 +180,6 @@ def __init__(
colormap=None,
):
super().__init__(keys)

self.min_positive = min_positive
self.min_poly_area = min_poly_area
self.max_poly_area = max_poly_area
Expand All @@ -208,9 +204,7 @@ def __call__(self, data):
min_poly_area = d.get("min_poly_area", self.min_poly_area)
max_poly_area = d.get("max_poly_area", self.max_poly_area)
color_map = d.get(self.key_label_colors) if self.colormap is None else self.colormap

foreground_points = d.get(self.key_foreground_points, []) if self.key_foreground_points else []
foreground_points = [Point(pt[0], pt[1]) for pt in foreground_points] # polygons in (x, y) format
foreground_points = [Point(pt) for pt in d.get(self.key_foreground_points, [])]

elements = []
label_names = set()
Expand All @@ -220,43 +214,73 @@ def __call__(self, data):
continue

labels = [label for label in np.unique(p).tolist() if label > 0]
logger.debug(f"Total Unique Masks (excluding background): {labels}")
for label_idx in labels:
p = convert_to_numpy(d[key]) if isinstance(d[key], torch.Tensor) else d[key]
p = np.where(p == label_idx, 1, 0).astype(np.uint8)
p = np.moveaxis(p, 0, 1) # for cv2
p = np.moveaxis(p, 0, 1)

if label_idx == 0:
continue
label_name = self.labels.get(label_idx, label_idx)
label_names.add(label_name)

polygons = []
contours, _ = cv2.findContours(p, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
if len(contour) < 3:
continue

contour = np.squeeze(contour)
area = cv2.contourArea(contour)
if area < min_poly_area: # Ignore poly with lesser area
continue
if 0 < max_poly_area < area: # Ignore very large poly (e.g. in case of nuclei)
continue

contour[:, 0] += location[0] # X
contour[:, 1] += location[1] # Y

coords = contour.astype(int).tolist()
if foreground_points:
for pt in foreground_points:
if Polygon(coords).contains(pt):
polygons.append(coords)
break
else:
polygons.append(coords)

if len(polygons):
logger.debug(f"+++++ {label_idx} => Total Polygons Found: {len(polygons)}")
elements.append({"label": label_name, "contours": polygons})
cv2, has_cv2 = optional_import("cv2")
if has_cv2:
polygons = []
contours, _ = cv2.findContours(p, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours:
if len(contour) < 3:
continue

contour = np.squeeze(contour)
area = cv2.contourArea(contour)
if area < min_poly_area:
continue
if 0 < max_poly_area < area:
continue

contour[:, 0] += location[0]
contour[:, 1] += location[1]

coords = contour.astype(int).tolist()
if foreground_points:
for pt in foreground_points:
if Polygon(coords).contains(pt):
polygons.append(coords)
break
else:
polygons.append(coords)

if len(polygons):
logger.debug(f"+++++ {label_idx} => Total Polygons Found: {len(polygons)}")
elements.append({"label": label_name, "contours": polygons})
else:
contours = find_contours(p, 0.5)
contours = [np.round(contour).astype(int) for contour in contours]
for contour in contours:
if not np.array_equal(contour[0], contour[-1]):
contour = np.append(contour, [contour[0]], axis=0)

simplified_contour = approximate_polygon(contour, tolerance=0.5)
if len(simplified_contour) < 4:
continue

simplified_contour = np.flip(simplified_contour, axis=1)
simplified_contour += location
simplified_contour = simplified_contour.astype(int)

polygon = Polygon(simplified_contour)
if (
polygon.is_valid
and polygon.area >= min_poly_area
and (max_poly_area <= 0 or polygon.area <= max_poly_area)
):
formatted_contour = [simplified_contour.tolist()]
if foreground_points:
if any(polygon.contains(point) for point in foreground_points):
elements.append({"label": label_name, "contours": formatted_contour})
else:
elements.append({"label": label_name, "contours": formatted_contour})

if elements:
if d.get(self.result) is None:
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ pydicom==2.4.4
pydicom-seg==0.4.1
pynetdicom==2.0.2
pynrrd==1.0.0
opencv-python-headless==4.9.0.80
numpymaxflow==0.0.6
girder-client==3.2.3
ninja==1.11.1.1
Expand Down
18 changes: 14 additions & 4 deletions sample-apps/pathology/lib/trainers/hovernet_nuclei.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import pathlib
from typing import Dict, Optional

import cv2
import numpy as np
from lib.hovernet import PatchExtractor
from lib.utils import split_dataset
from monai.utils import optional_import
from PIL import Image
from scipy.ndimage import label
from tqdm import tqdm

from monailabel.interfaces.datastore import Datastore
Expand All @@ -36,7 +37,11 @@ def __init__(self, path: str, conf: Dict[str, str], const: Optional[BundleConsta
self.step_size = (164, 164)
self.extract_type = "mirror"

def _fetch_datalist(self, request, datastore: Datastore):
def remove_file(path):
if os.path.exists(path):
os.remove(path)

def _fetch_datalist(self, request, datastore):
cache_dir = os.path.join(self.bundle_path, "cache", "train_ds")
remove_file(cache_dir)

Expand Down Expand Up @@ -71,13 +76,18 @@ def _fetch_datalist(self, request, datastore: Datastore):
img = np.array(Image.open(d["image"]).convert("RGB"))
ann_type = np.array(Image.open(d["label"]))

numLabels, ann_inst, _, _ = cv2.connectedComponentsWithStats(ann_type, 4, cv2.CV_32S)
cv2, has_cv2 = optional_import("cv2")
if has_cv2:
numLabels, ann_inst, _, _ = cv2.connectedComponentsWithStats(ann_type, 4, cv2.CV_32S)
else:
ann_inst, numLabels = label(ann_type)

ann = np.dstack([ann_inst, ann_type])

img = np.concatenate([img, ann], axis=-1)
sub_patches = xtractor.extract(img, self.extract_type)

pbar_format = "Extracting : |{bar}| {n_fmt}/{total_fmt}[{elapsed}<{remaining},{rate_fmt}]"
pbar_format = "Extracting: |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]"
pbar = tqdm(total=len(sub_patches), leave=False, bar_format=pbar_format, ascii=True, position=1)

for idx, patch in enumerate(sub_patches):
Expand Down
98 changes: 81 additions & 17 deletions sample-apps/pathology/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from io import BytesIO
from math import ceil

import cv2
import numpy as np
import openslide
import scipy
from PIL import Image
from monai.utils import optional_import
from PIL import Image, ImageDraw
from scipy.ndimage import center_of_mass, find_objects, label
from tqdm import tqdm

from monailabel.datastore.dsa import DSADatastore
Expand Down Expand Up @@ -273,6 +274,7 @@ def split_consep_dataset(
crop_size=256,
):
dataset_json = []

# logger.debug(f"Process Image: {d['image']} => Label: {d['label']}")

images_dir = output_dir
Expand Down Expand Up @@ -477,7 +479,21 @@ def split_nuclei_dataset(
mask = Image.open(d["label"])
mask_np = np.array(mask)

numLabels, instances, stats, centroids = cv2.connectedComponentsWithStats(mask_np, 4, cv2.CV_32S)
cv2, has_cv2 = optional_import("cv2")
if has_cv2:
numLabels, instances, stats, centroids = cv2.connectedComponentsWithStats(mask_np, 4, cv2.CV_32S)
else:
numLabels, instances = label(mask_np)
stats = []
centroids = center_of_mass(mask_np, instances, range(numLabels))

objects = find_objects(instances)
for i, slice_tuple in enumerate(objects):
if slice_tuple is not None:
dx, dy = slice_tuple
area = (dx.stop - dx.start) * (dy.stop - dy.start)
stats.append([dy.start, dx.start, dy.stop - dy.start, dx.stop - dx.start, area])

logger.info("-------------------------------------------------------------------------------")
logger.info(f"Image/Label ========> {d['image']} =====> {d['label']}")
logger.info(f"Total Labels: {numLabels}")
Expand All @@ -486,11 +502,11 @@ def split_nuclei_dataset(
logger.info(f"Total Centroids: {len(centroids)}")
logger.info(f"Total Classes in Mask: {np.unique(mask_np)}")

for nuclei_id, (x, y) in enumerate(centroids):
for nuclei_id, centroid in enumerate(centroids):
if nuclei_id == 0:
continue

x, y = (int(x), int(y))
x, y = int(centroid[1]), int(centroid[0])

this_instance = np.where(instances == nuclei_id, mask_np, 0)
class_id = int(np.max(this_instance))
Expand Down Expand Up @@ -556,9 +572,36 @@ def _group_item(groups, d, output_dir):
return groups, item_id


def calculate_bounding_rect(points):
points = np.array(points, dtype=int)
x_min, y_min = np.min(points, axis=0)
x_max, y_max = np.max(points, axis=0)
w = x_max - x_min + 1
h = y_max - y_min + 1
return int(x_min), int(y_min), int(w), int(h)


def fill_poly(image_size, polygons, color, mode="L"):
if mode.upper() == "RGB":
img = Image.new("RGB", image_size, (0, 0, 0))
else:
img = Image.new("L", image_size, 0)

draw = ImageDraw.Draw(img)
for polygon in polygons:
draw.polygon([tuple(p) for p in polygon], fill=color)
return np.array(img)


def _to_roi(points, max_region, polygons, annotation_id):
logger.info(f"Total Points: {len(points)}")
x, y, w, h = cv2.boundingRect(np.array(points))

cv2, has_cv2 = optional_import("cv2")
if has_cv2:
x, y, w, h = cv2.boundingRect(np.array(points))
else:
x, y, w, h = calculate_bounding_rect(points)

logger.info(f"ID: {annotation_id} => Groups: {polygons.keys()}; Location: ({x}, {y}); Size: {w} x {h}")

if w > max_region[0]:
Expand All @@ -584,25 +627,46 @@ def _to_dataset(item_id, x, y, w, h, img, tile_size, polygons, groups, output_di
logger.debug(f"Image NP: {image_np.shape}; sum: {np.sum(image_np)}")
tiled_images = _region_to_tiles(name, w, h, image_np, tile_size, output_dir, "Image")

label_np = np.zeros((h, w), dtype=np.uint8) # Transposed
for group, contours in polygons.items():
color = groups.get(group, 1)
contours = [np.array([[p[0] - x, p[1] - y] for p in contour]) for contour in contours]
cv2, has_cv2 = optional_import("cv2")
if has_cv2:
label_np = np.zeros((h, w), dtype=np.uint8) # Transposed
for group, contours in polygons.items():
color = groups.get(group, 1)
contours = [np.array([[p[0] - x, p[1] - y] for p in contour]) for contour in contours]

cv2.fillPoly(label_np, pts=contours, color=color)
logger.info(f"{group} => p: {len(contours)}; c: {color}; unique: {np.unique(label_np, return_counts=True)}")
if debug:
regions_dir = os.path.join(output_dir, "regions")
label_path = os.path.realpath(os.path.join(regions_dir, "labels", group, f"{name}.png"))
os.makedirs(os.path.dirname(label_path), exist_ok=True)
cv2.imwrite(label_path, label_np)
else:
label_img = Image.new("L", (w, h), 0)
draw = ImageDraw.Draw(label_img)

cv2.fillPoly(label_np, pts=contours, color=color)
logger.info(f"{group} => p: {len(contours)}; c: {color}; unique: {np.unique(label_np, return_counts=True)}")
for group, contours in polygons.items():
color = groups.get(group, 1)
pil_contours = [tuple((p[0] - x, p[1] - y) for p in contour) for contour in contours]

if debug:
regions_dir = os.path.join(output_dir, "regions")
label_path = os.path.realpath(os.path.join(regions_dir, "labels", group, f"{name}.png"))
os.makedirs(os.path.dirname(label_path), exist_ok=True)
cv2.imwrite(label_path, label_np)
for contour in pil_contours:
draw.polygon(contour, outline=color, fill=color)

if debug:
regions_dir = os.path.join(output_dir, "regions")
label_path = os.path.realpath(os.path.join(regions_dir, "labels", group, f"{name}.png"))
os.makedirs(os.path.dirname(label_path), exist_ok=True)
label_img.save(label_path)

label_np = np.array(label_img)

tiled_labels = _region_to_tiles(
name, w, h, label_np, tile_size, os.path.join(output_dir, "labels", "final"), "Label"
)

for k in tiled_images:
dataset_json.append({"image": tiled_images[k], "label": tiled_labels[k]})

return dataset_json


Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ install_requires =
pydicom-seg==0.4.1
pynetdicom==2.0.2
pynrrd==1.0.0
opencv-python-headless==4.9.0.80
numpymaxflow==0.0.6
girder-client==3.2.3
ninja==1.11.1.1
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/transform/test_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
{
"pred": np.array([[0, 0, 0, 0, 0], [0, 1, 1, 1, 0], [0, 1, 0, 1, 0], [0, 1, 1, 1, 0], [0, 0, 0, 0, 0]]),
},
[[[1, 2], [2, 1], [3, 2], [2, 3]], [[1, 1], [1, 3], [3, 3], [3, 1]]],
[[[3, 4], [1, 4], [0, 3], [0, 1], [1, 0], [3, 0], [4, 1], [4, 3], [3, 4]]],
]

DUMPIMAGEPREDICTION2DD_DATA = [
Expand Down

0 comments on commit 4dd1813

Please sign in to comment.