Skip to content

Commit

Permalink
add improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
CamiloInx committed Nov 13, 2024
1 parent 8893c5a commit cd3938f
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
8 changes: 6 additions & 2 deletions vision_agent_tools/models/florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Device,
Florence2ResponseType,
BaseMLModel,
ODResponse,
ODWithScoreResponse,
Florence2OCRResponse,
Florence2TextResponse,
Florence2OpenVocabularyResponse,
Expand Down Expand Up @@ -360,7 +360,11 @@ def _serialize(
| PromptTask.REGION_PROPOSAL
):
detections.append(
ODResponse(bboxes=detection["bboxes"], labels=detection["labels"])
ODWithScoreResponse(
bboxes=detection["bboxes"],
labels=detection["labels"],
scores=[1.0] * len(detection["labels"]),
)
)
case PromptTask.OCR_WITH_REGION:
detections.append(
Expand Down
4 changes: 2 additions & 2 deletions vision_agent_tools/models/sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class Sam2(BaseMLModel):
"""It receives images, a prompt and returns the instance segmentation for the
text prompt in each frame."""

def __init__(self, model_config: Sam2Config | None = Sam2Config()):
self.model_config = model_config
def __init__(self, model_config: Sam2Config | None = None):
self.model_config = model_config or Sam2Config()
self.image_model = SAM2ImagePredictor.from_pretrained(
self.model_config.hf_model
)
Expand Down
14 changes: 13 additions & 1 deletion vision_agent_tools/tools/text_to_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class TextToObjectDetectionRequest(BaseModel):
le=1.0,
description="The IoU threshold value used to apply a dummy agnostic Non-Maximum Suppression (NMS).",
)
chunk_length_frames: int | None = Field(
default=None,
ge=1,
le=30,
description="The number of frames for each chunk of video to analyze. The last chunk may have fewer frames.",
)
confidence: float | None = Field(
default=None,
ge=0.0,
Expand Down Expand Up @@ -62,7 +68,7 @@ class TextToObjectDetection(BaseTool):
def __init__(
self,
model: TextToObjectDetectionModel = TextToObjectDetectionModel.OWLV2,
model_config: OWLV2Config | None = None,
model_config: OWLV2Config | Florence2Config | None = None,
):
self.model_name = model

Expand All @@ -83,6 +89,7 @@ def __call__(
video: VideoNumpy | None = None,
*,
nms_threshold: float = 0.3,
chunk_length_frames: int | None = None,
confidence: float | None = None,
) -> list[dict[str, Any]]:
"""Run object detection on the image based on text prompts.
Expand All @@ -96,6 +103,9 @@ def __call__(
A numpy array containing the different images, representing the video.
nms_threshold:
The IoU threshold value used to apply a dummy agnostic Non-Maximum Suppression (NMS).
chunk_length_frames:
The number of frames for each chunk of video to analyze.
The last chunk may have fewer frames.
confidence:
Confidence threshold for model predictions.
Expand All @@ -108,6 +118,7 @@ def __call__(
images=images,
video=video,
nms_threshold=nms_threshold,
chunk_length_frames=chunk_length_frames,
confidence=confidence,
)

Expand Down Expand Up @@ -137,6 +148,7 @@ def __call__(
images=images,
video=video,
nms_threshold=nms_threshold,
chunk_length_frames=chunk_length_frames,
)

def to(self, device: Device):
Expand Down
12 changes: 7 additions & 5 deletions vision_agent_tools/tools/text_to_object_detection_sam2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@


class Text2ODSAM2Config(BaseModel):
sam2_config: Sam2Config | None = Sam2Config()
text2od_config: Florence2Config | OWLV2Config | None = OWLV2Config()
sam2_config: Sam2Config | None = None
text2od_config: Florence2Config | OWLV2Config | None = None


class Text2ODSam2Request(BaseModel):
Expand Down Expand Up @@ -88,14 +88,14 @@ class Text2ODSAM2(BaseMLModel):
def __init__(
self,
model: TextToObjectDetectionModel = TextToObjectDetectionModel.OWLV2,
model_config: Text2ODSAM2Config | None = Text2ODSAM2Config(),
model_config: Text2ODSAM2Config | None = None,
):
"""
Initializes the Text2ODSAM2 object with a pre-trained text2od model
and a SAM2 model.
"""
self._model = model
self._model_config = model_config
self._model_config = model_config or Text2ODSAM2Config()
self._text2od = TextToObjectDetection(
model=model, model_config=self._model_config.text2od_config
)
Expand Down Expand Up @@ -170,7 +170,9 @@ def __call__(

text2od_payload_response = self._text2od(**text2od_payload)
od_response = [
ODWithScoreResponse(**item) if len(item.get("labels")) > 0 else None
ODWithScoreResponse(**item)
if item is not None and len(item.get("labels")) > 0
else None
for item in text2od_payload_response
]
if images is not None:
Expand Down

0 comments on commit cd3938f

Please sign in to comment.