Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: owlv2 extra bounding box filtering #85

Merged
merged 5 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/models/test_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@ def test_successful_image_detection():
assert pred.label == "a photo of a cat"


def test_successful_removing_extra_bbox():
test_image = "eggs-food-easter-food-drink-44c10e-1024.jpg"
prompts = ["egg"]

image = Image.open(f"tests/shared_data/images/{test_image}")

owlv2 = Owlv2()

results = owlv2(prompts=prompts, image=image)

assert len(results[0]) > 0

bboxlabels = results[0]

for bbox_label in bboxlabels:
assert bbox_label.label == "egg"
current_count = len(bboxlabels)
expected_max_count = 42
assert current_count <= expected_max_count


def test_successful_video_detection():
test_video = "test_video_5_frames.mp4"
file_path = f"tests/shared_data/videos/{test_video}"
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 53 additions & 10 deletions vision_agent_tools/models/owlv2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
from PIL import Image
from typing import List, Tuple, Union
from pydantic import BaseModel, Field
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from transformers.utils import TensorType
from transformers.image_transforms import center_to_corners_format
from transformers.models.owlv2.image_processing_owlv2 import box_iou
from vision_agent_tools.shared_types import BaseMLModel, Device, VideoNumpy, BboxLabel
from transformers.utils import TensorType

from vision_agent_tools.models.utils import filter_redundant_boxes
from vision_agent_tools.shared_types import BaseMLModel, BboxLabel, Device, VideoNumpy


class OWLV2Config(BaseModel):
Expand Down Expand Up @@ -54,6 +56,37 @@ class Owlv2(BaseMLModel):
and bounding boxes for detected objects with confidence exceeding a threshold.
"""

from typing import Dict, List

def _filter_bboxes(self, bboxlabels: List[BboxLabel]) -> List[BboxLabel]:
"""
Filters out redundant BboxLabel objects that fully contain multiple smaller boxes of the same label.

Parameters:
bboxlabels (List[BboxLabel]): List of BboxLabel objects to be filtered.

Returns:
List[BboxLabel]: Filtered list of BboxLabel objects.
"""
bboxes = [bl.bbox for bl in bboxlabels]
labels = [bl.label for bl in bboxlabels]

filtered = filter_redundant_boxes({"bboxes": bboxes, "labels": labels})
filtered_bboxes = filtered["bboxes"]
filtered_labels = filtered["labels"]

filtered_pairs = list(zip(filtered_bboxes, filtered_labels))

# preserving the original order
output_bboxlabels = []
for bl in bboxlabels:
pair = (bl.bbox, bl.label)
if pair in filtered_pairs:
output_bboxlabels.append(bl)
filtered_pairs.remove(pair) # Remove to handle duplicates correctly

return output_bboxlabels

def __run_inference(
self, image, texts, confidence, nms_threshold
) -> list[BboxLabel]:
Expand Down Expand Up @@ -91,7 +124,9 @@ def __run_inference(
BboxLabel(label=texts[i][label.item()], score=score.item(), bbox=box)
)

return inferences
filtered_inferences = self._filter_bboxes(inferences)

return filtered_inferences

def __init__(self, model_config: Optional[OWLV2Config] = None):
"""
Expand Down Expand Up @@ -247,10 +282,18 @@ def post_process_object_detection_with_nms(
boxes = boxes * scale_fct[:, None, :]

results = []
for s, l, b in zip(scores, labels, boxes):
score = s[s > threshold]
label = l[s > threshold]
box = b[s > threshold]
results.append({"scores": score, "labels": label, "boxes": box})
for score_array, label_array, box_array in zip(scores, labels, boxes):
high_score_mask = score_array > threshold
filtered_scores = score_array[high_score_mask]
hrnn marked this conversation as resolved.
Show resolved Hide resolved
filtered_labels = label_array[high_score_mask]
filtered_boxes = box_array[high_score_mask]

results.append(
{
"scores": filtered_scores,
"labels": filtered_labels,
"boxes": filtered_boxes,
}
)

return results