Skip to content

Commit

Permalink
Update ultralytics models (take two) (#613)
Browse files Browse the repository at this point in the history
Co-authored-by: Helio Machado <[email protected]>
  • Loading branch information
dreadatour and 0x2b3bfa0 authored Nov 22, 2024
1 parent 894c13d commit 224f8a6
Show file tree
Hide file tree
Showing 19 changed files with 830 additions and 183 deletions.
24 changes: 9 additions & 15 deletions examples/computer_vision/openimage-detect.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,11 @@
import json

from PIL import Image
from pydantic import BaseModel

from datachain import C, DataChain, File
from datachain import C, DataChain, File, model
from datachain.sql.functions import path


class BBox(BaseModel):
x_min: int
x_max: int
y_min: int
y_max: int


def openimage_detect(args):
if len(args) != 2:
raise ValueError("Group jpg-json mismatch")
Expand All @@ -30,11 +22,13 @@ def openimage_detect(args):
detections = json.load(stream_json).get("detections", [])

for i, detect in enumerate(detections):
bbox = BBox(
x_min=int(detect["XMin"] * img.width),
x_max=int(detect["XMax"] * img.width),
y_min=int(detect["YMin"] * img.height),
y_max=int(detect["YMax"] * img.height),
bbox = model.BBox.from_list(
[
detect["XMin"] * img.width,
detect["XMax"] * img.width,
detect["YMin"] * img.height,
detect["YMax"] * img.height,
]
)

fstream = File(
Expand All @@ -56,7 +50,7 @@ def openimage_detect(args):
openimage_detect,
partition_by=path.file_stem(C("file.path")),
params=["file"],
output={"file": File, "bbox": BBox},
output={"file": File, "bbox": model.BBox},
)
.show()
)
22 changes: 22 additions & 0 deletions examples/computer_vision/ultralytics-bbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from io import BytesIO

from PIL import Image
from ultralytics import YOLO

from datachain import C, DataChain, File
from datachain.model.ultralytics import YoloBBoxes


def process_bboxes(yolo: YOLO, file: File) -> YoloBBoxes:
results = yolo(Image.open(BytesIO(file.read())))
return YoloBBoxes.from_results(results)


(
DataChain.from_storage("gs://datachain-demo/openimages-v6-test-jsonpairs/")
.filter(C("file.path").glob("*.jpg"))
.limit(20)
.setup(yolo=lambda: YOLO("yolo11n.pt"))
.map(boxes=process_bboxes)
.show()
)
22 changes: 22 additions & 0 deletions examples/computer_vision/ultralytics-pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from io import BytesIO

from PIL import Image
from ultralytics import YOLO

from datachain import C, DataChain, File
from datachain.model.ultralytics import YoloPoses


def process_poses(yolo: YOLO, file: File) -> YoloPoses:
results = yolo(Image.open(BytesIO(file.read())))
return YoloPoses.from_results(results)


(
DataChain.from_storage("gs://datachain-demo/openimages-v6-test-jsonpairs/")
.filter(C("file.path").glob("*.jpg"))
.limit(20)
.setup(yolo=lambda: YOLO("yolo11n-pose.pt"))
.map(poses=process_poses)
.show()
)
22 changes: 22 additions & 0 deletions examples/computer_vision/ultralytics-segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from io import BytesIO

from PIL import Image
from ultralytics import YOLO

from datachain import C, DataChain, File
from datachain.model.ultralytics import YoloSegments


def process_segments(yolo: YOLO, file: File) -> YoloSegments:
results = yolo(Image.open(BytesIO(file.read())))
return YoloSegments.from_results(results)


(
DataChain.from_storage("gs://datachain-demo/openimages-v6-test-jsonpairs/")
.filter(C("file.path").glob("*.jpg"))
.limit(20)
.setup(yolo=lambda: YOLO("yolo11n-seg.pt"))
.map(segments=process_segments)
.show()
)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ examples = [
"unstructured[pdf,embed-huggingface]<0.16.0",
"pdfplumber==0.11.4",
"huggingface_hub[hf_transfer]",
"onnx==1.16.1"
"onnx==1.16.1",
"ultralytics==8.3.29"
]

[project.urls]
Expand Down
3 changes: 1 addition & 2 deletions src/datachain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datachain.lib import func, models
from datachain.lib import func
from datachain.lib.data_model import DataModel, DataType, is_chain_type
from datachain.lib.dc import C, Column, DataChain, Sys
from datachain.lib.file import (
Expand Down Expand Up @@ -38,6 +38,5 @@
"func",
"is_chain_type",
"metrics",
"models",
"param",
]
5 changes: 0 additions & 5 deletions src/datachain/lib/models/__init__.py

This file was deleted.

45 changes: 0 additions & 45 deletions src/datachain/lib/models/bbox.py

This file was deleted.

37 changes: 0 additions & 37 deletions src/datachain/lib/models/pose.py

This file was deleted.

39 changes: 0 additions & 39 deletions src/datachain/lib/models/yolo.py

This file was deleted.

6 changes: 6 additions & 0 deletions src/datachain/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import ultralytics
from .bbox import BBox, OBBox
from .pose import Pose, Pose3D
from .segment import Segment

__all__ = ["BBox", "OBBox", "Pose", "Pose3D", "Segment", "ultralytics"]
102 changes: 102 additions & 0 deletions src/datachain/model/bbox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from pydantic import Field

from datachain.lib.data_model import DataModel


class BBox(DataModel):
"""
A data model for representing bounding box.
Attributes:
title (str): The title of the bounding box.
coords (list[int]): The coordinates of the bounding box.
The bounding box is defined by two points:
- (x1, y1): The top-left corner of the box.
- (x2, y2): The bottom-right corner of the box.
"""

title: str = Field(default="")
coords: list[int] = Field(default=None)

@staticmethod
def from_list(coords: list[float], title: str = "") -> "BBox":
assert len(coords) == 4, "Bounding box must be a list of 4 coordinates."
assert all(
isinstance(value, (int, float)) for value in coords
), "Bounding box coordinates must be floats or integers."
return BBox(
title=title,
coords=[round(c) for c in coords],
)

@staticmethod
def from_dict(coords: dict[str, float], title: str = "") -> "BBox":
assert isinstance(coords, dict) and set(coords) == {
"x1",
"y1",
"x2",
"y2",
}, "Bounding box must be a dictionary with keys 'x1', 'y1', 'x2' and 'y2'."
return BBox.from_list(
[coords["x1"], coords["y1"], coords["x2"], coords["y2"]],
title=title,
)


class OBBox(DataModel):
"""
A data model for representing oriented bounding boxes.
Attributes:
title (str): The title of the oriented bounding box.
coords (list[int]): The coordinates of the oriented bounding box.
The oriented bounding box is defined by four points:
- (x1, y1): The first corner of the box.
- (x2, y2): The second corner of the box.
- (x3, y3): The third corner of the box.
- (x4, y4): The fourth corner of the box.
"""

title: str = Field(default="")
coords: list[int] = Field(default=None)

@staticmethod
def from_list(coords: list[float], title: str = "") -> "OBBox":
assert (
len(coords) == 8
), "Oriented bounding box must be a list of 8 coordinates."
assert all(
isinstance(value, (int, float)) for value in coords
), "Oriented bounding box coordinates must be floats or integers."
return OBBox(
title=title,
coords=[round(c) for c in coords],
)

@staticmethod
def from_dict(coords: dict[str, float], title: str = "") -> "OBBox":
assert isinstance(coords, dict) and set(coords) == {
"x1",
"y1",
"x2",
"y2",
"x3",
"y3",
"x4",
"y4",
}, "Oriented bounding box must be a dictionary with coordinates."
return OBBox.from_list(
[
coords["x1"],
coords["y1"],
coords["x2"],
coords["y2"],
coords["x3"],
coords["y3"],
coords["x4"],
coords["y4"],
],
title=title,
)
Loading

0 comments on commit 224f8a6

Please sign in to comment.