From ed55f216c0d5759c30bd54df531a101403705217 Mon Sep 17 00:00:00 2001 From: David-OC17 Date: Fri, 8 Nov 2024 18:07:58 -0600 Subject: [PATCH] Resolved comments PR#44: Added MutoxConfig opt layer, style changes, repo decoupling, other --- README.md | 22 ++++++++++++-- examples/mutox_example.ipynb | 15 +++++++--- sonar/cards/sonar_mutox.yaml | 11 +++++++ sonar/inference_pipelines/mutox_speech.py | 21 ++++++++------ sonar/models/mutox/__init__.py | 2 +- sonar/models/mutox/builder.py | 31 ++++++++++++-------- sonar/models/mutox/classifier.py | 12 ++++---- sonar/models/mutox/loader.py | 8 ++---- tests/unit_tests/test_mutox.py | 35 ++++++++++++++++------- 9 files changed, 107 insertions(+), 50 deletions(-) create mode 100644 sonar/cards/sonar_mutox.yaml diff --git a/README.md b/README.md index c063155..300a5f6 100644 --- a/README.md +++ b/README.md @@ -144,11 +144,19 @@ Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggi ### Classifying the toxicity of sentences with MuTox -[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. +[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a probability of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset. ```Python from sonar.models.mutox.loader import load_mutox_model from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline +import torch + +if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 +else: + device = torch.device("cpu") + dtype = torch.float32 t2vec_model = TextToEmbeddingModelPipeline( encoder="text_sonar_basic_encoder", @@ -157,14 +165,22 @@ t2vec_model = TextToEmbeddingModelPipeline( ) text_column='lang_txt' classifier = load_mutox_model( - "mutox", + "sonar_mutox", device=device, dtype=dtype, ).eval() with torch.inference_mode(): emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn') - x = classifier(emb.to(device).half()) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) + x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) + +with torch.inference_mode(): + emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='fra_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16) + +with torch.inference_mode(): + emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16) ``` For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox). diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb index 291f127..f511a68 100644 --- a/examples/mutox_example.ipynb +++ b/examples/mutox_example.ipynb @@ -83,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -91,7 +91,7 @@ "from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n", "\n", "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", - " mutox_classifier_name =\"mutox\",\n", + " mutox_classifier_name =\"sonar_mutox\",\n", " encoder_name=f\"sonar_speech_encoder_eng\",\n", " device=device,\n", ")" @@ -116,6 +116,13 @@ "))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n" + ] + }, { "cell_type": "code", "execution_count": 7, @@ -162,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -184,7 +191,7 @@ ")\n", "text_column='lang_txt'\n", "classifier = load_mutox_model(\n", - " \"mutox\",\n", + " \"sonar_mutox\",\n", " device=device,\n", " dtype=dtype,\n", ").eval()" diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml new file mode 100644 index 0000000..2f26626 --- /dev/null +++ b/sonar/cards/sonar_mutox.yaml @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +name: sonar_mutox +model_type: mutox_classifier +model_arch: mutox +checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt" +input_size: 1024 \ No newline at end of file diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index 369e56f..c7ea634 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -5,22 +5,19 @@ # MIT_LICENSE file in the root directory of this source tree. from typing import Union -import torch +import torch +from fairseq2.data import DataPipelineBuilder from fairseq2.typing import Device -from fairseq2.data import ( - DataPipelineBuilder, -) -from sonar.models.sonar_speech.loader import load_sonar_speech_model -from sonar.models.encoder_model import SonarEncoderModel from sonar.inference_pipelines.speech import ( - SpeechToEmbeddingPipeline, SpeechInferenceParams, + SpeechToEmbeddingPipeline, ) - +from sonar.models.encoder_model import SonarEncoderModel from sonar.models.mutox.classifier import MutoxClassifier from sonar.models.mutox.loader import load_mutox_model +from sonar.models.sonar_speech.loader import load_sonar_speech_model CPU_DEVICE = torch.device("cpu") @@ -32,7 +29,13 @@ def __init__( encoder: Union[str, SonarEncoderModel], device: Device = CPU_DEVICE, ) -> None: - super().__init__(encoder) + if isinstance(encoder, str): + model = self.load_model_from_name("sonar_mutox", encoder, device=device) + else: + model = encoder + + super().__init__(model) + self.model.to(device).eval() self.mutox_classifier = mutox_classifier.to(device).eval() diff --git a/sonar/models/mutox/__init__.py b/sonar/models/mutox/__init__.py index b7f1604..15d8859 100644 --- a/sonar/models/mutox/__init__.py +++ b/sonar/models/mutox/__init__.py @@ -2,4 +2,4 @@ # All rights reserved. # # This source code is licensed under the license found in the -# MIT_LICENSE file in the root directory of this source tree. \ No newline at end of file +# MIT_LICENSE file in the root directory of this source tree. diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index 32f6cff..ac9a147 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -5,14 +5,12 @@ # MIT_LICENSE file in the root directory of this source tree. import typing as tp + import torch -from torch import nn from fairseq2.typing import DataType, Device +from torch import nn -from .classifier import ( - MutoxClassifier, - MutoxConfig, -) +from .classifier import MutoxClassifier, MutoxConfig class MutoxClassifierBuilder: @@ -42,21 +40,28 @@ def __init__( self.config = config self.device, self.dtype = device, dtype - def build_model(self) -> MutoxClassifier: + def build_model(self, activation=nn.ReLU) -> MutoxClassifier: model_h1 = nn.Sequential( nn.Dropout(0.01), nn.Linear(self.config.input_size, 512), ) model_h2 = nn.Sequential( - nn.ReLU(), + activation, nn.Linear(512, 128), ) - model_h3 = nn.Sequential( - nn.ReLU(), - nn.Linear(128, 1), - ) + if self.config.output_prob: + model_h3 = nn.Sequential( + activation(), + nn.Linear(128, 1), + nn.Sigmoid() + ) + else: + model_h3 = nn.Sequential( + activation(), + nn.Linear(128, 1) + ) model_all = nn.Sequential( model_h1, @@ -64,7 +69,9 @@ def build_model(self) -> MutoxClassifier: model_h3, ) - return MutoxClassifier(model_all,).to( + return MutoxClassifier( + model_all, + ).to( device=self.device, dtype=self.dtype, ) diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 5c67cca..1dcd360 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -5,13 +5,12 @@ # MIT_LICENSE file in the root directory of this source tree. from dataclasses import dataclass -import torch -from torch import nn +from typing import Optional -from fairseq2.typing import DataType, Device +import torch from fairseq2.models.utils.arch_registry import ArchitectureRegistry - -from typing import Optional +from fairseq2.typing import DataType, Device +from torch import nn class MutoxClassifier(nn.Module): @@ -33,5 +32,8 @@ class MutoxConfig: # size of the input embedding supported by this model input_size: int + # add sigmoid as last layer to output probability + output_prob: bool = False + mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier") diff --git a/sonar/models/mutox/loader.py b/sonar/models/mutox/loader.py index 013930b..20d764d 100644 --- a/sonar/models/mutox/loader.py +++ b/sonar/models/mutox/loader.py @@ -10,11 +10,9 @@ from fairseq2.models.utils import ConfigLoader, ModelLoader from .builder import create_mutox_model -from .classifier import ( - MutoxClassifier, - MutoxConfig, - mutox_archs, -) +from .classifier import MutoxClassifier, MutoxConfig, mutox_archs + +__import__("sonar") # Import only to update asset_store @mutox_archs.decorator("mutox") diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py index 7c2b53a..0e0db33 100644 --- a/tests/unit_tests/test_mutox.py +++ b/tests/unit_tests/test_mutox.py @@ -7,16 +7,18 @@ import pytest import torch from torch import nn -from unittest.mock import Mock -from sonar.models.mutox.builder import MutoxConfig, MutoxClassifierBuilder, create_mutox_model -from sonar.models.mutox.classifier import MutoxClassifier -from sonar.models.mutox.loader import ( - convert_mutox_checkpoint, +from sonar.models.mutox.builder import ( + MutoxClassifierBuilder, + MutoxConfig, + create_mutox_model, ) +from sonar.models.mutox.classifier import MutoxClassifier +from sonar.models.mutox.loader import convert_mutox_checkpoint # Builder tests + @pytest.mark.parametrize("input_size", [256, 512, 1024]) @pytest.mark.parametrize("device", [torch.device("cpu")]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @@ -52,9 +54,10 @@ def test_create_mutox_model(input_size): # Classifier tests + def test_mutox_classifier_forward(): """Test that MutoxClassifier forward pass returns expected output shape.""" - test_model= nn.Sequential( + test_model = nn.Sequential( nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 1), @@ -63,30 +66,40 @@ def test_mutox_classifier_forward(): test_input = torch.randn(3, 10) output = model(test_input) - assert output.shape == (3, 1), f"Expected output shape (3, 1), but instead got {output.shape}" + assert output.shape == ( + 3, + 1, + ), f"Expected output shape (3, 1), but instead got {output.shape}" def test_mutox_config(): """Test that MutoxConfig stores the configuration for a model.""" config = MutoxConfig(input_size=512) - assert config.input_size == 512, f"Config input_size should be 512, but got {config.input_size}" + assert ( + config.input_size == 512 + ), f"Config input_size should be 512, but got {config.input_size}" # Loader tests + def test_convert_mutox_checkpoint(): """Test convert_mutox_checkpoint correctly filters keys in the checkpoint.""" # Create a mock checkpoint with both 'model_all.' prefixed keys and other keys checkpoint = { "model_all.layer1.weight": torch.tensor([1.0]), "model_all.layer1.bias": torch.tensor([0.5]), - "non_model_key": torch.tensor([3.0]) + "non_model_key": torch.tensor([3.0]), } config = MutoxConfig(input_size=1024) converted = convert_mutox_checkpoint(checkpoint, config) # Verify only 'model_all.' keys are retained in the converted dictionary assert "model" in converted, "Converted checkpoint should contain a 'model' key" - assert "model_all.layer1.weight" in converted["model"], "Expected 'model_all.layer1.weight'" - assert "model_all.layer1.bias" in converted["model"], "Expected 'model_all.layer1.bias'" + assert ( + "model_all.layer1.weight" in converted["model"] + ), "Expected 'model_all.layer1.weight'" + assert ( + "model_all.layer1.bias" in converted["model"] + ), "Expected 'model_all.layer1.bias'" assert "non_model_key" not in converted["model"], "Unexpected 'non_model_key'"