Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Florence2SAM2 #24

Merged
merged 14 commits into from
Aug 13, 2024
649 changes: 421 additions & 228 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ transformers = {extras = ["torch"], version = "^4.42.3", optional = true}
scipy = {version = "^1.13.1", optional = true}
gdown = "^5.1.0"
wget = "^3.2"
torch = "2.2.2"
torch = ">=2.3.1"
timm = {version = "^0.6.7", optional = true}
einops = {version = "^0.7.0", optional = true}
loca = { git = "https://github.com/landing-ai/loca.git", branch = "main", optional = true }
depth-anything-v2 = { git = "https://github.com/landing-ai/depth-anything-v2.git", branch = "main", optional = true }
controlnet-aux = {version = "^0.0.9", optional = true}
lmdeploy = {version = "^0.5.3", optional = true}
decord = {version = "^0.6.0", optional = true}
sam-2 = {git = "https://github.com/landing-ai/segment-anything-2.git", branch = "main", optional = true}


[tool.poetry.group.dev.dependencies]
Expand All @@ -42,7 +43,7 @@ mkdocs-material = "^9.5.28"
griffe-fieldz = "^0.1.2"

[tool.poetry.extras]
all = ["qreader", "transformers", "scipy", "loca", "depth-anything-v2", "timm", "einops", "controlnet-aux", "lmdeploy", "decord"]
all = ["qreader", "transformers", "scipy", "loca", "depth-anything-v2", "timm", "einops", "controlnet-aux", "lmdeploy", "decord", "sam-2"]
qr-reader = ["qreader"]
owlv2 = ["transformers", "scipy"]
florencev2 = ["transformers", "scipy", "timm", "einops"]
Expand All @@ -53,6 +54,7 @@ controlnet-aux = ["controlnet-aux"]
florencev2-qa = ["transformers", "scipy", "timm", "einops"]
clip-media-sim = ["transformers"]
ixc-25 = ["transformers", "lmdeploy", "decord"]
florence2-sam2 = ["transformers", "scipy", "timm", "einops", "sam-2"]


[build-system]
Expand Down
80 changes: 80 additions & 0 deletions tests/tools/test_florence2_sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import pytest
from PIL import Image

from vision_agent_tools.tools.florence2_sam2 import Florence2SAM2


def test_successful_florence2_sam2_image():
"""
This test verifies that CLIPMediaSim returns a valid iresponse when passed a target_text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should change CLIPMediaSim to Florence2SAM2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done. thanks.

"""
test_image = Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB")

florence2_sam2 = Florence2SAM2()

results = florence2_sam2(media=test_image, prompts=["tomato"])

# The disctionary should have only one key: 0
assert len(results) == 1
# The dictionary should have 23 instances of the tomato class
assert len(results[0]) == 23
for instance in results[0].values():
assert len(instance.bounding_box) == 4
assert np.all(
[
0 <= coord <= np.max(test_image.size[:2])
for coord in instance.bounding_box
]
)
assert isinstance(instance.mask, np.ndarray)
assert instance.mask.shape == test_image.size[::-1]
assert instance.label == "tomato"


def test_successful_florence2_sam2_video():
"""
This test verifies that CLIPMediaSim returns a valid iresponse when passed a target_text
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, please change CLIPMediaSim

"""
tomatoes_image = np.array(
Image.open("tests/tools/data/loca/tomatoes.jpg").convert("RGB"), dtype=np.uint8
)
test_video = np.array(
[tomatoes_image, np.zeros(tomatoes_image.shape, dtype=np.uint8)]
)

florence2_sam2 = Florence2SAM2()

results = florence2_sam2(media=test_video, prompts=["tomato"])

# The disctionary should have 2 keys for the two frames in the video
assert len(results) == 2
# The first frame should have 23 instances of the tomato class
assert len(results[0]) == 23
assert len(results[1]) == 23
# First frame
for instance in results[0].values():
assert isinstance(instance.mask, np.ndarray)
assert instance.mask.shape == tomatoes_image.shape[:2]
assert instance.label == "tomato"

# Second frame
for instance in results[1].values():
assert isinstance(instance.mask, np.ndarray)
assert instance.mask.shape == tomatoes_image.shape[:2]
assert instance.label == "tomato"
# All masks should de empty since it's a black frame
assert np.all(instance.mask == 0)


def test_florence2_sam2_invalid_media():
"""
This test verifies that CLIPMediaSim raises a ValueError if the media is not a valid type.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

"""
florence2_sam2 = Florence2SAM2()

with pytest.raises(ValueError):
florence2_sam2(media="invalid media", prompts=["tomato"])

