-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Florence2SAM2 #24
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
1359d2b
added sam2
dillonalaird 248a797
added sam2 predictor
dillonalaird d07105c
added sam to dependencies
dillonalaird 6ef04a2
fixed sam2 naming, added as optional
dillonalaird 1ab1b16
updated sam2 with latest changes from sam2 repo
dillonalaird ab8976e
Merge branch 'main' into feat/add-sam2
camiloaz e759c79
typing and improvements
camiloaz 4a1678b
tests and fixes
camiloaz be4ac30
remove viz code
camiloaz f718957
rename class
camiloaz 4d9613b
remove comment
camiloaz 6df24de
use context manager
camiloaz 8ea859e
better dependencies
camiloaz d5f3a16
address review comments and docs
camiloaz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Florence2Sam2 | ||
|
||
This tool uses FlorenceV2 and the SAM-2 model to do text to instance segmentation on image or video inputs. | ||
|
||
```python | ||
from vision_agent_tools.tools.florence2_sam2 import Florence2SAM2 | ||
from decord import VideoReader | ||
from decord import cpu | ||
|
||
|
||
# Path to your video | ||
video_path = "path/to/your/video.mp4" | ||
|
||
# Load the video | ||
vr = VideoReader(video_path, ctx=cpu(0)) | ||
|
||
# Subsample frames | ||
frame_idxs = range(0, len(vr) - 1, 20) | ||
frames = vr.get_batch(frame_idxs).asnumpy() | ||
|
||
# Create the Florence2SAM2 instance | ||
florence2_sam2 = Florence2SAM2() | ||
|
||
# segment all the instances of the prompt "ball" for all video frames | ||
results = florence2_sam2(video=frames, prompts=["ball"]) | ||
|
||
# Returns a dictionary where the first key is the frame index then an annotation | ||
# ID, then an object with the mask, label and possibly bbox (for images) for each | ||
# annotation ID. For example: | ||
# { | ||
# 0: | ||
# { | ||
# 0: ImageBboxMaskLabel({"mask": np.ndarray, "label": "car"}), | ||
# 1: ImageBboxMaskLabel({"mask", np.ndarray, "label": "person"}) | ||
# }, | ||
# 1: ... | ||
# } | ||
|
||
print("Instance segmentation complete!") | ||
|
||
``` | ||
|
||
You can also run similarity against an image and get additionally bounding boxes doing the following: | ||
|
||
```python | ||
results = florence2_sam2(image=image, prompts=["ball"]) | ||
``` | ||
|
||
::: vision_agent_tools.tools.florence2_sam2 |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import numpy as np | ||
import pytest | ||
from PIL import Image | ||
|
||
from vision_agent_tools.tools.florence2_sam2 import Florence2SAM2 | ||
|
||
|
||
def test_successful_florence2_sam2_image(): | ||
""" | ||
This test verifies that Florence2SAM2 returns a valid iresponse when passed an image | ||
""" | ||
test_image = Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB") | ||
|
||
florence2_sam2 = Florence2SAM2() | ||
|
||
results = florence2_sam2(image=test_image, prompts=["tomato"]) | ||
|
||
# The disctionary should have only one key: 0 | ||
assert len(results) == 1 | ||
# The dictionary should have 23 instances of the tomato class | ||
assert len(results[0]) == 23 | ||
for instance in results[0].values(): | ||
assert len(instance.bounding_box) == 4 | ||
assert np.all( | ||
[ | ||
0 <= coord <= np.max(test_image.size[:2]) | ||
for coord in instance.bounding_box | ||
] | ||
) | ||
assert isinstance(instance.mask, np.ndarray) | ||
assert instance.mask.shape == test_image.size[::-1] | ||
assert instance.label == "tomato" | ||
|
||
|
||
def test_successful_florence2_sam2_video(): | ||
""" | ||
This test verifies that Florence2SAM2 returns a valid iresponse when passed a video | ||
""" | ||
tomatoes_image = np.array( | ||
Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB"), dtype=np.uint8 | ||
) | ||
test_video = np.array( | ||
[tomatoes_image, np.zeros(tomatoes_image.shape, dtype=np.uint8)] | ||
) | ||
|
||
florence2_sam2 = Florence2SAM2() | ||
|
||
results = florence2_sam2(video=test_video, prompts=["tomato"]) | ||
|
||
# The disctionary should have 2 keys for the two frames in the video | ||
assert len(results) == 2 | ||
# The first frame should have 23 instances of the tomato class | ||
assert len(results[0]) == 23 | ||
assert len(results[1]) == 23 | ||
# First frame | ||
for instance in results[0].values(): | ||
assert isinstance(instance.mask, np.ndarray) | ||
assert instance.mask.shape == tomatoes_image.shape[:2] | ||
assert instance.label == "tomato" | ||
|
||
# Second frame | ||
for instance in results[1].values(): | ||
assert isinstance(instance.mask, np.ndarray) | ||
assert instance.mask.shape == tomatoes_image.shape[:2] | ||
assert instance.label == "tomato" | ||
# All masks should de empty since it's a black frame | ||
assert np.all(instance.mask == 0) | ||
|
||
|
||
def test_florence2_sam2_invalid_media(): | ||
""" | ||
This test verifies that Florence2SAM2 raises an error if the media is not a valid. | ||
""" | ||
florence2_sam2 = Florence2SAM2() | ||
|
||
with pytest.raises(ValueError): | ||
florence2_sam2(image="invalid media", prompts=["tomato"]) | ||
|
||
with pytest.raises(ValueError): | ||
florence2_sam2(video="invalid media", prompts=["tomato"]) | ||
|
||
with pytest.raises(AssertionError): | ||
florence2_sam2(video=np.array([1, 2, 3]), prompts=["tomato"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
vision_agent_tools/tools/shared_types.py → vision_agent_tools/shared_types.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from dataclasses import dataclass | ||
from typing_extensions import Annotated | ||
|
||
import torch | ||
import numpy as np | ||
from PIL import Image | ||
from pydantic import validate_call | ||
|
||
from vision_agent_tools.shared_types import BaseTool, VideoNumpy, SegmentationBitMask | ||
from vision_agent_tools.tools.florencev2 import Florencev2, PromptTask | ||
|
||
from sam2.sam2_video_predictor import SAM2VideoPredictor | ||
from sam2.sam2_image_predictor import SAM2ImagePredictor | ||
|
||
|
||
_HF_MODEL = "facebook/sam2-hiera-large" | ||
|
||
|
||
@dataclass | ||
class ImageBboxAndMaskLabel: | ||
label: str | ||
bounding_box: list[ | ||
Annotated[float, "x_min"], | ||
Annotated[float, "y_min"], | ||
Annotated[float, "x_max"], | ||
Annotated[float, "y_max"], | ||
] | ||
mask: SegmentationBitMask | None | ||
|
||
|
||
@dataclass | ||
class MaskLabel: | ||
label: str | ||
mask: SegmentationBitMask | ||
|
||
|
||
class Florence2SAM2(BaseTool): | ||
""" | ||
A class that receives a video or an image plus a list of text prompts and | ||
returns the instance segmentation for the text prompts in each frame. | ||
""" | ||
|
||
def __init__(self, device: str | None = None): | ||
""" | ||
Initializes the Florence2SAM2 object with a pre-trained Florencev2 model | ||
and a SAM2 model. | ||
""" | ||
self.device = ( | ||
device | ||
if device in ["cuda", "mps", "cpu"] | ||
else "cuda" | ||
if torch.cuda.is_available() | ||
else "mps" | ||
if torch.backends.mps.is_available() | ||
else "cpu" | ||
) | ||
self.florence2 = Florencev2() | ||
self.video_predictor = SAM2VideoPredictor.from_pretrained(_HF_MODEL) | ||
self.image_predictor = SAM2ImagePredictor(self.video_predictor) | ||
|
||
@torch.inference_mode() | ||
def get_bbox_and_mask( | ||
self, image: Image.Image, prompts: list[str], return_mask: bool = True | ||
) -> dict[int, ImageBboxAndMaskLabel]: | ||
objs = {} | ||
self.image_predictor.set_image(np.array(image, dtype=np.uint8)) | ||
annotation_id = 0 | ||
for prompt in prompts: | ||
with torch.autocast(device_type=self.device, dtype=torch.float16): | ||
bboxes = self.florence2( | ||
image, PromptTask.CAPTION_TO_PHRASE_GROUNDING, prompt | ||
)[PromptTask.CAPTION_TO_PHRASE_GROUNDING]["bboxes"] | ||
if return_mask: | ||
with torch.autocast(device_type=self.device, dtype=torch.bfloat16): | ||
masks, _, _ = self.image_predictor.predict( | ||
point_coords=None, | ||
point_labels=None, | ||
box=bboxes, | ||
multimask_output=False, | ||
) | ||
for i in range(len(bboxes)): | ||
objs[annotation_id] = ImageBboxAndMaskLabel( | ||
bounding_box=bboxes[i], | ||
mask=( | ||
masks[i, 0, :, :] if len(masks.shape) == 4 else masks[i, :, :] | ||
) | ||
if return_mask | ||
else None, | ||
label=prompt, | ||
) | ||
annotation_id += 1 | ||
return objs | ||
|
||
@torch.inference_mode() | ||
def handle_image( | ||
self, image: Image.Image, prompts: list[str] | ||
) -> dict[int, dict[int, ImageBboxAndMaskLabel]]: | ||
self.image_predictor.reset_predictor() | ||
objs = self.get_bbox_and_mask(image.convert("RGB"), prompts) | ||
return {0: objs} | ||
|
||
@torch.inference_mode() | ||
def handle_video( | ||
self, video: VideoNumpy, prompts: list[str] | ||
) -> dict[int, dict[int, MaskLabel]]: | ||
self.image_predictor.reset_predictor() | ||
objs = self.get_bbox_and_mask( | ||
Image.fromarray(video[0]).convert("RGB"), prompts, return_mask=False | ||
) | ||
with torch.autocast(device_type=self.device, dtype=torch.bfloat16): | ||
inference_state = self.video_predictor.init_state(video=video) | ||
for annotation_id in objs: | ||
_, _, out_mask_logits = self.video_predictor.add_new_points_or_box( | ||
inference_state=inference_state, | ||
frame_idx=0, | ||
obj_id=annotation_id, | ||
box=objs[annotation_id].bounding_box, | ||
) | ||
|
||
annotation_id_to_label = {} | ||
for annotation_id in objs: | ||
annotation_id_to_label[annotation_id] = objs[annotation_id].label | ||
|
||
video_segments = {} | ||
for ( | ||
out_frame_idx, | ||
out_obj_ids, | ||
out_mask_logits, | ||
) in self.video_predictor.propagate_in_video(inference_state): | ||
video_segments[out_frame_idx] = { | ||
out_obj_id: MaskLabel( | ||
mask=(out_mask_logits[i][0] > 0.0).cpu().numpy(), | ||
label=annotation_id_to_label[out_obj_id], | ||
) | ||
for i, out_obj_id in enumerate(out_obj_ids) | ||
} | ||
self.video_predictor.reset_state(inference_state) | ||
return video_segments | ||
|
||
@validate_call(config={"arbitrary_types_allowed": True}) | ||
@torch.inference_mode() | ||
def __call__( | ||
self, | ||
prompts: list[str], | ||
image: Image.Image | None = None, | ||
video: VideoNumpy | None = None, | ||
) -> dict[int, dict[int, ImageBboxAndMaskLabel | MaskLabel]]: | ||
"""Returns a dictionary where the first key is the frame index then an annotation | ||
ID, then an object with the mask, label and possibly bbox (for images) for each | ||
annotation ID. For example: | ||
{ | ||
0: | ||
{ | ||
0: ImageBboxMaskLabel({"mask": np.ndarray, "label": "car"}), | ||
1: ImageBboxMaskLabel({"mask", np.ndarray, "label": "person"}) | ||
}, | ||
1: ... | ||
} | ||
""" | ||
if image is None and video is None: | ||
raise ValueError("Either 'image' or 'video' must be provided.") | ||
if image is not None and video is not None: | ||
raise ValueError("Only one of 'image' or 'video' can be provided.") | ||
|
||
if image is not None: | ||
return self.handle_image(image, prompts) | ||
elif video is not None: | ||
assert video.ndim == 4, "Video should have 4 dimensions" | ||
return self.handle_video(video, prompts) | ||
# No need to raise an error here, the validatie_call decorator will take care of it |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dillonalaird I think this value is already defined inside the file
vision_agent_tools/types.py
We should probably merge together theshared_types.py
file that you are moving out of thetools
folder.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I merged them. I moved it to this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I thought I deleted the other file but I did not. I will delete the file, but I migrated all the imports to use this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just pushed the changes