Skip to content

Commit

Permalink
add nsfw classification model (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
CamiloInx authored Jul 17, 2024
1 parent 1c28e60 commit 273f349
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@ jobs:
make install-depth-anything-v2
make test-depth-anything-v2
- name: nsfw-classification tests
run: |
make install-nsfw-classification
make test-nsfw-classification
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ install-florencev2:
$(POETRY) install -E florencev2 --no-interaction
pip install timm flash-attn

install-nsfw-classification:
# Install nsfw_classification dependencies only
$(POETRY) install -E nsfw-classification --no-interaction

test:
# Run all unit tests (experimental due posible dependencies conflict)
$(POETRY) run pytest tests
Expand Down Expand Up @@ -62,3 +66,7 @@ test-depth-anything-v2:
test-florencev2:
# Run florencev2 unit tests
$(POETRY) run pytest tests/tools/test_florencev2.py

test-nsfw-classification:
# Run nsfw_classification unit tests
$(POETRY) run pytest tests/tools/test_nsfw_classification.py
24 changes: 24 additions & 0 deletions docs/nsfw_classification.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# NSFW (Not Safe for Work) classification

This example demonstrates using the Not Safe for Work classification tool.


```python
from vision_agent_tools.tools.nsfw_classification import NSFWClassification

# (replace this path with your own!)
test_image = "path/to/your/image.jpg"

# Load the image
image = Image.open(test_image)
# Initialize the NSFW model.
nsfw_classification = NSFWClassification()

# Run the inference
results = nsfw_classification(image)

# Let's print the predicted label
print(results.label)
```

::: vision_agent_tools.tools.nsfw_classification
5 changes: 2 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ owlv2 = ["transformers", "scipy"]
florencev2 = ["transformers", "scipy"]
loca-model = ["loca"]
depth-anything-v2-model = ["depth-anything-v2"]
nsfw-classification = ["transformers", "scipy"]


[build-system]
Expand Down
Binary file added tests/tools/data/nsfw_classification/safework.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 12 additions & 0 deletions tests/tools/test_nsfw_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from PIL import Image
from vision_agent_tools.tools.nsfw_classification import NSFWClassification


def test_successful_nsfw_classification():
test_image = "safework.jpg"
image = Image.open(f"tests/tools/data/nsfw_classification/{test_image}")

nsfw_classifier = NSFWClassification()
results = nsfw_classifier(image)

assert results.label == "normal"
72 changes: 72 additions & 0 deletions vision_agent_tools/tools/nsfw_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch

from PIL import Image
from pydantic import BaseModel
from transformers import AutoModelForImageClassification, ViTImageProcessor
from vision_agent_tools.tools.shared_types import BaseTool

CHECKPOINT = "Falconsai/nsfw_image_detection"


class NSFWInferenceData(BaseModel):
"""
Represents an inference result from the NSFWClassification model.
Attributes:
label (str): The predicted label for the image.
score (float): The confidence score associated with the prediction (between 0 and 1).
"""

label: str
score: float


class NSFWClassification(BaseTool):
"""
The primary intended use of this model is for the classification of
[NSFW (Not Safe for Work)](https://huggingface.co/Falconsai/nsfw_image_detection) images.
"""

def __init__(self):
"""
Initializes the NSFW (Not Safe for Work) classification tool.
"""
self._model = AutoModelForImageClassification.from_pretrained(CHECKPOINT)
self._processor = ViTImageProcessor.from_pretrained(CHECKPOINT)

self.device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
self._model.to(self.device)

def __call__(
self,
image: Image.Image,
) -> NSFWInferenceData:
"""
Performs the NSFW inference on an image using the NSFWClassification model.
Args:
image (Image.Image): The input image for object detection.
Returns:
NSFWInferenceData: The inference result from the NSFWClassification model.
label (str): The label for the unsafe content detected in the image.
score (float):The score for the unsafe content detected in the image.
"""
with torch.no_grad():
inputs = self._processor(
images=image,
return_tensors="pt",
).to(self.device)
outputs = self._model(**inputs)
logits = outputs.logits
scores = logits.softmax(dim=1).tolist()[0]
predicted_label = logits.argmax(-1).item()
text = self._model.config.id2label[predicted_label]
return NSFWInferenceData(label=text, score=scores[predicted_label])

0 comments on commit 273f349

Please sign in to comment.