with pytest.raises(AssertionError):
florence2_sam2(media=np.array([1, 2, 3]), prompts=["tomato"])
2 changes: 1 addition & 1 deletion vision_agent_tools/helpers/roberta_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import pipeline
from pydantic import BaseModel
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool

MODEL_NAME = "deepset/roberta-base-squad2"
PROCESSOR_NAME = "deepset/roberta-base-squad2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from typing import Annotated, Literal, TypeVar

from pydantic import BaseModel
import numpy as np
import numpy.typing as npt


class BaseTool:
pass


DType = TypeVar("DType", bound=np.generic)

VideoNumpy = Annotated[npt.NDArray[DType], Literal["N", "N", "N", 3]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dillonalaird I think this value is already defined inside the file vision_agent_tools/types.py We should probably merge together the shared_types.py file that you are moving out of the tools folder.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged them. I moved it to this file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I thought I deleted the other file but I did not. I will delete the file, but I migrated all the imports to use this file.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed the changes


SegmentationBitMask = Annotated[npt.NDArray[np.bool_], Literal["N", "N"]]


class Point(BaseModel):
# X coordinate of the point
x: float
Expand Down
3 changes: 1 addition & 2 deletions vision_agent_tools/tools/clip_media_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from pydantic import validate_call
from transformers import CLIPModel, CLIPProcessor

from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.types import VideoNumpy
from vision_agent_tools.shared_types import BaseTool, VideoNumpy


_HF_MODEL = "openai/clip-vit-large-patch14"
Expand Down
3 changes: 1 addition & 2 deletions vision_agent_tools/tools/depth_anything_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
# Run this line before loading torch
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

import cv2
import numpy as np
import os.path as osp
import torch

from PIL import Image
from .utils import download, CHECKPOINT_DIR
from typing import Union, Any
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool
from depth_anything_v2.dpt import DepthAnythingV2 as DepthAnythingV2Model
from pydantic import BaseModel

Expand Down
144 changes: 144 additions & 0 deletions vision_agent_tools/tools/florence2_sam2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from dataclasses import dataclass
from typing_extensions import Annotated

import torch
import numpy as np
from PIL import Image
from pydantic import validate_call

from vision_agent_tools.shared_types import BaseTool, VideoNumpy, SegmentationBitMask
from vision_agent_tools.tools.florencev2 import Florencev2, PromptTask

from sam2.sam2_video_predictor import SAM2VideoPredictor
from sam2.sam2_image_predictor import SAM2ImagePredictor


_HF_MODEL = "facebook/sam2-hiera-large"


@dataclass
class ImageBboxAndMaskLabel:
label: str
bounding_box: list[
Annotated[float, "x_min"],
Annotated[float, "y_min"],
Annotated[float, "x_max"],
Annotated[float, "y_max"],
]
mask: SegmentationBitMask | None


@dataclass
class MaskLabel:
label: str
mask: SegmentationBitMask


class Florence2SAM2(BaseTool):
def __init__(self):
self.florence2 = Florencev2()
self.video_predictor = SAM2VideoPredictor.from_pretrained(_HF_MODEL)
self.image_predictor = SAM2ImagePredictor(self.video_predictor)

@torch.inference_mode()
def get_bbox_and_mask(
self, image: Image.Image, prompts: list[str], return_mask: bool = True
) -> dict[int, ImageBboxAndMaskLabel]:
objs = {}
self.image_predictor.set_image(np.array(image, dtype=np.uint8))
annotation_id = 0
for prompt in prompts:
with torch.autocast(device_type="cuda", dtype=torch.float16):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dillonalaird I think this would be better to define a self.device value that stores the device type instead of hard coding it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! done

bboxes = self.florence2(
image, PromptTask.CAPTION_TO_PHRASE_GROUNDING, prompt
)[PromptTask.CAPTION_TO_PHRASE_GROUNDING]["bboxes"]
if return_mask:
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

masks, _, _ = self.image_predictor.predict(
point_coords=None,
point_labels=None,
box=bboxes,
multimask_output=False,
)
for i in range(len(bboxes)):
objs[annotation_id] = ImageBboxAndMaskLabel(
bounding_box=bboxes[i],
mask=(
masks[i, 0, :, :] if len(masks.shape) == 4 else masks[i, :, :]
)
if return_mask
else None,
label=prompt,
)
annotation_id += 1
return objs

@torch.inference_mode()
def handle_image(
self, image: Image.Image, prompts: list[str]
) -> dict[int, dict[int, ImageBboxAndMaskLabel]]:
self.image_predictor.reset_predictor()
objs = self.get_bbox_and_mask(image.convert("RGB"), prompts)
return {0: objs}

@torch.inference_mode()
def handle_video(
self, video: VideoNumpy, prompts: list[str]
) -> dict[int, dict[int, MaskLabel]]:
self.image_predictor.reset_predictor()
objs = self.get_bbox_and_mask(
Image.fromarray(video[0]).convert("RGB"), prompts, return_mask=False
)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, change to self.device

inference_state = self.video_predictor.init_state(video=video)
for annotation_id in objs:
_, _, out_mask_logits = self.video_predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=0,
obj_id=annotation_id,
box=objs[annotation_id].bounding_box,
)

