Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Florence2SAM2 #24

Merged
merged 14 commits into from
Aug 13, 2024
49 changes: 49 additions & 0 deletions docs/florence2-sam2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Florence2Sam2

This tool uses FlorenceV2 and the SAM-2 model to do text to instance segmentation on image or video inputs.

```python
from vision_agent_tools.tools.florence2_sam2 import Florence2SAM2
from decord import VideoReader
from decord import cpu


# Path to your video
video_path = "path/to/your/video.mp4"

# Load the video
vr = VideoReader(video_path, ctx=cpu(0))

# Subsample frames
frame_idxs = range(0, len(vr) - 1, 20)
frames = vr.get_batch(frame_idxs).asnumpy()

# Create the Florence2SAM2 instance
florence2_sam2 = Florence2SAM2()

# segment all the instances of the prompt "ball" for all video frames
results = florence2_sam2(video=frames, prompts=["ball"])

# Returns a dictionary where the first key is the frame index then an annotation
# ID, then an object with the mask, label and possibly bbox (for images) for each
# annotation ID. For example:
# {
# 0:
# {
# 0: ImageBboxMaskLabel({"mask": np.ndarray, "label": "car"}),
# 1: ImageBboxMaskLabel({"mask", np.ndarray, "label": "person"})
# },
# 1: ...
# }

print("Instance segmentation complete!")

```

You can also run similarity against an image and get additionally bounding boxes doing the following:

```python
results = florence2_sam2(image=image, prompts=["ball"])
```

::: vision_agent_tools.tools.florence2_sam2
649 changes: 421 additions & 228 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ transformers = {extras = ["torch"], version = "^4.42.3", optional = true}
scipy = {version = "^1.13.1", optional = true}
gdown = "^5.1.0"
wget = "^3.2"
torch = "2.2.2"
torch = ">=2.3.1"
timm = {version = "^0.6.7", optional = true}
einops = {version = "^0.7.0", optional = true}
loca = { git = "https://github.com/landing-ai/loca.git", branch = "main", optional = true }
depth-anything-v2 = { git = "https://github.com/landing-ai/depth-anything-v2.git", branch = "main", optional = true }
controlnet-aux = {version = "^0.0.9", optional = true}
lmdeploy = {version = "^0.5.3", optional = true}
decord = {version = "^0.6.0", optional = true}
sam-2 = {git = "https://github.com/landing-ai/segment-anything-2.git", branch = "main", optional = true}


[tool.poetry.group.dev.dependencies]
Expand All @@ -42,7 +43,7 @@ mkdocs-material = "^9.5.28"
griffe-fieldz = "^0.1.2"

[tool.poetry.extras]
all = ["qreader", "transformers", "scipy", "loca", "depth-anything-v2", "timm", "einops", "controlnet-aux", "lmdeploy", "decord"]
all = ["qreader", "transformers", "scipy", "loca", "depth-anything-v2", "timm", "einops", "controlnet-aux", "lmdeploy", "decord", "sam-2"]
qr-reader = ["qreader"]
owlv2 = ["transformers", "scipy"]
florencev2 = ["transformers", "scipy", "timm", "einops"]
Expand All @@ -53,6 +54,7 @@ controlnet-aux = ["controlnet-aux"]
florencev2-qa = ["transformers", "scipy", "timm", "einops"]
clip-media-sim = ["transformers"]
ixc-25 = ["transformers", "lmdeploy", "decord"]
florence2-sam2 = ["transformers", "scipy", "timm", "einops", "sam-2"]


[build-system]
Expand Down
83 changes: 83 additions & 0 deletions tests/tools/test_florence2_sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import numpy as np
import pytest
from PIL import Image

from vision_agent_tools.tools.florence2_sam2 import Florence2SAM2


def test_successful_florence2_sam2_image():
"""
This test verifies that Florence2SAM2 returns a valid iresponse when passed an image
"""
test_image = Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB")

florence2_sam2 = Florence2SAM2()

results = florence2_sam2(image=test_image, prompts=["tomato"])

# The disctionary should have only one key: 0
assert len(results) == 1
# The dictionary should have 23 instances of the tomato class
assert len(results[0]) == 23
for instance in results[0].values():
assert len(instance.bounding_box) == 4
assert np.all(
[
0 <= coord <= np.max(test_image.size[:2])
for coord in instance.bounding_box
]
)
assert isinstance(instance.mask, np.ndarray)
assert instance.mask.shape == test_image.size[::-1]
assert instance.label == "tomato"


