From eff7989d7fc877c4ea26ff79497edb2bf4e86fe6 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Fri, 25 Oct 2024 19:24:15 -0600 Subject: [PATCH 01/11] Main code transfer of Mutox classifier from SeamlessM4T --- examples/mutox_example.ipynb | 246 +++++++++++++++++++++++++++++++++++ mutox/__init__.py | 5 + mutox/builder.py | 91 +++++++++++++ mutox/classifier.py | 36 +++++ mutox/cli/README.md | 102 +++++++++++++++ mutox/cli/mutox_speech.py | 140 ++++++++++++++++++++ mutox/cli/mutox_text.py | 98 ++++++++++++++ mutox/loader.py | 46 +++++++ mutox/speech_pipeline.py | 61 +++++++++ 9 files changed, 825 insertions(+) create mode 100644 examples/mutox_example.ipynb create mode 100644 mutox/__init__.py create mode 100644 mutox/builder.py create mode 100644 mutox/classifier.py create mode 100644 mutox/cli/README.md create mode 100644 mutox/cli/mutox_speech.py create mode 100644 mutox/cli/mutox_text.py create mode 100644 mutox/loader.py create mode 100644 mutox/speech_pipeline.py diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb new file mode 100644 index 0000000..a820259 --- /dev/null +++ b/examples/mutox_example.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates\n", + "# All rights reserved.\n", + "#\n", + "# This source code is licensed under the license found in the\n", + "# MIT_LICENSE file in the root directory of this source tree." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MUTOX toxicity classification\n", + "\n", + "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from pathlib import Path\n", + "\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda:0\")\n", + " dtype = torch.float16\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + " dtype = torch.float32" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Speech Scoring" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. download some demo audio segments\n", + "2. create a tsv file to feed to the speech scoring pipeline\n", + "3. load the model and build the pipeline\n", + "4. go through the batches in the pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# get demo file\n", + "import urllib.request\n", + "import tempfile\n", + "\n", + "files = [\n", + " (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n", + " (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n", + "]\n", + "\n", + "tmpdir = Path(tempfile.mkdtemp())\n", + "tsv_file = (tmpdir / 'data.tsv')\n", + "with tsv_file.open('w') as tsv_file_p:\n", + " print('path', file=tsv_file_p)\n", + " for (uri, name) in files:\n", + " dl = tmpdir / name\n", + " urllib.request.urlretrieve(uri, dl)\n", + " print(dl, file=tsv_file_p)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from sonar.inference_pipelines.speech import SpeechInferenceParams\n", + "from seamless_communication.toxicity.mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n", + "\n", + "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", + " mutox_classifier_name =\"mutox\",\n", + " encoder_name=f\"sonar_speech_encoder_eng\",\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n", + " data_file=tsv_file,\n", + " audio_root_dir=None,\n", + " audio_path_index=0,\n", + " target_lang=\"eng\",\n", + " batch_size=4,\n", + " pad_idx=0,\n", + " device=device,\n", + " fbank_dtype=torch.float32,\n", + " n_parallel=4\n", + "))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n", + "/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n" + ] + } + ], + "source": [ + "for batch in pipeline:\n", + " ex = batch['audio']\n", + " for idx, path in enumerate(ex['path']):\n", + " print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# cleanup tmp dir\n", + "import shutil\n", + "shutil.rmtree(tmpdir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Text Scoring\n", + "\n", + "1. load the sonar text encoder\n", + "2. load the mutox classifier model\n", + "3. compute embedding for a sentence\n", + "4. score this embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n" + ] + } + ], + "source": [ + "from seamless_communication.toxicity.mutox.loader import load_mutox_model\n", + "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n", + "\n", + "t2vec_model = TextToEmbeddingModelPipeline(\n", + " encoder=\"text_sonar_basic_encoder\",\n", + " tokenizer=\"text_sonar_basic_encoder\",\n", + " device=device,\n", + ")\n", + "text_column='lang_txt'\n", + "classifier = load_mutox_model(\n", + " \"mutox\",\n", + " device=device,\n", + " dtype=dtype,\n", + ").eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.inference_mode():\n", + " emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n", + " x = classifier(emb.to(device).half())\n", + "\n", + "x" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sc_fr2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mutox/__init__.py b/mutox/__init__.py new file mode 100644 index 0000000..b7f1604 --- /dev/null +++ b/mutox/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/mutox/builder.py b/mutox/builder.py new file mode 100644 index 0000000..470dba7 --- /dev/null +++ b/mutox/builder.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +import typing as tp +from mutox.classifier import ( + MutoxClassifier, + MutoxConfig, +) +import torch +from torch import nn +from fairseq2.typing import DataType, Device + + +class MutoxClassifierBuilder: + """ + Builder module for MutoxClassifier model + """ + + config: MutoxConfig + device: tp.Optional[Device] + dtype: tp.Optional[DataType] + + def __init__( + self, + config: MutoxConfig, + *, + device: tp.Optional[Device] = None, + dtype: tp.Optional[DataType] = None, + ) -> None: + """ + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + self.config = config + self.device, self.dtype = device, dtype + + def build_model(self) -> MutoxClassifier: + model_h1 = nn.Sequential( + nn.Dropout(0.01), + nn.Linear(self.config.input_size, 512), + ) + + model_h2 = nn.Sequential( + nn.ReLU(), + nn.Linear(512, 128), + ) + + model_h3 = nn.Sequential( + nn.ReLU(), + nn.Linear(128, 1), + ) + + model_all = nn.Sequential( + model_h1, + model_h2, + model_h3, + ) + + return MutoxClassifier(model_all,).to( + device=self.device, + dtype=self.dtype, + ) + + +def create_mutox_model( + config: MutoxConfig, + device: tp.Optional[Device] = None, + dtype: tp.Optional[DataType] = None, +) -> MutoxClassifier: + """Create a Mutox Classifier model. + + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + + return MutoxClassifierBuilder( + config, + device=device, + dtype=dtype, + ).build_model() diff --git a/mutox/classifier.py b/mutox/classifier.py new file mode 100644 index 0000000..25e6e6b --- /dev/null +++ b/mutox/classifier.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +import torch +from torch import nn +from fairseq2.typing import DataType, Device + +from fairseq2.models.utils.arch_registry import ArchitectureRegistry +from typing import Optional + + +class MutoxClassifier(nn.Module): + def __init__( + self, + model_all, + ): + super().__init__() + self.model_all = model_all + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.model_all(inputs) + + +@dataclass +class MutoxConfig: + """Holds the configuration of a Mutox Classifier model.""" + + # size of the input embedding supported by this model + input_size: int + + +mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier") diff --git a/mutox/cli/README.md b/mutox/cli/README.md new file mode 100644 index 0000000..08dd7e6 --- /dev/null +++ b/mutox/cli/README.md @@ -0,0 +1,102 @@ +# MuTox: MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector + +MuTox, the first highly multilingual audio-based dataset with toxicity labels. +The dataset consists of 20k audio utterances for English and Spanish, and 4k for +the other 19 languages. To showcase the quality of this dataset, we train the +MuTox audio-based toxicity classifier, which allows zero-shot toxicity detection +across a broad range of languages. This classifier outperforms existing +text-based trainable classifiers by more than 1% AUC, while increasing the +language coverage from 8 to 100+ languages. When compared to a wordlist-based +classifier that covers a similar number of languages, MuTox improves precision +and recall by ∼2.5 times. + +## License + +The mutox code and model are licensed under the MIT license (see MIT_LICENSE +file at the root of seamless_communication). The mutox model depends on SONAR +encoders, most are under the MIT license but a few are under CC-BY-NC license. +See the [SONAR repository](https://github.com/facebookresearch/SONAR) for +details. + +## Dataset Languages. + +- English, +- Spanish, +- Arabic, +- Bengali, +- Mandarin Chinese, +- Dutch, +- French, +- German, +- Hindi, +- Indonesian, +- Italian, +- Japanese, +- Korean, +- Portuguese, +- Russian, +- Swahili, +- Tagalog, +- Thai, +- Turkish, +- Urdu, +- Vietnamese + +## Classifier details. + +We use multi-modal and multilingual +[SONAR](https://github.com/facebookresearch/SONAR) encoders from (Duquenne et +al., 2023). For the classifier, we use variable input sizes for the 3 +feedforward layers (1024, 512, and 128). + +The predictions of the classifier can be interpreted as logits (i.e. after feeding them to a sigmoid transform they become probabilities). +The 0 value can be used as a threshold, as it corresponds to the 50% predicted toxicity probability. + +## Classifier Quick Start + +This introduces the MuTox speech toxicity model, this relies on computing the +sonar embedding and then classifying it through the MuTox model. The +`cli/mutox/mutox.py` provides an example of reading a TSV, computing the SONAR +embedding and running the classifier on the results: + +```bash +python -m seamless_communication.cli.toxicity.mutox.mutox_speech --lang fra --audio_column ref_tgt_audio /checkpoint/bokai/seamless/toxity_mitigation/exps_v5/joined_etox/fleurs/s2t/en-xx/fra.tsv /tmp/tesmortt.tsv +``` + +You can also work with text: + +```bash +python -m seamless_communication.cli.toxicity.mutox.mutox_text --lang fra_Latn sentences.txt +``` + +You can also check the mutox example notebook in this directory. + +## Dataset + +The dataset is available in this [tsv file](https://dl.fbaipublicfiles.com/seamless/datasets/mutox.tsv). The dataset is licensed under the MIT license (see MIT_LICENSE +file at the root of seamless_communication). + +The columns of the dataset are: +- `id`: a string id of the segment; +- `lang`: 3-letter language code; +- `partition`: one of `train`, `dev`, or `devtest`; +- `public_url_segment`: a string formatted as `url:start:end`, where start and end are indicated in milliseconds; +- `audio_file_transcript`: text transctiption of the segment; +- `contains_toxicity`, `toxicity_types`, `perlocutionary_effects`: annotation results as strings (see the paper for their explanation); +- `label`: an integer label, equal to 1 if `contains_toxicity` equals `Yes` and 0 otherwise; +- `etox_result`: toxic word (or multiple words, separated by `|`) detected by the Etox matcher; +- `detoxify_score`: toxicity probabilities predicted by the Detoxify system (float numbers between 0 and 1); +- `mutox_speech_score`, `mutox_text_score`, `mutox_zero_shot_speech_score`, `mutox_zero_shot_text_score`: MuTox predictions as float numbers with any value (they can be interpreted as logits, i.e. probabilities before a sigmoid transformation). + +## Citation + +```bitex +@misc{costajussà2023mutox, + title={MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector}, + author={ Marta R. Costa-jussà, Mariano Coria Meglioli, Pierre Andrews, David Dale, Prangthip Hansanti, Elahe Kalbassi, Alex Mourachko, Christophe Ropers, Carleigh Wood}, + year={2023}, + eprint={}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/mutox/cli/mutox_speech.py b/mutox/cli/mutox_speech.py new file mode 100644 index 0000000..945e533 --- /dev/null +++ b/mutox/cli/mutox_speech.py @@ -0,0 +1,140 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +import argparse + +import torch +from tqdm import tqdm +from pathlib import Path + +from sonar.inference_pipelines.speech import ( + SpeechInferenceParams, +) +from mutox.speech_pipeline import ( + MutoxSpeechClassifierPipeline, +) + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", +) + +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Mutox speech will compute a toxicity score for each speech segment it is provided." + ) + parser.add_argument( + "data_file", + type=Path, + help="Path to the input TSV manifest that list the audio files.", + ) + parser.add_argument( + "output_file", + type=Path, + help="Path to a TSV file where to save the results.", + ) + parser.add_argument( + "--lang", + type=str, + help="Language, language of the speech being passed as input, three letter code", + required=True, + ) + parser.add_argument( + "--audio_root_dir", + type=str, + help="Root directory for the audio filenames in the data file.", + ) + parser.add_argument( + "--audio_path_index", + type=int, + help="Index of the column where the audiofile is listed in the input tsv.", + default="audio", + ) + parser.add_argument( + "--batch_size", + type=int, + help="Inference batch size.", + default=4, + ) + parser.add_argument( + "--n_parallel", + type=int, + help="Number of data loading in parallel.", + default=4, + ) + parser.add_argument( + "--device", + type=str, + help="name of the device to use with torch.", + required=False, + ) + args, _unknown = parser.parse_known_args() + + if args.device is not None: + device = torch.device(args.device) + dtype = torch.float32 + if device.type == "cuda": + dtype = torch.float16 + elif torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + logger.info("using cuda:0, %s", dtype) + else: + device = torch.device("cpu") + dtype = torch.float32 + logger.info("no gpu, using cpu") + + logger.info("loading models.") + + pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name( + mutox_classifier_name="mutox", + encoder_name=f"sonar_speech_encoder_{args.lang}", + device=device, + ) + + pipeline = pipeline_builder.build_pipeline( + SpeechInferenceParams( + data_file=args.data_file, + audio_root_dir=args.audio_root_dir, + audio_path_index=args.audio_path_index, + target_lang=args.lang, + batch_size=args.batch_size, + pad_idx=0, + device=device, + fbank_dtype=torch.float32, + n_parallel=args.n_parallel, + ) + ) + + logger.info("processing.") + + with open(args.output_file, "w", encoding="utf-8") as outf: + print( + "input_audio_path", + "score", + sep="\t", + file=outf, + ) + for example in tqdm(pipeline): + ex = example["audio"] + for idx, path in enumerate(ex["path"]): + print( + str(path), + ex["data"][idx].item(), + sep="\t", + file=outf, + ) + + logger.info(f"Done, outputs are in {args.output_file}.") + + +if __name__ == "__main__": + main() diff --git a/mutox/cli/mutox_text.py b/mutox/cli/mutox_text.py new file mode 100644 index 0000000..209ea83 --- /dev/null +++ b/mutox/cli/mutox_text.py @@ -0,0 +1,98 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +import argparse +import sys + +import torch +from mutox.loader import load_mutox_model +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline + +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", +) + +CPU_DEVICE = torch.device("cpu") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Mutox Text will compute a toxicity score for each sentence it is passed." + ) + + parser.add_argument( + "lang", + type=str, + help="Language of the input text, nllb format with script.", + ) + parser.add_argument( + "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin + ) + parser.add_argument( + "output", nargs="?", type=argparse.FileType("w"), default=sys.stdout + ) + parser.add_argument( + "--batch_size", + type=int, + help="Inference batch size.", + default=4, + ) + parser.add_argument( + "--device", + type=str, + help="name of the device to use with torch.", + required=False, + ) + args, _unknown = parser.parse_known_args() + + if args.device is not None: + device = torch.device(args.device) + dtype = torch.float32 + if device.type == "cuda": + dtype = torch.float16 + elif torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model( + "mutox", + device=device, + dtype=dtype, + ).eval() + + def write_result(batch): + emb = t2vec_model.predict(batch, source_lang=args.lang) + scores = classifier(emb.half()) + for s, t in zip(scores, batch): + print(t, s.item(), sep="\t", file=args.output) + + with torch.inference_mode(): + print("text", "score", sep="\t", file=args.output) + batch = [] + for line in args.input: + batch.append(line.rstrip()) + if len(batch) >= args.batch_size: + write_result(batch) + batch = [] + + if len(batch): + write_result(batch) + + +if __name__ == "__main__": + main() diff --git a/mutox/loader.py b/mutox/loader.py new file mode 100644 index 0000000..d7c59a8 --- /dev/null +++ b/mutox/loader.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + + +from fairseq2.assets import asset_store, download_manager +from fairseq2.models.utils import ConfigLoader, ModelLoader +from mutox.builder import create_mutox_model +from mutox.classifier import ( + MutoxClassifier, + MutoxConfig, + mutox_archs, +) + +import typing as tp + + +@mutox_archs.decorator("mutox") +def _base_mutox() -> MutoxConfig: + return MutoxConfig( + input_size=1024, + ) + + +def convert_mutox_checkpoint( + checkpoint: tp.Mapping[str, tp.Any], config: MutoxConfig +) -> tp.Mapping[str, tp.Any]: + new_dict = {} + for key in checkpoint: + if key.startswith("model_all."): + new_dict[key] = checkpoint[key] + return {"model": new_dict} + + +load_mutox_config = ConfigLoader[MutoxConfig](asset_store, mutox_archs) + + +load_mutox_model = ModelLoader[MutoxClassifier, MutoxConfig]( + asset_store, + download_manager, + load_mutox_config, + create_mutox_model, + convert_mutox_checkpoint, +) diff --git a/mutox/speech_pipeline.py b/mutox/speech_pipeline.py new file mode 100644 index 0000000..22b634f --- /dev/null +++ b/mutox/speech_pipeline.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +import torch +from mutox.classifier import MutoxClassifier +from mutox.loader import load_mutox_model +from sonar.models.sonar_speech.loader import load_sonar_speech_model + +from sonar.inference_pipelines.speech import ( + SpeechToEmbeddingPipeline, + SpeechInferenceParams, +) + +from fairseq2.data import ( + DataPipelineBuilder, +) + +from typing import Union + +from mutox.mutox.classifier import MutoxClassifier +from sonar.models.encoder_model import SonarEncoderModel +from fairseq2.typing import Device + + +CPU_DEVICE = torch.device("cpu") + + +class MutoxSpeechClassifierPipeline(SpeechToEmbeddingPipeline): + def __init__( + self, + mutox_classifier: Union[str, MutoxClassifier], + encoder: Union[str, SonarEncoderModel], + device: Device = CPU_DEVICE, + ) -> None: + super().__init__(encoder) + self.model.to(device).eval() + self.mutox_classifier = mutox_classifier.to(device).eval() + + @classmethod + def load_model_from_name( + cls, + mutox_classifier_name: str, + encoder_name: str, + device: Device = CPU_DEVICE, + ) -> "SpeechToEmbeddingPipeline": + encoder = load_sonar_speech_model(encoder_name, device=device, progress=False) + mutox_classifier = load_mutox_model( + mutox_classifier_name, device=device, progress=False + ) + return cls(mutox_classifier=mutox_classifier, encoder=encoder, device=device) + + def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder: + pipeline_builder = super().prebuild_pipeline(context) + return pipeline_builder.map(self._run_classifier, selector="audio.data") + + @torch.inference_mode() + def _run_classifier(self, data: dict): + return self.mutox_classifier(data.sentence_embeddings) From 822673ae18c3d70bc20634772242afe839dacd89 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Fri, 25 Oct 2024 20:38:00 -0600 Subject: [PATCH 02/11] Minor changes, import inside mutox broken --- examples/mutox_example.ipynb | 11 ++--------- mutox/builder.py | 9 +++++---- mutox/classifier.py | 3 ++- mutox/cli/README.md | 6 +++--- mutox/loader.py | 8 ++++---- mutox/speech_pipeline.py | 21 +++++++++------------ pyproject.toml | 2 +- 7 files changed, 26 insertions(+), 34 deletions(-) diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb index a820259..3b98dff 100644 --- a/examples/mutox_example.ipynb +++ b/examples/mutox_example.ipynb @@ -88,7 +88,7 @@ "outputs": [], "source": [ "from sonar.inference_pipelines.speech import SpeechInferenceParams\n", - "from seamless_communication.toxicity.mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n", + "from mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n", "\n", "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", " mutox_classifier_name =\"mutox\",\n", @@ -174,7 +174,7 @@ } ], "source": [ - "from seamless_communication.toxicity.mutox.loader import load_mutox_model\n", + "from mutox.loader import load_mutox_model\n", "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n", "\n", "t2vec_model = TextToEmbeddingModelPipeline(\n", @@ -213,13 +213,6 @@ "\n", "x" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/mutox/builder.py b/mutox/builder.py index 470dba7..32f6cff 100644 --- a/mutox/builder.py +++ b/mutox/builder.py @@ -5,14 +5,15 @@ # MIT_LICENSE file in the root directory of this source tree. import typing as tp -from mutox.classifier import ( - MutoxClassifier, - MutoxConfig, -) import torch from torch import nn from fairseq2.typing import DataType, Device +from .classifier import ( + MutoxClassifier, + MutoxConfig, +) + class MutoxClassifierBuilder: """ diff --git a/mutox/classifier.py b/mutox/classifier.py index 25e6e6b..5c67cca 100644 --- a/mutox/classifier.py +++ b/mutox/classifier.py @@ -7,9 +7,10 @@ from dataclasses import dataclass import torch from torch import nn -from fairseq2.typing import DataType, Device +from fairseq2.typing import DataType, Device from fairseq2.models.utils.arch_registry import ArchitectureRegistry + from typing import Optional diff --git a/mutox/cli/README.md b/mutox/cli/README.md index 08dd7e6..5900131 100644 --- a/mutox/cli/README.md +++ b/mutox/cli/README.md @@ -15,7 +15,7 @@ and recall by ∼2.5 times. The mutox code and model are licensed under the MIT license (see MIT_LICENSE file at the root of seamless_communication). The mutox model depends on SONAR encoders, most are under the MIT license but a few are under CC-BY-NC license. -See the [SONAR repository](https://github.com/facebookresearch/SONAR) for +See [SONAR](../../sonar/) for details. ## Dataset Languages. @@ -60,13 +60,13 @@ sonar embedding and then classifying it through the MuTox model. The embedding and running the classifier on the results: ```bash -python -m seamless_communication.cli.toxicity.mutox.mutox_speech --lang fra --audio_column ref_tgt_audio /checkpoint/bokai/seamless/toxity_mitigation/exps_v5/joined_etox/fleurs/s2t/en-xx/fra.tsv /tmp/tesmortt.tsv +python -m mutox.mutox_speech --lang fra --audio_column ref_tgt_audio /checkpoint/bokai/seamless/toxity_mitigation/exps_v5/joined_etox/fleurs/s2t/en-xx/fra.tsv /tmp/tesmortt.tsv ``` You can also work with text: ```bash -python -m seamless_communication.cli.toxicity.mutox.mutox_text --lang fra_Latn sentences.txt +python -m mutox.mutox_text --lang fra_Latn sentences.txt ``` You can also check the mutox example notebook in this directory. diff --git a/mutox/loader.py b/mutox/loader.py index d7c59a8..013930b 100644 --- a/mutox/loader.py +++ b/mutox/loader.py @@ -4,18 +4,18 @@ # This source code is licensed under the license found in the # MIT_LICENSE file in the root directory of this source tree. +import typing as tp from fairseq2.assets import asset_store, download_manager from fairseq2.models.utils import ConfigLoader, ModelLoader -from mutox.builder import create_mutox_model -from mutox.classifier import ( + +from .builder import create_mutox_model +from .classifier import ( MutoxClassifier, MutoxConfig, mutox_archs, ) -import typing as tp - @mutox_archs.decorator("mutox") def _base_mutox() -> MutoxConfig: diff --git a/mutox/speech_pipeline.py b/mutox/speech_pipeline.py index 22b634f..ad86ee4 100644 --- a/mutox/speech_pipeline.py +++ b/mutox/speech_pipeline.py @@ -4,26 +4,23 @@ # This source code is licensed under the license found in the # MIT_LICENSE file in the root directory of this source tree. +from typing import Union import torch -from mutox.classifier import MutoxClassifier -from mutox.loader import load_mutox_model -from sonar.models.sonar_speech.loader import load_sonar_speech_model - -from sonar.inference_pipelines.speech import ( - SpeechToEmbeddingPipeline, - SpeechInferenceParams, -) +from fairseq2.typing import Device from fairseq2.data import ( DataPipelineBuilder, ) -from typing import Union - -from mutox.mutox.classifier import MutoxClassifier +from sonar.models.sonar_speech.loader import load_sonar_speech_model from sonar.models.encoder_model import SonarEncoderModel -from fairseq2.typing import Device +from sonar.inference_pipelines.speech import ( + SpeechToEmbeddingPipeline, + SpeechInferenceParams, +) +from .classifier import MutoxClassifier +from .loader import load_mutox_model CPU_DEVICE = torch.device("cpu") diff --git a/pyproject.toml b/pyproject.toml index 2078947..cd50bcf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ dev = [ ] -hg = [ +hf = [ "transformers>=4.44.0", "datasets>=2.20.0", "evaluate>=0.4.2", From 35c91757f444c83ee92b5c0aa3488586c32cac63 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Mon, 28 Oct 2024 19:02:35 -0600 Subject: [PATCH 03/11] Corrections from PR #44 comments --- README.md | 27 ++++ examples/mutox_example.ipynb | 8 +- mutox/cli/README.md | 102 ------------- mutox/cli/mutox_speech.py | 140 ------------------ mutox/cli/mutox_text.py | 98 ------------ pyproject.toml | 2 +- .../inference_pipelines/mutox_speech.py | 4 +- {mutox => sonar/models/mutox}/__init__.py | 0 {mutox => sonar/models/mutox}/builder.py | 0 {mutox => sonar/models/mutox}/classifier.py | 0 {mutox => sonar/models/mutox}/loader.py | 0 tests/unit_tests/test_mutox.py | 0 12 files changed, 34 insertions(+), 347 deletions(-) delete mode 100644 mutox/cli/README.md delete mode 100644 mutox/cli/mutox_speech.py delete mode 100644 mutox/cli/mutox_text.py rename mutox/speech_pipeline.py => sonar/inference_pipelines/mutox_speech.py (94%) rename {mutox => sonar/models/mutox}/__init__.py (100%) rename {mutox => sonar/models/mutox}/builder.py (100%) rename {mutox => sonar/models/mutox}/classifier.py (100%) rename {mutox => sonar/models/mutox}/loader.py (100%) create mode 100644 tests/unit_tests/test_mutox.py diff --git a/README.md b/README.md index 17e0fbe..c063155 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,33 @@ print(blaser_qe(src=src_embs, mt=mt_embs).item()) # 4.708 Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggingface.co/facebook/blaser-2.0-ref), [facebook/blaser-2.0-qe](https://huggingface.co/facebook/blaser-2.0-qe). +### Classifying the toxicity of sentences with MuTox + +[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. + +```Python +from sonar.models.mutox.loader import load_mutox_model +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline + +t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, +) +text_column='lang_txt' +classifier = load_mutox_model( + "mutox", + device=device, + dtype=dtype, +).eval() + +with torch.inference_mode(): + emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn') + x = classifier(emb.to(device).half()) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) +``` + +For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox). + ### Demo notebooks See more complete demo notebooks : diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb index 3b98dff..291f127 100644 --- a/examples/mutox_example.ipynb +++ b/examples/mutox_example.ipynb @@ -88,7 +88,7 @@ "outputs": [], "source": [ "from sonar.inference_pipelines.speech import SpeechInferenceParams\n", - "from mutox.speech_pipeline import MutoxSpeechClassifierPipeline\n", + "from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n", "\n", "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", " mutox_classifier_name =\"mutox\",\n", @@ -174,7 +174,7 @@ } ], "source": [ - "from mutox.loader import load_mutox_model\n", + "from sonar.models.mutox.loader import load_mutox_model\n", "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n", "\n", "t2vec_model = TextToEmbeddingModelPipeline(\n", @@ -217,7 +217,7 @@ ], "metadata": { "kernelspec": { - "display_name": "sc_fr2", + "display_name": "SONAR", "language": "python", "name": "python3" }, @@ -231,7 +231,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/mutox/cli/README.md b/mutox/cli/README.md deleted file mode 100644 index 5900131..0000000 --- a/mutox/cli/README.md +++ /dev/null @@ -1,102 +0,0 @@ -# MuTox: MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector - -MuTox, the first highly multilingual audio-based dataset with toxicity labels. -The dataset consists of 20k audio utterances for English and Spanish, and 4k for -the other 19 languages. To showcase the quality of this dataset, we train the -MuTox audio-based toxicity classifier, which allows zero-shot toxicity detection -across a broad range of languages. This classifier outperforms existing -text-based trainable classifiers by more than 1% AUC, while increasing the -language coverage from 8 to 100+ languages. When compared to a wordlist-based -classifier that covers a similar number of languages, MuTox improves precision -and recall by ∼2.5 times. - -## License - -The mutox code and model are licensed under the MIT license (see MIT_LICENSE -file at the root of seamless_communication). The mutox model depends on SONAR -encoders, most are under the MIT license but a few are under CC-BY-NC license. -See [SONAR](../../sonar/) for -details. - -## Dataset Languages. - -- English, -- Spanish, -- Arabic, -- Bengali, -- Mandarin Chinese, -- Dutch, -- French, -- German, -- Hindi, -- Indonesian, -- Italian, -- Japanese, -- Korean, -- Portuguese, -- Russian, -- Swahili, -- Tagalog, -- Thai, -- Turkish, -- Urdu, -- Vietnamese - -## Classifier details. - -We use multi-modal and multilingual -[SONAR](https://github.com/facebookresearch/SONAR) encoders from (Duquenne et -al., 2023). For the classifier, we use variable input sizes for the 3 -feedforward layers (1024, 512, and 128). - -The predictions of the classifier can be interpreted as logits (i.e. after feeding them to a sigmoid transform they become probabilities). -The 0 value can be used as a threshold, as it corresponds to the 50% predicted toxicity probability. - -## Classifier Quick Start - -This introduces the MuTox speech toxicity model, this relies on computing the -sonar embedding and then classifying it through the MuTox model. The -`cli/mutox/mutox.py` provides an example of reading a TSV, computing the SONAR -embedding and running the classifier on the results: - -```bash -python -m mutox.mutox_speech --lang fra --audio_column ref_tgt_audio /checkpoint/bokai/seamless/toxity_mitigation/exps_v5/joined_etox/fleurs/s2t/en-xx/fra.tsv /tmp/tesmortt.tsv -``` - -You can also work with text: - -```bash -python -m mutox.mutox_text --lang fra_Latn sentences.txt -``` - -You can also check the mutox example notebook in this directory. - -## Dataset - -The dataset is available in this [tsv file](https://dl.fbaipublicfiles.com/seamless/datasets/mutox.tsv). The dataset is licensed under the MIT license (see MIT_LICENSE -file at the root of seamless_communication). - -The columns of the dataset are: -- `id`: a string id of the segment; -- `lang`: 3-letter language code; -- `partition`: one of `train`, `dev`, or `devtest`; -- `public_url_segment`: a string formatted as `url:start:end`, where start and end are indicated in milliseconds; -- `audio_file_transcript`: text transctiption of the segment; -- `contains_toxicity`, `toxicity_types`, `perlocutionary_effects`: annotation results as strings (see the paper for their explanation); -- `label`: an integer label, equal to 1 if `contains_toxicity` equals `Yes` and 0 otherwise; -- `etox_result`: toxic word (or multiple words, separated by `|`) detected by the Etox matcher; -- `detoxify_score`: toxicity probabilities predicted by the Detoxify system (float numbers between 0 and 1); -- `mutox_speech_score`, `mutox_text_score`, `mutox_zero_shot_speech_score`, `mutox_zero_shot_text_score`: MuTox predictions as float numbers with any value (they can be interpreted as logits, i.e. probabilities before a sigmoid transformation). - -## Citation - -```bitex -@misc{costajussà2023mutox, - title={MuTox: Universal MUltilingual Audio-based TOXicity Dataset and Zero-shot Detector}, - author={ Marta R. Costa-jussà, Mariano Coria Meglioli, Pierre Andrews, David Dale, Prangthip Hansanti, Elahe Kalbassi, Alex Mourachko, Christophe Ropers, Carleigh Wood}, - year={2023}, - eprint={}, - archivePrefix={arXiv}, - primaryClass={cs.CL} -} -``` diff --git a/mutox/cli/mutox_speech.py b/mutox/cli/mutox_speech.py deleted file mode 100644 index 945e533..0000000 --- a/mutox/cli/mutox_speech.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the license found in the -# MIT_LICENSE file in the root directory of this source tree. - -import argparse - -import torch -from tqdm import tqdm -from pathlib import Path - -from sonar.inference_pipelines.speech import ( - SpeechInferenceParams, -) -from mutox.speech_pipeline import ( - MutoxSpeechClassifierPipeline, -) - -import logging - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", -) - -logger = logging.getLogger(__name__) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Mutox speech will compute a toxicity score for each speech segment it is provided." - ) - parser.add_argument( - "data_file", - type=Path, - help="Path to the input TSV manifest that list the audio files.", - ) - parser.add_argument( - "output_file", - type=Path, - help="Path to a TSV file where to save the results.", - ) - parser.add_argument( - "--lang", - type=str, - help="Language, language of the speech being passed as input, three letter code", - required=True, - ) - parser.add_argument( - "--audio_root_dir", - type=str, - help="Root directory for the audio filenames in the data file.", - ) - parser.add_argument( - "--audio_path_index", - type=int, - help="Index of the column where the audiofile is listed in the input tsv.", - default="audio", - ) - parser.add_argument( - "--batch_size", - type=int, - help="Inference batch size.", - default=4, - ) - parser.add_argument( - "--n_parallel", - type=int, - help="Number of data loading in parallel.", - default=4, - ) - parser.add_argument( - "--device", - type=str, - help="name of the device to use with torch.", - required=False, - ) - args, _unknown = parser.parse_known_args() - - if args.device is not None: - device = torch.device(args.device) - dtype = torch.float32 - if device.type == "cuda": - dtype = torch.float16 - elif torch.cuda.is_available(): - device = torch.device("cuda:0") - dtype = torch.float16 - logger.info("using cuda:0, %s", dtype) - else: - device = torch.device("cpu") - dtype = torch.float32 - logger.info("no gpu, using cpu") - - logger.info("loading models.") - - pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name( - mutox_classifier_name="mutox", - encoder_name=f"sonar_speech_encoder_{args.lang}", - device=device, - ) - - pipeline = pipeline_builder.build_pipeline( - SpeechInferenceParams( - data_file=args.data_file, - audio_root_dir=args.audio_root_dir, - audio_path_index=args.audio_path_index, - target_lang=args.lang, - batch_size=args.batch_size, - pad_idx=0, - device=device, - fbank_dtype=torch.float32, - n_parallel=args.n_parallel, - ) - ) - - logger.info("processing.") - - with open(args.output_file, "w", encoding="utf-8") as outf: - print( - "input_audio_path", - "score", - sep="\t", - file=outf, - ) - for example in tqdm(pipeline): - ex = example["audio"] - for idx, path in enumerate(ex["path"]): - print( - str(path), - ex["data"][idx].item(), - sep="\t", - file=outf, - ) - - logger.info(f"Done, outputs are in {args.output_file}.") - - -if __name__ == "__main__": - main() diff --git a/mutox/cli/mutox_text.py b/mutox/cli/mutox_text.py deleted file mode 100644 index 209ea83..0000000 --- a/mutox/cli/mutox_text.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the license found in the -# MIT_LICENSE file in the root directory of this source tree. - -import argparse -import sys - -import torch -from mutox.loader import load_mutox_model -from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline - -import logging - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", -) - -CPU_DEVICE = torch.device("cpu") - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Mutox Text will compute a toxicity score for each sentence it is passed." - ) - - parser.add_argument( - "lang", - type=str, - help="Language of the input text, nllb format with script.", - ) - parser.add_argument( - "input", nargs="?", type=argparse.FileType("r"), default=sys.stdin - ) - parser.add_argument( - "output", nargs="?", type=argparse.FileType("w"), default=sys.stdout - ) - parser.add_argument( - "--batch_size", - type=int, - help="Inference batch size.", - default=4, - ) - parser.add_argument( - "--device", - type=str, - help="name of the device to use with torch.", - required=False, - ) - args, _unknown = parser.parse_known_args() - - if args.device is not None: - device = torch.device(args.device) - dtype = torch.float32 - if device.type == "cuda": - dtype = torch.float16 - elif torch.cuda.is_available(): - device = torch.device("cuda:0") - dtype = torch.float16 - else: - device = torch.device("cpu") - dtype = torch.float32 - - t2vec_model = TextToEmbeddingModelPipeline( - encoder="text_sonar_basic_encoder", - tokenizer="text_sonar_basic_encoder", - device=device, - ) - - classifier = load_mutox_model( - "mutox", - device=device, - dtype=dtype, - ).eval() - - def write_result(batch): - emb = t2vec_model.predict(batch, source_lang=args.lang) - scores = classifier(emb.half()) - for s, t in zip(scores, batch): - print(t, s.item(), sep="\t", file=args.output) - - with torch.inference_mode(): - print("text", "score", sep="\t", file=args.output) - batch = [] - for line in args.input: - batch.append(line.rstrip()) - if len(batch) >= args.batch_size: - write_result(batch) - batch = [] - - if len(batch): - write_result(batch) - - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index cd50bcf..2078947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ dev = [ ] -hf = [ +hg = [ "transformers>=4.44.0", "datasets>=2.20.0", "evaluate>=0.4.2", diff --git a/mutox/speech_pipeline.py b/sonar/inference_pipelines/mutox_speech.py similarity index 94% rename from mutox/speech_pipeline.py rename to sonar/inference_pipelines/mutox_speech.py index ad86ee4..369e56f 100644 --- a/mutox/speech_pipeline.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -19,8 +19,8 @@ SpeechInferenceParams, ) -from .classifier import MutoxClassifier -from .loader import load_mutox_model +from sonar.models.mutox.classifier import MutoxClassifier +from sonar.models.mutox.loader import load_mutox_model CPU_DEVICE = torch.device("cpu") diff --git a/mutox/__init__.py b/sonar/models/mutox/__init__.py similarity index 100% rename from mutox/__init__.py rename to sonar/models/mutox/__init__.py diff --git a/mutox/builder.py b/sonar/models/mutox/builder.py similarity index 100% rename from mutox/builder.py rename to sonar/models/mutox/builder.py diff --git a/mutox/classifier.py b/sonar/models/mutox/classifier.py similarity index 100% rename from mutox/classifier.py rename to sonar/models/mutox/classifier.py diff --git a/mutox/loader.py b/sonar/models/mutox/loader.py similarity index 100% rename from mutox/loader.py rename to sonar/models/mutox/loader.py diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py new file mode 100644 index 0000000..e69de29 From 60f181621dd0c2a4c2012c287f1a483d2b51eee8 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Mon, 28 Oct 2024 19:55:11 -0600 Subject: [PATCH 04/11] Added unit tests for mutox builder, classifier --- tests/unit_tests/test_mutox.py | 92 ++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py index e69de29..7c2b53a 100644 --- a/tests/unit_tests/test_mutox.py +++ b/tests/unit_tests/test_mutox.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torch import nn +from unittest.mock import Mock + +from sonar.models.mutox.builder import MutoxConfig, MutoxClassifierBuilder, create_mutox_model +from sonar.models.mutox.classifier import MutoxClassifier +from sonar.models.mutox.loader import ( + convert_mutox_checkpoint, +) + +# Builder tests + +@pytest.mark.parametrize("input_size", [256, 512, 1024]) +@pytest.mark.parametrize("device", [torch.device("cpu")]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_mutox_classifier_builder(input_size, device, dtype): + """Test MutoxClassifierBuilder initializes a model with correct configuration and dtype.""" + config = MutoxConfig(input_size=input_size) + builder = MutoxClassifierBuilder(config, device=device, dtype=dtype) + model = builder.build_model() + + # Check if model layers are correctly initialized with shapes + assert isinstance(model, nn.Module), "Model should be an instance of nn.Module" + assert all( + isinstance(layer, nn.Sequential) for layer in model.model_all.children() + ), "All layers should be instances of nn.Sequential" + + test_input = torch.zeros((5, input_size), device=device, dtype=dtype) + result = model(test_input) + assert result.shape == (5, 1), f"Expected output shape (5, 1), got {result.shape}" + + +@pytest.mark.parametrize("input_size", [256, 512]) +def test_create_mutox_model(input_size): + """Test create_mutox_model function to confirm it creates a model with the specified config.""" + config = MutoxConfig(input_size=input_size) + model = create_mutox_model(config, device=torch.device("cpu")) + + # Check if the created model has the expected structure and behavior + test_input = torch.zeros((3, input_size)) + result = model(test_input) + assert result.shape == (3, 1), f"Expected output shape (3, 1), got {result.shape}" + assert isinstance(model, nn.Module), "Model should be an instance of nn.Module" + + +# Classifier tests + +def test_mutox_classifier_forward(): + """Test that MutoxClassifier forward pass returns expected output shape.""" + test_model= nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + model = MutoxClassifier(test_model) + + test_input = torch.randn(3, 10) + output = model(test_input) + assert output.shape == (3, 1), f"Expected output shape (3, 1), but instead got {output.shape}" + + +def test_mutox_config(): + """Test that MutoxConfig stores the configuration for a model.""" + config = MutoxConfig(input_size=512) + assert config.input_size == 512, f"Config input_size should be 512, but got {config.input_size}" + + +# Loader tests + +def test_convert_mutox_checkpoint(): + """Test convert_mutox_checkpoint correctly filters keys in the checkpoint.""" + # Create a mock checkpoint with both 'model_all.' prefixed keys and other keys + checkpoint = { + "model_all.layer1.weight": torch.tensor([1.0]), + "model_all.layer1.bias": torch.tensor([0.5]), + "non_model_key": torch.tensor([3.0]) + } + config = MutoxConfig(input_size=1024) + converted = convert_mutox_checkpoint(checkpoint, config) + + # Verify only 'model_all.' keys are retained in the converted dictionary + assert "model" in converted, "Converted checkpoint should contain a 'model' key" + assert "model_all.layer1.weight" in converted["model"], "Expected 'model_all.layer1.weight'" + assert "model_all.layer1.bias" in converted["model"], "Expected 'model_all.layer1.bias'" + assert "non_model_key" not in converted["model"], "Unexpected 'non_model_key'" From 55a342d0aaacb755c7075f4e1d4b4ca93995f65a Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Fri, 8 Nov 2024 18:07:58 -0600 Subject: [PATCH 05/11] Resolved comments PR#44: Added MutoxConfig opt layer, style changes, repo decoupling, other --- README.md | 22 ++++++++++++-- examples/mutox_example.ipynb | 15 +++++++--- sonar/cards/sonar_mutox.yaml | 11 +++++++ sonar/inference_pipelines/mutox_speech.py | 21 ++++++++------ sonar/models/mutox/__init__.py | 2 +- sonar/models/mutox/builder.py | 24 ++++++++-------- sonar/models/mutox/classifier.py | 12 ++++---- sonar/models/mutox/loader.py | 8 ++---- tests/unit_tests/test_mutox.py | 35 ++++++++++++++++------- 9 files changed, 100 insertions(+), 50 deletions(-) create mode 100644 sonar/cards/sonar_mutox.yaml diff --git a/README.md b/README.md index c063155..300a5f6 100644 --- a/README.md +++ b/README.md @@ -144,11 +144,19 @@ Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggi ### Classifying the toxicity of sentences with MuTox -[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. +[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a probability of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset. ```Python from sonar.models.mutox.loader import load_mutox_model from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline +import torch + +if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 +else: + device = torch.device("cpu") + dtype = torch.float32 t2vec_model = TextToEmbeddingModelPipeline( encoder="text_sonar_basic_encoder", @@ -157,14 +165,22 @@ t2vec_model = TextToEmbeddingModelPipeline( ) text_column='lang_txt' classifier = load_mutox_model( - "mutox", + "sonar_mutox", device=device, dtype=dtype, ).eval() with torch.inference_mode(): emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn') - x = classifier(emb.to(device).half()) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) + x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) + +with torch.inference_mode(): + emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='fra_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16) + +with torch.inference_mode(): + emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16) ``` For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox). diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb index 291f127..f511a68 100644 --- a/examples/mutox_example.ipynb +++ b/examples/mutox_example.ipynb @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -91,7 +91,7 @@ "from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n", "\n", "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", - " mutox_classifier_name =\"mutox\",\n", + " mutox_classifier_name =\"sonar_mutox\",\n", " encoder_name=f\"sonar_speech_encoder_eng\",\n", " device=device,\n", ")" @@ -116,6 +116,13 @@ "))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n" + ] + }, { "cell_type": "code", "execution_count": 7, @@ -162,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -184,7 +191,7 @@ ")\n", "text_column='lang_txt'\n", "classifier = load_mutox_model(\n", - " \"mutox\",\n", + " \"sonar_mutox\",\n", " device=device,\n", " dtype=dtype,\n", ").eval()" diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml new file mode 100644 index 0000000..2f26626 --- /dev/null +++ b/sonar/cards/sonar_mutox.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +name: sonar_mutox +model_type: mutox_classifier +model_arch: mutox +checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt" +input_size: 1024 \ No newline at end of file diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index 369e56f..c7ea634 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -5,22 +5,19 @@ # MIT_LICENSE file in the root directory of this source tree. from typing import Union -import torch +import torch +from fairseq2.data import DataPipelineBuilder from fairseq2.typing import Device -from fairseq2.data import ( - DataPipelineBuilder, -) -from sonar.models.sonar_speech.loader import load_sonar_speech_model -from sonar.models.encoder_model import SonarEncoderModel from sonar.inference_pipelines.speech import ( - SpeechToEmbeddingPipeline, SpeechInferenceParams, + SpeechToEmbeddingPipeline, ) - +from sonar.models.encoder_model import SonarEncoderModel from sonar.models.mutox.classifier import MutoxClassifier from sonar.models.mutox.loader import load_mutox_model +from sonar.models.sonar_speech.loader import load_sonar_speech_model CPU_DEVICE = torch.device("cpu") @@ -32,7 +29,13 @@ def __init__( encoder: Union[str, SonarEncoderModel], device: Device = CPU_DEVICE, ) -> None: - super().__init__(encoder) + if isinstance(encoder, str): + model = self.load_model_from_name("sonar_mutox", encoder, device=device) + else: + model = encoder + + super().__init__(model) + self.model.to(device).eval() self.mutox_classifier = mutox_classifier.to(device).eval() diff --git a/sonar/models/mutox/__init__.py b/sonar/models/mutox/__init__.py index b7f1604..15d8859 100644 --- a/sonar/models/mutox/__init__.py +++ b/sonar/models/mutox/__init__.py @@ -2,4 +2,4 @@ # All rights reserved. # # This source code is licensed under the license found in the -# MIT_LICENSE file in the root directory of this source tree. \ No newline at end of file +# MIT_LICENSE file in the root directory of this source tree. diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index 32f6cff..d9d86b0 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -5,14 +5,12 @@ # MIT_LICENSE file in the root directory of this source tree. import typing as tp + import torch -from torch import nn from fairseq2.typing import DataType, Device +from torch import nn -from .classifier import ( - MutoxClassifier, - MutoxConfig, -) +from .classifier import MutoxClassifier, MutoxConfig class MutoxClassifierBuilder: @@ -42,21 +40,21 @@ def __init__( self.config = config self.device, self.dtype = device, dtype - def build_model(self) -> MutoxClassifier: + def build_model(self, activation=nn.ReLU) -> MutoxClassifier: model_h1 = nn.Sequential( nn.Dropout(0.01), nn.Linear(self.config.input_size, 512), ) model_h2 = nn.Sequential( - nn.ReLU(), + activation, nn.Linear(512, 128), ) - model_h3 = nn.Sequential( - nn.ReLU(), - nn.Linear(128, 1), - ) + if self.config.output_prob: + model_h3 = nn.Sequential(activation(), nn.Linear(128, 1), nn.Sigmoid()) + else: + model_h3 = nn.Sequential(activation(), nn.Linear(128, 1)) model_all = nn.Sequential( model_h1, @@ -64,7 +62,9 @@ def build_model(self) -> MutoxClassifier: model_h3, ) - return MutoxClassifier(model_all,).to( + return MutoxClassifier( + model_all, + ).to( device=self.device, dtype=self.dtype, ) diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 5c67cca..1dcd360 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -5,13 +5,12 @@ # MIT_LICENSE file in the root directory of this source tree. from dataclasses import dataclass -import torch -from torch import nn +from typing import Optional -from fairseq2.typing import DataType, Device +import torch from fairseq2.models.utils.arch_registry import ArchitectureRegistry - -from typing import Optional +from fairseq2.typing import DataType, Device +from torch import nn class MutoxClassifier(nn.Module): @@ -33,5 +32,8 @@ class MutoxConfig: # size of the input embedding supported by this model input_size: int + # add sigmoid as last layer to output probability + output_prob: bool = False + mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier") diff --git a/sonar/models/mutox/loader.py b/sonar/models/mutox/loader.py index 013930b..20d764d 100644 --- a/sonar/models/mutox/loader.py +++ b/sonar/models/mutox/loader.py @@ -10,11 +10,9 @@ from fairseq2.models.utils import ConfigLoader, ModelLoader from .builder import create_mutox_model -from .classifier import ( - MutoxClassifier, - MutoxConfig, - mutox_archs, -) +from .classifier import MutoxClassifier, MutoxConfig, mutox_archs + +__import__("sonar") # Import only to update asset_store @mutox_archs.decorator("mutox") diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py index 7c2b53a..0e0db33 100644 --- a/tests/unit_tests/test_mutox.py +++ b/tests/unit_tests/test_mutox.py @@ -7,16 +7,18 @@ import pytest import torch from torch import nn -from unittest.mock import Mock -from sonar.models.mutox.builder import MutoxConfig, MutoxClassifierBuilder, create_mutox_model -from sonar.models.mutox.classifier import MutoxClassifier -from sonar.models.mutox.loader import ( - convert_mutox_checkpoint, +from sonar.models.mutox.builder import ( + MutoxClassifierBuilder, + MutoxConfig, + create_mutox_model, ) +from sonar.models.mutox.classifier import MutoxClassifier +from sonar.models.mutox.loader import convert_mutox_checkpoint # Builder tests + @pytest.mark.parametrize("input_size", [256, 512, 1024]) @pytest.mark.parametrize("device", [torch.device("cpu")]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @@ -52,9 +54,10 @@ def test_create_mutox_model(input_size): # Classifier tests + def test_mutox_classifier_forward(): """Test that MutoxClassifier forward pass returns expected output shape.""" - test_model= nn.Sequential( + test_model = nn.Sequential( nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1), @@ -63,30 +66,40 @@ def test_mutox_classifier_forward(): test_input = torch.randn(3, 10) output = model(test_input) - assert output.shape == (3, 1), f"Expected output shape (3, 1), but instead got {output.shape}" + assert output.shape == ( + 3, + 1, + ), f"Expected output shape (3, 1), but instead got {output.shape}" def test_mutox_config(): """Test that MutoxConfig stores the configuration for a model.""" config = MutoxConfig(input_size=512) - assert config.input_size == 512, f"Config input_size should be 512, but got {config.input_size}" + assert ( + config.input_size == 512 + ), f"Config input_size should be 512, but got {config.input_size}" # Loader tests + def test_convert_mutox_checkpoint(): """Test convert_mutox_checkpoint correctly filters keys in the checkpoint.""" # Create a mock checkpoint with both 'model_all.' prefixed keys and other keys checkpoint = { "model_all.layer1.weight": torch.tensor([1.0]), "model_all.layer1.bias": torch.tensor([0.5]), - "non_model_key": torch.tensor([3.0]) + "non_model_key": torch.tensor([3.0]), } config = MutoxConfig(input_size=1024) converted = convert_mutox_checkpoint(checkpoint, config) # Verify only 'model_all.' keys are retained in the converted dictionary assert "model" in converted, "Converted checkpoint should contain a 'model' key" - assert "model_all.layer1.weight" in converted["model"], "Expected 'model_all.layer1.weight'" - assert "model_all.layer1.bias" in converted["model"], "Expected 'model_all.layer1.bias'" + assert ( + "model_all.layer1.weight" in converted["model"] + ), "Expected 'model_all.layer1.weight'" + assert ( + "model_all.layer1.bias" in converted["model"] + ), "Expected 'model_all.layer1.bias'" assert "non_model_key" not in converted["model"], "Unexpected 'non_model_key'" From d2d7d60777774ad546b5dc71b46464bcf2994206 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Wed, 20 Nov 2024 03:21:45 -0600 Subject: [PATCH 06/11] Resolved comments 2 PR#44: Missing comments, style changes, others --- README.md | 4 ++-- sonar/cards/sonar_mutox.yaml | 8 ++++++++ sonar/inference_pipelines/mutox_speech.py | 7 +++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 300a5f6..4c48afc 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggi ### Classifying the toxicity of sentences with MuTox -[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a probability of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset. +[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a logit of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset. ```Python from sonar.models.mutox.loader import load_mutox_model @@ -175,7 +175,7 @@ with torch.inference_mode(): x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) with torch.inference_mode(): - emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='fra_Latn') + emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='eng_Latn') x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16) with torch.inference_mode(): diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml index 2f26626..760c40d 100644 --- a/sonar/cards/sonar_mutox.yaml +++ b/sonar/cards/sonar_mutox.yaml @@ -4,6 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +""" +This card is a duplicate of the original found at +[Facebook Research's Seamless Communication repository] +(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). +It is included here to prevent circular dependencies between the Seamless Communication +repository and this project. +""" + name: sonar_mutox model_type: mutox_classifier model_arch: mutox diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index c7ea634..535a9e6 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -39,6 +39,13 @@ def __init__( self.model.to(device).eval() self.mutox_classifier = mutox_classifier.to(device).eval() + if isinstance(mutox_classifier, str): + self.mutox_classifier = load_mutox_model(mutox_classifier, device=device,) + else: + self.mutox_classifier = mutox_classifier + + self.mutox_classifier.to(device).eval() + @classmethod def load_model_from_name( cls, From 1baaf4949720d208882860b44f8f782f1a620ffe Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Wed, 20 Nov 2024 05:21:55 -0600 Subject: [PATCH 07/11] Resolved comments 3 PR#44: opt sigmoid layer change, card edit, other linter/mypy related --- examples/mutox_example.ipynb | 2 +- sonar/cards/sonar_mutox.yaml | 11 ++++------- sonar/inference_pipelines/mutox_speech.py | 19 +++++++++++++------ sonar/models/mutox/builder.py | 8 +------- sonar/models/mutox/classifier.py | 10 ++++++---- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb index f511a68..4cbe982 100644 --- a/examples/mutox_example.ipynb +++ b/examples/mutox_example.ipynb @@ -19,7 +19,7 @@ "source": [ "# MUTOX toxicity classification\n", "\n", - "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that." + "Mutox enables toxicity scoring for speech and text using sonar embeddings and a classifier trained with a _Binary Cross Entropy loss with logits_ objective. To obtain probabilities from the classifier's output, apply a sigmoid layer. This notebook demonstrates encoding speech and text into sonar embeddings and classifying their toxicity." ] }, { diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml index 760c40d..f1e60ee 100644 --- a/sonar/cards/sonar_mutox.yaml +++ b/sonar/cards/sonar_mutox.yaml @@ -4,13 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -This card is a duplicate of the original found at -[Facebook Research's Seamless Communication repository] -(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). -It is included here to prevent circular dependencies between the Seamless Communication -repository and this project. -""" +#This card is a duplicate of the original found at +#[Facebook Research's Seamless Communication repository] +#(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). +#It is included here to prevent circular dependencies between the Seamless Communication name: sonar_mutox model_type: mutox_classifier diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index 535a9e6..fa34f0c 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -30,17 +30,21 @@ def __init__( device: Device = CPU_DEVICE, ) -> None: if isinstance(encoder, str): - model = self.load_model_from_name("sonar_mutox", encoder, device=device) + self.model = self.load_model_from_name( + "sonar_mutox", encoder, device=device + ) else: - model = encoder + self.model = encoder - super().__init__(model) + super().__init__(self.model) self.model.to(device).eval() - self.mutox_classifier = mutox_classifier.to(device).eval() if isinstance(mutox_classifier, str): - self.mutox_classifier = load_mutox_model(mutox_classifier, device=device,) + self.mutox_classifier = load_mutox_model( + mutox_classifier, + device=device, + ) else: self.mutox_classifier = mutox_classifier @@ -65,4 +69,7 @@ def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuild @torch.inference_mode() def _run_classifier(self, data: dict): - return self.mutox_classifier(data.sentence_embeddings) + sentence_embeddings = data.get("sentence_embeddings") + if sentence_embeddings is None: + raise ValueError("Missing sentence embeddings in the data.") + return self.mutox_classifier(sentence_embeddings) diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index d9d86b0..f021c85 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -40,7 +40,7 @@ def __init__( self.config = config self.device, self.dtype = device, dtype - def build_model(self, activation=nn.ReLU) -> MutoxClassifier: + def build_model(self, activation=nn.ReLU()) -> MutoxClassifier: model_h1 = nn.Sequential( nn.Dropout(0.01), nn.Linear(self.config.input_size, 512), @@ -51,15 +51,9 @@ def build_model(self, activation=nn.ReLU) -> MutoxClassifier: nn.Linear(512, 128), ) - if self.config.output_prob: - model_h3 = nn.Sequential(activation(), nn.Linear(128, 1), nn.Sigmoid()) - else: - model_h3 = nn.Sequential(activation(), nn.Linear(128, 1)) - model_all = nn.Sequential( model_h1, model_h2, - model_h3, ) return MutoxClassifier( diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 1dcd360..ada2efe 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -21,7 +21,12 @@ def __init__( super().__init__() self.model_all = model_all - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tensor: + if output_prob: + self.model_all.add_module("sigmoid", nn.Sigmoid()) + else: + self.model_all.add_module("linear", nn.Linear(128, 1)) + return self.model_all(inputs) @@ -32,8 +37,5 @@ class MutoxConfig: # size of the input embedding supported by this model input_size: int - # add sigmoid as last layer to output probability - output_prob: bool = False - mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier") From 96f5d6fa550882c575406eee6fcfd2457b82e687 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Thu, 21 Nov 2024 15:33:13 -0600 Subject: [PATCH 08/11] Resolved comments 4 PR#44 --- sonar/models/mutox/builder.py | 14 +++++++++----- sonar/models/mutox/classifier.py | 9 +++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index f021c85..7e6577f 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -40,25 +40,29 @@ def __init__( self.config = config self.device, self.dtype = device, dtype - def build_model(self, activation=nn.ReLU()) -> MutoxClassifier: + def build_model(self) -> MutoxClassifier: model_h1 = nn.Sequential( nn.Dropout(0.01), nn.Linear(self.config.input_size, 512), ) model_h2 = nn.Sequential( - activation, + nn.ReLU(), nn.Linear(512, 128), ) + model_h3 = nn.Sequential( + nn.ReLU(), + nn.Linear(128, 1), + ) + model_all = nn.Sequential( model_h1, model_h2, + model_h3, ) - return MutoxClassifier( - model_all, - ).to( + return MutoxClassifier(model_all,).to( device=self.device, dtype=self.dtype, ) diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index ada2efe..9ae7ebe 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -22,12 +22,13 @@ def __init__( self.model_all = model_all def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tensor: + outputs = self.model_all(inputs) + if output_prob: - self.model_all.add_module("sigmoid", nn.Sigmoid()) - else: - self.model_all.add_module("linear", nn.Linear(128, 1)) + outputs = torch.sigmoid(outputs) + + return outputs - return self.model_all(inputs) @dataclass From 7c3b3f9a1c9615c43e67b058ee4d122de60649dc Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Thu, 21 Nov 2024 17:05:30 -0600 Subject: [PATCH 09/11] Resolved comments 5 PR#44: new integration and unit tests for mutox, other related changes --- README.md | 4 +- sonar/models/mutox/builder.py | 4 +- sonar/models/mutox/classifier.py | 1 - tests/integration_tests/test_mutox.py | 123 ++++++++++++++++++++++++++ tests/unit_tests/test_mutox.py | 24 ++++- 5 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 tests/integration_tests/test_mutox.py diff --git a/README.md b/README.md index 4c48afc..a085ab0 100644 --- a/README.md +++ b/README.md @@ -176,11 +176,11 @@ with torch.inference_mode(): with torch.inference_mode(): emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='eng_Latn') - x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16) + x = classifier(emb.to(device).to(dtype)) # tensor([[-53.5938]], device='cuda:0', dtype=torch.float16) with torch.inference_mode(): emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn') - x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16) + x = classifier(emb.to(device).to(dtype)) # tensor([[-21.4062]], device='cuda:0', dtype=torch.float16) ``` For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox). diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index 7e6577f..b308cb9 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -62,7 +62,9 @@ def build_model(self) -> MutoxClassifier: model_h3, ) - return MutoxClassifier(model_all,).to( + return MutoxClassifier( + model_all, + ).to( device=self.device, dtype=self.dtype, ) diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 9ae7ebe..524a386 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -30,7 +30,6 @@ def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tens return outputs - @dataclass class MutoxConfig: """Holds the configuration of a Mutox Classifier model.""" diff --git a/tests/integration_tests/test_mutox.py b/tests/integration_tests/test_mutox.py new file mode 100644 index 0000000..3466301 --- /dev/null +++ b/tests/integration_tests/test_mutox.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline +from sonar.models.mutox.loader import load_mutox_model + + +@pytest.mark.parametrize( + "input_texts, source_lang, expected_outputs", + [ + ( + ["De peur que le pays ne se prostitue et ne se remplisse de crimes."], + "fra_Latn", + [-19.7812], + ), + ( + ["She worked hard and made a significant contribution to the team."], + "eng_Latn", + [-53.5938], + ), + ( + [ + "El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones." + ], + "spa_Latn", + [-21.4062], + ), + ], +) +def test_sonar_mutox_classifier_integration(input_texts, source_lang, expected_outputs): + """Integration test to compare classifier outputs with expected values.""" + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval() + + with torch.inference_mode(): + embeddings = t2vec_model.predict(input_texts, source_lang=source_lang) + outputs = classifier(embeddings.to(device).to(dtype)).squeeze() + + if outputs.dim() == 0: + outputs = [outputs.item()] + else: + outputs = outputs.tolist() + + # Compare the outputs to expected values within a small tolerance + for output, expected in zip(outputs, expected_outputs): + assert abs(output - expected) < 0.1, ( + f"Expected output {expected}, but got {output}. " + "Outputs should be close to expected values." + ) + + +@pytest.mark.parametrize( + "input_texts, source_lang, expected_probabilities", + [ + ( + ["De peur que le pays ne se prostitue et ne se remplisse de crimes."], + "fra_Latn", + [0.0], + ), + ( + ["She worked hard and made a significant contribution to the team."], + "eng_Latn", + [0.0], + ), + ( + [ + "El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones." + ], + "spa_Latn", + [0.0], + ), + ], +) +def test_sonar_mutox_classifier_probability_integration( + input_texts, source_lang, expected_probabilities +): + """Integration test to verify classifier output probabilities.""" + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval() + + for text, lang, expected_prob in zip( + input_texts, [source_lang] * len(input_texts), expected_probabilities + ): + with torch.inference_mode(): + emb = t2vec_model.predict([text], source_lang=lang) + + prob = classifier(emb.to(device).to(dtype), output_prob=True) + + assert abs(prob.item() - expected_prob) < 0.001, ( + f"Expected probability {expected_prob}, but got {prob.item()}. " + "Output probability should be within a reasonable range." + ) diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py index 0e0db33..ebb73eb 100644 --- a/tests/unit_tests/test_mutox.py +++ b/tests/unit_tests/test_mutox.py @@ -72,6 +72,29 @@ def test_mutox_classifier_forward(): ), f"Expected output shape (3, 1), but instead got {output.shape}" +def test_mutox_classifier_forward_with_output_prob(): + """Test that MutoxClassifier forward pass applies sigmoid when output_prob=True.""" + test_model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + model = MutoxClassifier(test_model) + + test_input = torch.randn(3, 10) + + output = model(test_input, output_prob=True) + + assert output.shape == ( + 3, + 1, + ), f"Expected output shape (3, 1), but instead got {output.shape}" + + assert (output >= 0).all() and ( + output <= 1 + ).all(), "Expected output values to be within the range [0, 1]" + + def test_mutox_config(): """Test that MutoxConfig stores the configuration for a model.""" config = MutoxConfig(input_size=512) @@ -85,7 +108,6 @@ def test_mutox_config(): def test_convert_mutox_checkpoint(): """Test convert_mutox_checkpoint correctly filters keys in the checkpoint.""" - # Create a mock checkpoint with both 'model_all.' prefixed keys and other keys checkpoint = { "model_all.layer1.weight": torch.tensor([1.0]), "model_all.layer1.bias": torch.tensor([0.5]), From 8de0914a8431207f8aa4bd3d481580fe6c065aef Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Fri, 22 Nov 2024 21:01:25 -0600 Subject: [PATCH 10/11] Resolved comments 6 PR#44: modifying integration test to increase coverage, other changes to satisfy mypy --- sonar/inference_pipelines/mutox_speech.py | 26 ++++++++++++++++--- tests/integration_tests/test_mutox.py | 8 +++--- .../unit_tests/huggingface_pipelines/text.py | 2 +- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index fa34f0c..d2f9e4f 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -11,9 +11,12 @@ from fairseq2.typing import Device from sonar.inference_pipelines.speech import ( + AudioToFbankDataPipelineBuilder, SpeechInferenceParams, + SpeechInferencePipeline, SpeechToEmbeddingPipeline, ) +from sonar.inference_pipelines.utils import extract_sequence_batch from sonar.models.encoder_model import SonarEncoderModel from sonar.models.mutox.classifier import MutoxClassifier from sonar.models.mutox.loader import load_mutox_model @@ -22,7 +25,9 @@ CPU_DEVICE = torch.device("cpu") -class MutoxSpeechClassifierPipeline(SpeechToEmbeddingPipeline): +class MutoxSpeechClassifierPipeline(SpeechInferencePipeline): + model: SonarEncoderModel + def __init__( self, mutox_classifier: Union[str, MutoxClassifier], @@ -36,7 +41,7 @@ def __init__( else: self.model = encoder - super().__init__(self.model) + super().__init__() self.model.to(device).eval() @@ -56,7 +61,7 @@ def load_model_from_name( mutox_classifier_name: str, encoder_name: str, device: Device = CPU_DEVICE, - ) -> "SpeechToEmbeddingPipeline": + ) -> "MutoxSpeechClassifierPipeline": encoder = load_sonar_speech_model(encoder_name, device=device, progress=False) mutox_classifier = load_mutox_model( mutox_classifier_name, device=device, progress=False @@ -64,9 +69,22 @@ def load_model_from_name( return cls(mutox_classifier=mutox_classifier, encoder=encoder, device=device) def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder: - pipeline_builder = super().prebuild_pipeline(context) + audio_to_fbank_dp_builder = AudioToFbankDataPipelineBuilder() + pipeline_builder = ( + audio_to_fbank_dp_builder.prebuild_pipeline(context) + .map( + lambda fbank: extract_sequence_batch(fbank, context.device), + selector="audio.data.fbank", + ) + .map(self.run_inference, selector="audio.data") + ) return pipeline_builder.map(self._run_classifier, selector="audio.data") + @torch.inference_mode() + def run_inference(self, fbank: torch.Tensor) -> dict: + """Runs the encoder model on the extracted FBANK features.""" + return {"sentence_embeddings": self.model(fbank)} + @torch.inference_mode() def _run_classifier(self, data: dict): sentence_embeddings = data.get("sentence_embeddings") diff --git a/tests/integration_tests/test_mutox.py b/tests/integration_tests/test_mutox.py index 3466301..a7e5b45 100644 --- a/tests/integration_tests/test_mutox.py +++ b/tests/integration_tests/test_mutox.py @@ -76,9 +76,11 @@ def test_sonar_mutox_classifier_integration(input_texts, source_lang, expected_o [0.0], ), ( - ["She worked hard and made a significant contribution to the team."], + [ + "Dammit, that was a terrible launch, it will piss the director and make the mission fail." + ], "eng_Latn", - [0.0], + [0.23], ), ( [ @@ -117,7 +119,7 @@ def test_sonar_mutox_classifier_probability_integration( prob = classifier(emb.to(device).to(dtype), output_prob=True) - assert abs(prob.item() - expected_prob) < 0.001, ( + assert abs(prob.item() - expected_prob) < 0.01, ( f"Expected probability {expected_prob}, but got {prob.item()}. " "Output probability should be within a reasonable range." ) diff --git a/tests/unit_tests/huggingface_pipelines/text.py b/tests/unit_tests/huggingface_pipelines/text.py index a354c10..689349d 100644 --- a/tests/unit_tests/huggingface_pipelines/text.py +++ b/tests/unit_tests/huggingface_pipelines/text.py @@ -52,7 +52,7 @@ def test_embedding_to_text_process_batch(embedding_to_text_config): embedding_dim = 1024 num_embeddings = 4 - embeddings = [ + embeddings: List[np.ndarray] = [ np.random.rand(embedding_dim).astype(np.float32) for _ in range(num_embeddings) ] From 9648d9dc38743fb6b0499244ee5ffb4af5b4be41 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Wed, 4 Dec 2024 09:47:46 -0600 Subject: [PATCH 11/11] Added line #type: ignore to pass mypy check --- sonar/inference_pipelines/mutox_speech.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index d2f9e4f..8e7ff71 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -37,7 +37,7 @@ def __init__( if isinstance(encoder, str): self.model = self.load_model_from_name( "sonar_mutox", encoder, device=device - ) + ) # type: ignore else: self.model = encoder