annotation_id_to_label = {}
for annotation_id in objs:
annotation_id_to_label[annotation_id] = objs[annotation_id].label

video_segments = {}
for (
out_frame_idx,
out_obj_ids,
out_mask_logits,
) in self.video_predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: MaskLabel(
mask=(out_mask_logits[i][0] > 0.0).cpu().numpy(),
label=annotation_id_to_label[out_obj_id],
)
for i, out_obj_id in enumerate(out_obj_ids)
}
self.video_predictor.reset_state(inference_state)
return video_segments

@validate_call(config={"arbitrary_types_allowed": True})
@torch.inference_mode()
def __call__(
self, media: Image.Image | VideoNumpy, prompts: list[str]
Copy link
Member

@CamiloInx CamiloInx Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dillonalaird @camiloaz should we handle the input value as either image or video as separate values as we do here? I think we should be consistent and change one of these two to have the same format.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, agree. will do that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

) -> dict[int, dict[int, ImageBboxAndMaskLabel | MaskLabel]]:
"""Returns a dictionary where the first key is the frame index then an annotation
ID, then a dictionary of the mask, label and possibly bbox (for images) for each
annotation ID. For example:
{
0:
{
0: {"mask": np.ndarray, "label": "car"},
1: {"mask", np.ndarray, "label": "person"}
},
1: ...
}
"""
if isinstance(media, Image.Image):
return self.handle_image(media, prompts)
elif isinstance(media, np.ndarray):
assert media.ndim == 4, "Video should have 4 dimensions"
return self.handle_video(media, prompts)
# No need to raise an error here, the validatie_call decorator will take care of it
2 changes: 1 addition & 1 deletion vision_agent_tools/tools/florencev2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool

MODEL_NAME = "microsoft/Florence-2-large"
PROCESSOR_NAME = "microsoft/Florence-2-large"
Expand Down
3 changes: 1 addition & 2 deletions vision_agent_tools/tools/florencev2_qa.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from PIL import Image
from typing import Dict
import torch

from vision_agent_tools.tools.florencev2 import Florencev2, PromptTask
from vision_agent_tools.helpers.roberta_qa import RobertaQA
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool


class FlorenceQA(BaseTool):
Expand Down
3 changes: 1 addition & 2 deletions vision_agent_tools/tools/internlm_xcomposer2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from PIL import Image
from vision_agent_tools.types import VideoNumpy
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool, VideoNumpy
from pydantic import Field, validate_call
from typing import Annotated, Optional

Expand Down
2 changes: 1 addition & 1 deletion vision_agent_tools/tools/nsfw_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from PIL import Image
from pydantic import BaseModel
from transformers import AutoModelForImageClassification, ViTImageProcessor
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool

CHECKPOINT = "Falconsai/nsfw_image_detection"

Expand Down
4 changes: 2 additions & 2 deletions vision_agent_tools/tools/nshot_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from PIL import Image
from loca.loca import LOCA
from .utils import download, CHECKPOINT_DIR
from typing import Union, Optional, Any
from typing import Optional, Any
from torch import nn
from torchvision import transforms as T
from pydantic import BaseModel
from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool


class CountingDetection(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion vision_agent_tools/tools/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel
from transformers import Owlv2ForObjectDetection, Owlv2Processor

from vision_agent_tools.tools.shared_types import BaseTool
from vision_agent_tools.shared_types import BaseTool

MODEL_NAME = "google/owlv2-large-patch14-ensemble"
PROCESSOR_NAME = "google/owlv2-large-patch14-ensemble"
Expand Down
2 changes: 1 addition & 1 deletion vision_agent_tools/tools/qr_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from qreader import QReader

from vision_agent_tools.tools.shared_types import BaseTool, Polygon, Point, BoundingBox
from vision_agent_tools.shared_types import BaseTool, Polygon, Point, BoundingBox


class QRCodeDetection(BaseModel):
Expand Down