def test_successful_florence2_sam2_video():
"""
This test verifies that Florence2SAM2 returns a valid iresponse when passed a video
"""
tomatoes_image = np.array(
Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB"), dtype=np.uint8
)
test_video = np.array(
[tomatoes_image, np.zeros(tomatoes_image.shape, dtype=np.uint8)]
)

florence2_sam2 = Florence2SAM2()

results = florence2_sam2(video=test_video, prompts=["tomato"])

# The disctionary should have 2 keys for the two frames in the video
assert len(results) == 2
# The first frame should have 23 instances of the tomato class
assert len(results[0]) == 23
assert len(results[1]) == 23
# First frame
for instance in results[0].values():
assert isinstance(instance.mask, np.ndarray)
assert instance.mask.shape == tomatoes_image.shape[:2]
assert instance.label == "tomato"

# Second frame
for instance in results[1].values():
assert isinstance(instance.mask, np.ndarray)
assert instance.mask.shape == tomatoes_image.shape[:2]
assert instance.label == "tomato"
# All masks should de empty since it's a black frame
assert np.all(instance.mask == 0)


def test_florence2_sam2_invalid_media():
"""
This test verifies that Florence2SAM2 raises an error if the media is not a valid.
"""
florence2_sam2 = Florence2SAM2()

with pytest.raises(ValueError):
florence2_sam2(image="invalid media", prompts=["tomato"])

with pytest.raises(ValueError):
florence2_sam2(video="invalid media", prompts=["tomato"])

with pytest.raises(AssertionError):
florence2_sam2(video=np.array([1, 2, 3]), prompts=["tomato"])
2 changes: 1 addition & 1 deletion vision_agent_tools/helpers/roberta_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import pipeline
from pydantic import BaseModel
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool

MODEL_NAME = "deepset/roberta-base-squad2"
PROCESSOR_NAME = "deepset/roberta-base-squad2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from typing import Annotated, Literal, TypeVar

from pydantic import BaseModel
import numpy as np
import numpy.typing as npt


class BaseTool:
pass


DType = TypeVar("DType", bound=np.generic)

