Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutox classifier #44

Merged
merged 11 commits into from
Dec 4, 2024
Merged
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,49 @@ print(blaser_qe(src=src_embs, mt=mt_embs).item()) # 4.708
Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggingface.co/facebook/blaser-2.0-ref),
[facebook/blaser-2.0-qe](https://huggingface.co/facebook/blaser-2.0-qe).

### Classifying the toxicity of sentences with MuTox

[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a probability of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset.
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved

```Python
from sonar.models.mutox.loader import load_mutox_model
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import torch

if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

t2vec_model = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
avidale marked this conversation as resolved.
Show resolved Hide resolved
)
text_column='lang_txt'
classifier = load_mutox_model(
"sonar_mutox",
device=device,
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved
dtype=dtype,
).eval()

with torch.inference_mode():
emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn')
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved
x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='fra_Latn')
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved
x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16)
```

For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox).

### Demo notebooks
See more complete demo notebooks :

Expand Down
246 changes: 246 additions & 0 deletions examples/mutox_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Meta Platforms, Inc. and affiliates\n",
"# All rights reserved.\n",
"#\n",
"# This source code is licensed under the license found in the\n",
"# MIT_LICENSE file in the root directory of this source tree."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MUTOX toxicity classification\n",
"\n",
"Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pathlib import Path\n",
"\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda:0\")\n",
" dtype = torch.float16\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
" dtype = torch.float32"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Speech Scoring"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. download some demo audio segments\n",
"2. create a tsv file to feed to the speech scoring pipeline\n",
"3. load the model and build the pipeline\n",
"4. go through the batches in the pipeline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# get demo file\n",
"import urllib.request\n",
"import tempfile\n",
"\n",
"files = [\n",
" (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n",
" (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n",
"]\n",
"\n",
"tmpdir = Path(tempfile.mkdtemp())\n",
"tsv_file = (tmpdir / 'data.tsv')\n",
"with tsv_file.open('w') as tsv_file_p:\n",
" print('path', file=tsv_file_p)\n",
" for (uri, name) in files:\n",
" dl = tmpdir / name\n",
" urllib.request.urlretrieve(uri, dl)\n",
" print(dl, file=tsv_file_p)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sonar.inference_pipelines.speech import SpeechInferenceParams\n",
"from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n",
"\n",
"pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n",
" mutox_classifier_name =\"sonar_mutox\",\n",
" encoder_name=f\"sonar_speech_encoder_eng\",\n",
" device=device,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n",
" data_file=tsv_file,\n",
" audio_root_dir=None,\n",
" audio_path_index=0,\n",
" target_lang=\"eng\",\n",
" batch_size=4,\n",
" pad_idx=0,\n",
" device=device,\n",
" fbank_dtype=torch.float32,\n",
" n_parallel=4\n",
"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n",
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved
"/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n"
]
}
],
"source": [
"for batch in pipeline:\n",
" ex = batch['audio']\n",
" for idx, path in enumerate(ex['path']):\n",
" print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# cleanup tmp dir\n",
"import shutil\n",
"shutil.rmtree(tmpdir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text Scoring\n",
"\n",
"1. load the sonar text encoder\n",
"2. load the mutox classifier model\n",
"3. compute embedding for a sentence\n",
"4. score this embedding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n"
]
}
],
"source": [
"from sonar.models.mutox.loader import load_mutox_model\n",
"from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n",
"\n",
"t2vec_model = TextToEmbeddingModelPipeline(\n",
" encoder=\"text_sonar_basic_encoder\",\n",
" tokenizer=\"text_sonar_basic_encoder\",\n",
" device=device,\n",
")\n",
"text_column='lang_txt'\n",
"classifier = load_mutox_model(\n",
" \"sonar_mutox\",\n",
" device=device,\n",
" dtype=dtype,\n",
").eval()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.inference_mode():\n",
" emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n",
" x = classifier(emb.to(device).half())\n",
"\n",
"x"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "SONAR",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
11 changes: 11 additions & 0 deletions sonar/cards/sonar_mutox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

name: sonar_mutox
model_type: mutox_classifier
model_arch: mutox
checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt"
input_size: 1024
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved
61 changes: 61 additions & 0 deletions sonar/inference_pipelines/mutox_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.

from typing import Union

import torch
from fairseq2.data import DataPipelineBuilder
from fairseq2.typing import Device

from sonar.inference_pipelines.speech import (
SpeechInferenceParams,
SpeechToEmbeddingPipeline,
)
from sonar.models.encoder_model import SonarEncoderModel
from sonar.models.mutox.classifier import MutoxClassifier
from sonar.models.mutox.loader import load_mutox_model
from sonar.models.sonar_speech.loader import load_sonar_speech_model

CPU_DEVICE = torch.device("cpu")


class MutoxSpeechClassifierPipeline(SpeechToEmbeddingPipeline):
def __init__(
self,
mutox_classifier: Union[str, MutoxClassifier],
encoder: Union[str, SonarEncoderModel],
device: Device = CPU_DEVICE,
) -> None:
if isinstance(encoder, str):
model = self.load_model_from_name("sonar_mutox", encoder, device=device)
else:
model = encoder

super().__init__(model)

self.model.to(device).eval()
self.mutox_classifier = mutox_classifier.to(device).eval()
David-OC17 marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def load_model_from_name(
cls,
mutox_classifier_name: str,
encoder_name: str,
device: Device = CPU_DEVICE,
) -> "SpeechToEmbeddingPipeline":
encoder = load_sonar_speech_model(encoder_name, device=device, progress=False)
mutox_classifier = load_mutox_model(
mutox_classifier_name, device=device, progress=False
)
return cls(mutox_classifier=mutox_classifier, encoder=encoder, device=device)

def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder:
pipeline_builder = super().prebuild_pipeline(context)
return pipeline_builder.map(self._run_classifier, selector="audio.data")

@torch.inference_mode()
def _run_classifier(self, data: dict):
return self.mutox_classifier(data.sentence_embeddings)
5 changes: 5 additions & 0 deletions sonar/models/mutox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
Loading
Loading