-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
124 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# NSFW (Not Safe for Work) classification | ||
|
||
This example demonstrates using the Not Safe for Work classification tool. | ||
|
||
|
||
```python | ||
from vision_agent_tools.tools.nsfw_classification import NSFWClassification | ||
|
||
# (replace this path with your own!) | ||
test_image = "path/to/your/image.jpg" | ||
|
||
# Load the image | ||
image = Image.open(test_image) | ||
# Initialize the NSFW model. | ||
nsfw_classification = NSFWClassification() | ||
|
||
# Run the inference | ||
results = nsfw_classification(image) | ||
|
||
# Let's print the predicted label | ||
print(results.label) | ||
``` | ||
|
||
::: vision_agent_tools.tools.nsfw_classification |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from PIL import Image | ||
from vision_agent_tools.tools.nsfw_classification import NSFWClassification | ||
|
||
|
||
def test_successful_nsfw_classification(): | ||
test_image = "safework.jpg" | ||
image = Image.open(f"tests/tools/data/nsfw_classification/{test_image}") | ||
|
||
nsfw_classifier = NSFWClassification() | ||
results = nsfw_classifier(image) | ||
|
||
assert results.label == "normal" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
|
||
from PIL import Image | ||
from pydantic import BaseModel | ||
from transformers import AutoModelForImageClassification, ViTImageProcessor | ||
from vision_agent_tools.tools.shared_types import BaseTool | ||
|
||
CHECKPOINT = "Falconsai/nsfw_image_detection" | ||
|
||
|
||
class NSFWInferenceData(BaseModel): | ||
""" | ||
Represents an inference result from the NSFWClassification model. | ||
Attributes: | ||
label (str): The predicted label for the image. | ||
score (float): The confidence score associated with the prediction (between 0 and 1). | ||
""" | ||
|
||
label: str | ||
score: float | ||
|
||
|
||
class NSFWClassification(BaseTool): | ||
""" | ||
The primary intended use of this model is for the classification of | ||
[NSFW (Not Safe for Work)](https://huggingface.co/Falconsai/nsfw_image_detection) images. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initializes the NSFW (Not Safe for Work) classification tool. | ||
""" | ||
self._model = AutoModelForImageClassification.from_pretrained(CHECKPOINT) | ||
self._processor = ViTImageProcessor.from_pretrained(CHECKPOINT) | ||
|
||
self.device = ( | ||
"cuda" | ||
if torch.cuda.is_available() | ||
else "mps" | ||
if torch.backends.mps.is_available() | ||
else "cpu" | ||
) | ||
self._model.to(self.device) | ||
|
||
def __call__( | ||
self, | ||
image: Image.Image, | ||
) -> NSFWInferenceData: | ||
""" | ||
Performs the NSFW inference on an image using the NSFWClassification model. | ||
Args: | ||
image (Image.Image): The input image for object detection. | ||
Returns: | ||
NSFWInferenceData: The inference result from the NSFWClassification model. | ||
label (str): The label for the unsafe content detected in the image. | ||
score (float):The score for the unsafe content detected in the image. | ||
""" | ||
with torch.no_grad(): | ||
inputs = self._processor( | ||
images=image, | ||
return_tensors="pt", | ||
).to(self.device) | ||
outputs = self._model(**inputs) | ||
logits = outputs.logits | ||
scores = logits.softmax(dim=1).tolist()[0] | ||
predicted_label = logits.argmax(-1).item() | ||
text = self._model.config.id2label[predicted_label] | ||
return NSFWInferenceData(label=text, score=scores[predicted_label]) |