VideoNumpy = Annotated[npt.NDArray[DType], Literal["N", "N", "N", 3]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dillonalaird I think this value is already defined inside the file vision_agent_tools/types.py We should probably merge together the shared_types.py file that you are moving out of the tools folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged them. I moved it to this file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I thought I deleted the other file but I did not. I will delete the file, but I migrated all the imports to use this file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed the changes


SegmentationBitMask = Annotated[npt.NDArray[np.bool_], Literal["N", "N"]]


class Point(BaseModel):
# X coordinate of the point
x: float
Expand Down
3 changes: 1 addition & 2 deletions vision_agent_tools/tools/clip_media_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from pydantic import validate_call
from transformers import CLIPModel, CLIPProcessor

from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.types import VideoNumpy
from vision_agent_tools.shared_types import BaseTool, VideoNumpy


_HF_MODEL = "openai/clip-vit-large-patch14"
Expand Down
3 changes: 1 addition & 2 deletions vision_agent_tools/tools/depth_anything_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# Run this line before loading torch
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import cv2
import numpy as np
import os.path as osp
import torch

from PIL import Image
from .utils import download, CHECKPOINT_DIR
from typing import Union, Any
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool
from depth_anything_v2.dpt import DepthAnythingV2 as DepthAnythingV2Model
from pydantic import BaseModel

Expand Down
170 changes: 170 additions & 0 deletions vision_agent_tools/tools/florence2_sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from dataclasses import dataclass
from typing_extensions import Annotated

import torch
import numpy as np
from PIL import Image
from pydantic import validate_call

from vision_agent_tools.shared_types import BaseTool, VideoNumpy, SegmentationBitMask
from vision_agent_tools.tools.florencev2 import Florencev2, PromptTask

from sam2.sam2_video_predictor import SAM2VideoPredictor
from sam2.sam2_image_predictor import SAM2ImagePredictor


_HF_MODEL = "facebook/sam2-hiera-large"


@dataclass
class ImageBboxAndMaskLabel:
label: str
bounding_box: list[
Annotated[float, "x_min"],
Annotated[float, "y_min"],
Annotated[float, "x_max"],
Annotated[float, "y_max"],
]
mask: SegmentationBitMask | None


@dataclass
class MaskLabel:
label: str
mask: SegmentationBitMask


class Florence2SAM2(BaseTool):
"""
A class that receives a video or an image plus a list of text prompts and
returns the instance segmentation for the text prompts in each frame.
"""

def __init__(self, device: str | None = None):
"""
Initializes the Florence2SAM2 object with a pre-trained Florencev2 model
and a SAM2 model.
"""
self.device = (
device
if device in ["cuda", "mps", "cpu"]
else "cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self.florence2 = Florencev2()
self.video_predictor = SAM2VideoPredictor.from_pretrained(_HF_MODEL)
self.image_predictor = SAM2ImagePredictor(self.video_predictor)

@torch.inference_mode()
def get_bbox_and_mask(
self, image: Image.Image, prompts: list[str], return_mask: bool = True
) -> dict[int, ImageBboxAndMaskLabel]:
objs = {}
self.image_predictor.set_image(np.array(image, dtype=np.uint8))
annotation_id = 0
for prompt in prompts:
with torch.autocast(device_type=self.device, dtype=torch.float16):
bboxes = self.florence2(
image, PromptTask.CAPTION_TO_PHRASE_GROUNDING, prompt
)[PromptTask.CAPTION_TO_PHRASE_GROUNDING]["bboxes"]
if return_mask:
with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
masks, _, _ = self.image_predictor.predict(
point_coords=None,
point_labels=None,
box=bboxes,
multimask_output=False,
)
for i in range(len(bboxes)):
objs[annotation_id] = ImageBboxAndMaskLabel(
bounding_box=bboxes[i],
mask=(
masks[i, 0, :, :] if len(masks.shape) == 4 else masks[i, :, :]
)
if return_mask
else None,
label=prompt,
)
annotation_id += 1
return objs

@torch.inference_mode()
def handle_image(
self, image: Image.Image, prompts: list[str]
) -> dict[int, dict[int, ImageBboxAndMaskLabel]]:
self.image_predictor.reset_predictor()
objs = self.get_bbox_and_mask(image.convert("RGB"), prompts)
return {0: objs}

@torch.inference_mode()
def handle_video(
self, video: VideoNumpy, prompts: list[str]
) -> dict[int, dict[int, MaskLabel]]:
self.image_predictor.reset_predictor()
objs = self.get_bbox_and_mask(
Image.fromarray(video[0]).convert("RGB"), prompts, return_mask=False
)
with torch.autocast(device_type=self.device, dtype=torch.bfloat16):
inference_state = self.video_predictor.init_state(video=video)
for annotation_id in objs:
_, _, out_mask_logits = self.video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=annotation_id,
box=objs[annotation_id].bounding_box,
)

annotation_id_to_label = {}
for annotation_id in objs:
annotation_id_to_label[annotation_id] = objs[annotation_id].label

video_segments = {}
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in self.video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: MaskLabel(
mask=(out_mask_logits[i][0] > 0.0).cpu().numpy(),
label=annotation_id_to_label[out_obj_id],
)
for i, out_obj_id in enumerate(out_obj_ids)
}
self.video_predictor.reset_state(inference_state)
return video_segments

@validate_call(config={"arbitrary_types_allowed": True})
@torch.inference_mode()
def __call__(
self,
prompts: list[str],
image: Image.Image | None = None,
video: VideoNumpy | None = None,
) -> dict[int, dict[int, ImageBboxAndMaskLabel | MaskLabel]]:
"""Returns a dictionary where the first key is the frame index then an annotation
ID, then an object with the mask, label and possibly bbox (for images) for each
annotation ID. For example:
{
0:
{
0: ImageBboxMaskLabel({"mask": np.ndarray, "label": "car"}),
1: ImageBboxMaskLabel({"mask", np.ndarray, "label": "person"})
},
1: ...
}
"""
if image is None and video is None:
raise ValueError("Either 'image' or 'video' must be provided.")
if image is not None and video is not None:
raise ValueError("Only one of 'image' or 'video' can be provided.")

if image is not None:
return self.handle_image(image, prompts)
elif video is not None:
assert video.ndim == 4, "Video should have 4 dimensions"
return self.handle_video(video, prompts)
# No need to raise an error here, the validatie_call decorator will take care of it
Loading