Skip to content

Commit

Permalink
Let load classes for YOLO models
Browse files Browse the repository at this point in the history
Actually, when using models from example from https://github.com/aperveyev/booru_yolo/tree/main/models, it uses all classes when it founds it.
This PR lets the user set a class index or number to use that class instead of every class.
For now it prints all classes found, but only uses the one that the user entered.
  • Loading branch information
Panchovix committed Dec 2, 2024
1 parent 358d170 commit db9bbed
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
7 changes: 6 additions & 1 deletion aaaaaa/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def on_ad_model_update(model: str):
visible=True,
placeholder="Comma separated class names to detect, ex: 'person,cat'. default: COCO 80 classes",
)
elif "yolo" in model.lower():
return gr.update(
visible=True,
placeholder="Comma separated class numbers to detect or separated class names, ex: '0,1' for first 2 classes, or 'head, hip",
)
return gr.update(visible=False, placeholder="")


Expand Down Expand Up @@ -203,7 +208,7 @@ def one_ui_group(n: int, is_img2img: bool, webui_info: WebuiInfo):
w.ad_model_classes = gr.Textbox(
label="ADetailer detector classes" + suffix(n),
value="",
visible=False,
visible=True,
elem_id=eid("ad_model_classes"),
)

Expand Down
48 changes: 42 additions & 6 deletions adetailer/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from adetailer import PredictOutput
from adetailer.common import create_mask_from_bbox

import numpy as np
if TYPE_CHECKING:
import torch
from ultralytics import YOLO, YOLOWorld
Expand All @@ -23,10 +23,30 @@ def ultralytics_predict(
classes: str = "",
) -> PredictOutput[float]:
from ultralytics import YOLO

model = YOLO(model_path)
apply_classes(model, model_path, classes)
class_indices = []
if classes:
parsed = [c.strip() for c in classes.split(",") if c.strip()]
for c in parsed:
if c.isdigit():
class_indices.append(int(c))
elif c in model.names.values():
# Find the index for the class name
for idx, name in model.names.items():
if name == c:
class_indices.append(idx)
break

pred = model(image, conf=confidence, device=device)

if class_indices and len(pred[0].boxes) > 0:
cls = pred[0].boxes.cls.cpu().numpy()
mask = np.isin(cls, class_indices)

# Apply mask to boxes
pred[0].boxes.data = pred[0].boxes.data[mask]
if pred[0].masks is not None:
pred[0].masks.data = pred[0].masks.data[mask]

bboxes = pred[0].boxes.xyxy.cpu().numpy()
if bboxes.size == 0:
Expand All @@ -50,11 +70,27 @@ def ultralytics_predict(


def apply_classes(model: YOLO | YOLOWorld, model_path: str | Path, classes: str):
if not classes or "-world" not in Path(model_path).stem:
if not classes:
return

parsed = [c.strip() for c in classes.split(",") if c.strip()]
if parsed:
model.set_classes(parsed)
if not parsed:
return

try:
class_indices = []
for c in parsed:
if c.isdigit():
class_indices.append(int(c))
elif c in model.names.values():
for idx, name in model.names.items():
if name == c:
class_indices.append(idx)
break

model.classes = class_indices
except Exception as e:
print(f"Error setting classes: {e}")


def mask_to_pil(masks: torch.Tensor, shape: tuple[int, int]) -> list[Image.Image]:
Expand Down

0 comments on commit db9bbed

Please sign in to comment.