-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Florence2SAM2 #24
Changes from 13 commits
1359d2b
248a797
d07105c
6ef04a2
1ab1b16
ab8976e
e759c79
4a1678b
be4ac30
f718957
4d9613b
6df24de
8ea859e
d5f3a16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
import pytest | ||
from PIL import Image | ||
|
||
from vision_agent_tools.tools.florence2_sam2 import Florence2SAM2 | ||
|
||
|
||
def test_successful_florence2_sam2_image(): | ||
""" | ||
This test verifies that CLIPMediaSim returns a valid iresponse when passed a target_text | ||
""" | ||
test_image = Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB") | ||
|
||
florence2_sam2 = Florence2SAM2() | ||
|
||
results = florence2_sam2(media=test_image, prompts=["tomato"]) | ||
|
||
# The disctionary should have only one key: 0 | ||
assert len(results) == 1 | ||
# The dictionary should have 23 instances of the tomato class | ||
assert len(results[0]) == 23 | ||
for instance in results[0].values(): | ||
assert len(instance.bounding_box) == 4 | ||
assert np.all( | ||
[ | ||
0 <= coord <= np.max(test_image.size[:2]) | ||
for coord in instance.bounding_box | ||
] | ||
) | ||
assert isinstance(instance.mask, np.ndarray) | ||
assert instance.mask.shape == test_image.size[::-1] | ||
assert instance.label == "tomato" | ||
|
||
|
||
def test_successful_florence2_sam2_video(): | ||
""" | ||
This test verifies that CLIPMediaSim returns a valid iresponse when passed a target_text | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, please change |
||
""" | ||
tomatoes_image = np.array( | ||
Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB"), dtype=np.uint8 | ||
) | ||
test_video = np.array( | ||
[tomatoes_image, np.zeros(tomatoes_image.shape, dtype=np.uint8)] | ||
) | ||
|
||
florence2_sam2 = Florence2SAM2() | ||
|
||
results = florence2_sam2(media=test_video, prompts=["tomato"]) | ||
|
||
# The disctionary should have 2 keys for the two frames in the video | ||
assert len(results) == 2 | ||
# The first frame should have 23 instances of the tomato class | ||
assert len(results[0]) == 23 | ||
assert len(results[1]) == 23 | ||
# First frame | ||
for instance in results[0].values(): | ||
assert isinstance(instance.mask, np.ndarray) | ||
assert instance.mask.shape == tomatoes_image.shape[:2] | ||
assert instance.label == "tomato" | ||
|
||
# Second frame | ||
for instance in results[1].values(): | ||
assert isinstance(instance.mask, np.ndarray) | ||
assert instance.mask.shape == tomatoes_image.shape[:2] | ||
assert instance.label == "tomato" | ||
# All masks should de empty since it's a black frame | ||
assert np.all(instance.mask == 0) | ||
|
||
|
||
def test_florence2_sam2_invalid_media(): | ||
""" | ||
This test verifies that CLIPMediaSim raises a ValueError if the media is not a valid type. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
""" | ||
florence2_sam2 = Florence2SAM2() | ||
|
||
with pytest.raises(ValueError): | ||
florence2_sam2(media="invalid media", prompts=["tomato"]) | ||
|
||
with pytest.raises(AssertionError): | ||
florence2_sam2(media=np.array([1, 2, 3]), prompts=["tomato"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,21 @@ | ||
from typing import Annotated, Literal, TypeVar | ||
|
||
from pydantic import BaseModel | ||
import numpy as np | ||
import numpy.typing as npt | ||
|
||
|
||
class BaseTool: | ||
pass | ||
|
||
|
||
DType = TypeVar("DType", bound=np.generic) | ||
|
||
VideoNumpy = Annotated[npt.NDArray[DType], Literal["N", "N", "N", 3]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dillonalaird I think this value is already defined inside the file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I merged them. I moved it to this file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, I thought I deleted the other file but I did not. I will delete the file, but I migrated all the imports to use this file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just pushed the changes |
||
|
||
SegmentationBitMask = Annotated[npt.NDArray[np.bool_], Literal["N", "N"]] | ||
|
||
|
||
class Point(BaseModel): | ||
# X coordinate of the point | ||
x: float | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
from dataclasses import dataclass | ||
from typing_extensions import Annotated | ||
|
||
import torch | ||
import numpy as np | ||
from PIL import Image | ||
from pydantic import validate_call | ||
|
||
from vision_agent_tools.shared_types import BaseTool, VideoNumpy, SegmentationBitMask | ||
from vision_agent_tools.tools.florencev2 import Florencev2, PromptTask | ||
|
||
from sam2.sam2_video_predictor import SAM2VideoPredictor | ||
from sam2.sam2_image_predictor import SAM2ImagePredictor | ||
|
||
|
||
_HF_MODEL = "facebook/sam2-hiera-large" | ||
|
||
|
||
@dataclass | ||
class ImageBboxAndMaskLabel: | ||
label: str | ||
bounding_box: list[ | ||
Annotated[float, "x_min"], | ||
Annotated[float, "y_min"], | ||
Annotated[float, "x_max"], | ||
Annotated[float, "y_max"], | ||
] | ||
mask: SegmentationBitMask | None | ||
|
||
|
||
@dataclass | ||
class MaskLabel: | ||
label: str | ||
mask: SegmentationBitMask | ||
|
||
|
||
class Florence2SAM2(BaseTool): | ||
def __init__(self): | ||
self.florence2 = Florencev2() | ||
self.video_predictor = SAM2VideoPredictor.from_pretrained(_HF_MODEL) | ||
self.image_predictor = SAM2ImagePredictor(self.video_predictor) | ||
|
||
@torch.inference_mode() | ||
def get_bbox_and_mask( | ||
self, image: Image.Image, prompts: list[str], return_mask: bool = True | ||
) -> dict[int, ImageBboxAndMaskLabel]: | ||
objs = {} | ||
self.image_predictor.set_image(np.array(image, dtype=np.uint8)) | ||
annotation_id = 0 | ||
for prompt in prompts: | ||
with torch.autocast(device_type="cuda", dtype=torch.float16): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dillonalaird I think this would be better to define a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks! done |
||
bboxes = self.florence2( | ||
image, PromptTask.CAPTION_TO_PHRASE_GROUNDING, prompt | ||
)[PromptTask.CAPTION_TO_PHRASE_GROUNDING]["bboxes"] | ||
if return_mask: | ||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
masks, _, _ = self.image_predictor.predict( | ||
point_coords=None, | ||
point_labels=None, | ||
box=bboxes, | ||
multimask_output=False, | ||
) | ||
for i in range(len(bboxes)): | ||
objs[annotation_id] = ImageBboxAndMaskLabel( | ||
bounding_box=bboxes[i], | ||
mask=( | ||
masks[i, 0, :, :] if len(masks.shape) == 4 else masks[i, :, :] | ||
) | ||
if return_mask | ||
else None, | ||
label=prompt, | ||
) | ||
annotation_id += 1 | ||
return objs | ||
|
||
@torch.inference_mode() | ||
def handle_image( | ||
self, image: Image.Image, prompts: list[str] | ||
) -> dict[int, dict[int, ImageBboxAndMaskLabel]]: | ||
self.image_predictor.reset_predictor() | ||
objs = self.get_bbox_and_mask(image.convert("RGB"), prompts) | ||
return {0: objs} | ||
|
||
@torch.inference_mode() | ||
def handle_video( | ||
self, video: VideoNumpy, prompts: list[str] | ||
) -> dict[int, dict[int, MaskLabel]]: | ||
self.image_predictor.reset_predictor() | ||
objs = self.get_bbox_and_mask( | ||
Image.fromarray(video[0]).convert("RGB"), prompts, return_mask=False | ||
) | ||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, change to |
||
inference_state = self.video_predictor.init_state(video=video) | ||
for annotation_id in objs: | ||
_, _, out_mask_logits = self.video_predictor.add_new_points_or_box( | ||
inference_state=inference_state, | ||
frame_idx=0, | ||
obj_id=annotation_id, | ||
box=objs[annotation_id].bounding_box, | ||
) | ||
|
||
annotation_id_to_label = {} | ||
for annotation_id in objs: | ||
annotation_id_to_label[annotation_id] = objs[annotation_id].label | ||
|
||
video_segments = {} | ||
for ( | ||
out_frame_idx, | ||
out_obj_ids, | ||
out_mask_logits, | ||
) in self.video_predictor.propagate_in_video(inference_state): | ||
video_segments[out_frame_idx] = { | ||
out_obj_id: MaskLabel( | ||
mask=(out_mask_logits[i][0] > 0.0).cpu().numpy(), | ||
label=annotation_id_to_label[out_obj_id], | ||
) | ||
for i, out_obj_id in enumerate(out_obj_ids) | ||
} | ||
self.video_predictor.reset_state(inference_state) | ||
return video_segments | ||
|
||
@validate_call(config={"arbitrary_types_allowed": True}) | ||
@torch.inference_mode() | ||
def __call__( | ||
self, media: Image.Image | VideoNumpy, prompts: list[str] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dillonalaird @camiloaz should we handle the input value as either image or video as separate values as we do here? I think we should be consistent and change one of these two to have the same format. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, agree. will do that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
) -> dict[int, dict[int, ImageBboxAndMaskLabel | MaskLabel]]: | ||
"""Returns a dictionary where the first key is the frame index then an annotation | ||
ID, then a dictionary of the mask, label and possibly bbox (for images) for each | ||
annotation ID. For example: | ||
{ | ||
0: | ||
{ | ||
0: {"mask": np.ndarray, "label": "car"}, | ||
1: {"mask", np.ndarray, "label": "person"} | ||
}, | ||
1: ... | ||
} | ||
""" | ||
if isinstance(media, Image.Image): | ||
return self.handle_image(media, prompts) | ||
elif isinstance(media, np.ndarray): | ||
assert media.ndim == 4, "Video should have 4 dimensions" | ||
return self.handle_video(media, prompts) | ||
# No need to raise an error here, the validatie_call decorator will take care of it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should change
CLIPMediaSim
toFlorence2SAM2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done. thanks.