Skip to content

Commit

Permalink
Merge pull request #763 from roboflow/class_agnostic_nms_in_owlv2
Browse files Browse the repository at this point in the history
adding class agnostic nms
  • Loading branch information
probicheaux authored Oct 31, 2024
2 parents ada5c76 + c548155 commit f67162d
Showing 1 changed file with 56 additions and 22 deletions.
78 changes: 56 additions & 22 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,12 @@ def filter_tensors_by_objectness(
def get_class_preds_from_embeds(
pos_neg_embedding_dict: PosNegDictType,
image_class_embeds: torch.Tensor,
confidence: torch.Tensor,
confidence: float,
image_boxes: torch.Tensor,
class_map: torch.Tensor,
class_name: torch.Tensor,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
class_map: Dict[Tuple[str, str], int],
class_name: str,
iou_threshold: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
predicted_boxes_per_class = []
predicted_class_indices_per_class = []
predicted_scores_per_class = []
Expand All @@ -155,9 +156,9 @@ def get_class_preds_from_embeds(

if not predicted_boxes_per_class:
return (
np.empty((0, 4)),
np.empty((0,)),
np.empty((0,)),
torch.empty((0, 4)),
torch.empty((0,)),
torch.empty((0,)),
)

# concat tensors
Expand All @@ -166,12 +167,14 @@ def get_class_preds_from_embeds(
pred_scores = torch.cat(predicted_scores_per_class, dim=0).float()
positive = torch.cat(positive_arr_per_class, dim=0).float()
# nms
survival_indices = torchvision.ops.nms(to_corners(pred_boxes), pred_scores, 0.3)
survival_indices = torchvision.ops.nms(
to_corners(pred_boxes), pred_scores, iou_threshold
)
# put on numpy and filter to post-nms
pred_boxes = pred_boxes[survival_indices, :].detach().cpu().numpy()
pred_classes = pred_classes[survival_indices].detach().cpu().numpy()
pred_scores = pred_scores[survival_indices].detach().cpu().numpy()
positive = positive[survival_indices].detach().cpu().numpy()
pred_boxes = pred_boxes[survival_indices, :]
pred_classes = pred_classes[survival_indices]
pred_scores = pred_scores[survival_indices]
positive = positive[survival_indices]
is_positive = positive == 1
# return only positive elements of tensor
return pred_boxes[is_positive], pred_classes[is_positive], pred_scores[is_positive]
Expand Down Expand Up @@ -307,7 +310,9 @@ def embed_image(self, image: np.ndarray) -> Hash:

return image_hash

def get_query_embedding(self, query_spec: QuerySpecType) -> torch.Tensor:
def get_query_embedding(
self, query_spec: QuerySpecType, iou_threshold: float
) -> torch.Tensor:
# NOTE: for now we're handling each image seperately
query_embeds = []
for image_hash, query_boxes in query_spec.items():
Expand All @@ -326,7 +331,7 @@ def get_query_embedding(self, query_spec: QuerySpecType) -> torch.Tensor:
) # 3000, k
ious, indices = torch.max(iou, dim=0)
# filter for only iou > 0.4
iou_mask = ious > 0.4
iou_mask = ious > iou_threshold
indices = indices[iou_mask]
if not indices.numel() > 0:
continue
Expand All @@ -343,6 +348,7 @@ def infer_from_embed(
image_hash: Hash,
query_embeddings: Dict[str, PosNegDictType],
confidence: float,
iou_threshold: float,
) -> List[Dict]:
_, image_boxes, image_class_embeds, _, _ = self.image_embed_cache[image_hash]
class_map, class_names = make_class_map(query_embeddings)
Expand All @@ -355,11 +361,30 @@ def infer_from_embed(
image_boxes,
class_map,
class_name,
iou_threshold,
)

all_predicted_boxes.extend(boxes)
all_predicted_classes.extend(classes)
all_predicted_scores.extend(scores)
all_predicted_boxes.append(boxes)
all_predicted_classes.append(classes)
all_predicted_scores.append(scores)

all_predicted_boxes = torch.cat(all_predicted_boxes, dim=0)
all_predicted_classes = torch.cat(all_predicted_classes, dim=0)
all_predicted_scores = torch.cat(all_predicted_scores, dim=0)

# run nms on all predictions
survival_indices = torchvision.ops.nms(
to_corners(all_predicted_boxes), all_predicted_scores, iou_threshold
)
all_predicted_boxes = all_predicted_boxes[survival_indices]
all_predicted_classes = all_predicted_classes[survival_indices]
all_predicted_scores = all_predicted_scores[survival_indices]

# move tensors to numpy before returning
all_predicted_boxes = all_predicted_boxes.cpu().numpy()
all_predicted_classes = all_predicted_classes.cpu().numpy()
all_predicted_scores = all_predicted_scores.cpu().numpy()

return [
{
"class_name": class_names[int(c)],
Expand All @@ -374,10 +399,19 @@ def infer_from_embed(
)
]

def infer(self, image: Any, training_data: Dict, confidence=0.99, **kwargs):
def infer(
self,
image: Any,
training_data: Dict,
confidence=0.99,
iou_threshold=0.3,
**kwargs,
):
class_to_query_spec = self.make_class_box_query_dict(training_data)

class_embeddings_dict = self.make_class_embeddings_dict(class_to_query_spec)
class_embeddings_dict = self.make_class_embeddings_dict(
class_to_query_spec, iou_threshold
)

if not isinstance(image, list):
images = [image]
Expand All @@ -391,22 +425,22 @@ def infer(self, image: Any, training_data: Dict, confidence=0.99, **kwargs):
image_sizes.append(image.shape[:2][::-1])
image_hash = self.embed_image(image)
result = self.infer_from_embed(
image_hash, class_embeddings_dict, confidence
image_hash, class_embeddings_dict, confidence, iou_threshold
)
results.append(result)
return self.make_response(
results, image_sizes, sorted(list(class_embeddings_dict.keys()))
)

def make_class_embeddings_dict(
self, class_to_query_spec: Dict[Tuple[str, str], Dict]
self, class_to_query_spec: Dict[Tuple[str, str], Dict], iou_threshold: float
) -> Dict[str, PosNegDictType]:
class_embeddings_dict = defaultdict(
lambda: {"positive": None, "negative": None}
)
bool_to_literal = {True: "positive", False: "negative"}
for (class_name, positive), query_spec in class_to_query_spec.items():
class_embedding = self.get_query_embedding(query_spec)
class_embedding = self.get_query_embedding(query_spec, iou_threshold)
class_embeddings_dict[class_name][
bool_to_literal[positive]
] = class_embedding
Expand Down

0 comments on commit f67162d

Please sign in to comment.