diff --git a/README.md b/README.md index 17e0fbe..a085ab0 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,49 @@ print(blaser_qe(src=src_embs, mt=mt_embs).item()) # 4.708 Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggingface.co/facebook/blaser-2.0-ref), [facebook/blaser-2.0-qe](https://huggingface.co/facebook/blaser-2.0-qe). +### Classifying the toxicity of sentences with MuTox + +[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a logit of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset. + +```Python +from sonar.models.mutox.loader import load_mutox_model +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline +import torch + +if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 +else: + device = torch.device("cpu") + dtype = torch.float32 + +t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, +) +text_column='lang_txt' +classifier = load_mutox_model( + "sonar_mutox", + device=device, + dtype=dtype, +).eval() + +with torch.inference_mode(): + emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16) + +with torch.inference_mode(): + emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='eng_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-53.5938]], device='cuda:0', dtype=torch.float16) + +with torch.inference_mode(): + emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn') + x = classifier(emb.to(device).to(dtype)) # tensor([[-21.4062]], device='cuda:0', dtype=torch.float16) +``` + +For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox). + ### Demo notebooks See more complete demo notebooks : diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb new file mode 100644 index 0000000..4cbe982 --- /dev/null +++ b/examples/mutox_example.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Meta Platforms, Inc. and affiliates\n", + "# All rights reserved.\n", + "#\n", + "# This source code is licensed under the license found in the\n", + "# MIT_LICENSE file in the root directory of this source tree." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# MUTOX toxicity classification\n", + "\n", + "Mutox enables toxicity scoring for speech and text using sonar embeddings and a classifier trained with a _Binary Cross Entropy loss with logits_ objective. To obtain probabilities from the classifier's output, apply a sigmoid layer. This notebook demonstrates encoding speech and text into sonar embeddings and classifying their toxicity." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from pathlib import Path\n", + "\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda:0\")\n", + " dtype = torch.float16\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + " dtype = torch.float32" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Speech Scoring" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. download some demo audio segments\n", + "2. create a tsv file to feed to the speech scoring pipeline\n", + "3. load the model and build the pipeline\n", + "4. go through the batches in the pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# get demo file\n", + "import urllib.request\n", + "import tempfile\n", + "\n", + "files = [\n", + " (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n", + " (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n", + "]\n", + "\n", + "tmpdir = Path(tempfile.mkdtemp())\n", + "tsv_file = (tmpdir / 'data.tsv')\n", + "with tsv_file.open('w') as tsv_file_p:\n", + " print('path', file=tsv_file_p)\n", + " for (uri, name) in files:\n", + " dl = tmpdir / name\n", + " urllib.request.urlretrieve(uri, dl)\n", + " print(dl, file=tsv_file_p)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sonar.inference_pipelines.speech import SpeechInferenceParams\n", + "from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n", + "\n", + "pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", + " mutox_classifier_name =\"sonar_mutox\",\n", + " encoder_name=f\"sonar_speech_encoder_eng\",\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n", + " data_file=tsv_file,\n", + " audio_root_dir=None,\n", + " audio_path_index=0,\n", + " target_lang=\"eng\",\n", + " batch_size=4,\n", + " pad_idx=0,\n", + " device=device,\n", + " fbank_dtype=torch.float32,\n", + " n_parallel=4\n", + "))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n", + "/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n" + ] + } + ], + "source": [ + "for batch in pipeline:\n", + " ex = batch['audio']\n", + " for idx, path in enumerate(ex['path']):\n", + " print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# cleanup tmp dir\n", + "import shutil\n", + "shutil.rmtree(tmpdir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Text Scoring\n", + "\n", + "1. load the sonar text encoder\n", + "2. load the mutox classifier model\n", + "3. compute embedding for a sentence\n", + "4. score this embedding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n" + ] + } + ], + "source": [ + "from sonar.models.mutox.loader import load_mutox_model\n", + "from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n", + "\n", + "t2vec_model = TextToEmbeddingModelPipeline(\n", + " encoder=\"text_sonar_basic_encoder\",\n", + " tokenizer=\"text_sonar_basic_encoder\",\n", + " device=device,\n", + ")\n", + "text_column='lang_txt'\n", + "classifier = load_mutox_model(\n", + " \"sonar_mutox\",\n", + " device=device,\n", + " dtype=dtype,\n", + ").eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with torch.inference_mode():\n", + " emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n", + " x = classifier(emb.to(device).half())\n", + "\n", + "x" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "SONAR", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml new file mode 100644 index 0000000..f1e60ee --- /dev/null +++ b/sonar/cards/sonar_mutox.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +#This card is a duplicate of the original found at +#[Facebook Research's Seamless Communication repository] +#(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). +#It is included here to prevent circular dependencies between the Seamless Communication + +name: sonar_mutox +model_type: mutox_classifier +model_arch: mutox +checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt" +input_size: 1024 \ No newline at end of file diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py new file mode 100644 index 0000000..8e7ff71 --- /dev/null +++ b/sonar/inference_pipelines/mutox_speech.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +from typing import Union + +import torch +from fairseq2.data import DataPipelineBuilder +from fairseq2.typing import Device + +from sonar.inference_pipelines.speech import ( + AudioToFbankDataPipelineBuilder, + SpeechInferenceParams, + SpeechInferencePipeline, + SpeechToEmbeddingPipeline, +) +from sonar.inference_pipelines.utils import extract_sequence_batch +from sonar.models.encoder_model import SonarEncoderModel +from sonar.models.mutox.classifier import MutoxClassifier +from sonar.models.mutox.loader import load_mutox_model +from sonar.models.sonar_speech.loader import load_sonar_speech_model + +CPU_DEVICE = torch.device("cpu") + + +class MutoxSpeechClassifierPipeline(SpeechInferencePipeline): + model: SonarEncoderModel + + def __init__( + self, + mutox_classifier: Union[str, MutoxClassifier], + encoder: Union[str, SonarEncoderModel], + device: Device = CPU_DEVICE, + ) -> None: + if isinstance(encoder, str): + self.model = self.load_model_from_name( + "sonar_mutox", encoder, device=device + ) # type: ignore + else: + self.model = encoder + + super().__init__() + + self.model.to(device).eval() + + if isinstance(mutox_classifier, str): + self.mutox_classifier = load_mutox_model( + mutox_classifier, + device=device, + ) + else: + self.mutox_classifier = mutox_classifier + + self.mutox_classifier.to(device).eval() + + @classmethod + def load_model_from_name( + cls, + mutox_classifier_name: str, + encoder_name: str, + device: Device = CPU_DEVICE, + ) -> "MutoxSpeechClassifierPipeline": + encoder = load_sonar_speech_model(encoder_name, device=device, progress=False) + mutox_classifier = load_mutox_model( + mutox_classifier_name, device=device, progress=False + ) + return cls(mutox_classifier=mutox_classifier, encoder=encoder, device=device) + + def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuilder: + audio_to_fbank_dp_builder = AudioToFbankDataPipelineBuilder() + pipeline_builder = ( + audio_to_fbank_dp_builder.prebuild_pipeline(context) + .map( + lambda fbank: extract_sequence_batch(fbank, context.device), + selector="audio.data.fbank", + ) + .map(self.run_inference, selector="audio.data") + ) + return pipeline_builder.map(self._run_classifier, selector="audio.data") + + @torch.inference_mode() + def run_inference(self, fbank: torch.Tensor) -> dict: + """Runs the encoder model on the extracted FBANK features.""" + return {"sentence_embeddings": self.model(fbank)} + + @torch.inference_mode() + def _run_classifier(self, data: dict): + sentence_embeddings = data.get("sentence_embeddings") + if sentence_embeddings is None: + raise ValueError("Missing sentence embeddings in the data.") + return self.mutox_classifier(sentence_embeddings) diff --git a/sonar/models/mutox/__init__.py b/sonar/models/mutox/__init__.py new file mode 100644 index 0000000..15d8859 --- /dev/null +++ b/sonar/models/mutox/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py new file mode 100644 index 0000000..b308cb9 --- /dev/null +++ b/sonar/models/mutox/builder.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +import typing as tp + +import torch +from fairseq2.typing import DataType, Device +from torch import nn + +from .classifier import MutoxClassifier, MutoxConfig + + +class MutoxClassifierBuilder: + """ + Builder module for MutoxClassifier model + """ + + config: MutoxConfig + device: tp.Optional[Device] + dtype: tp.Optional[DataType] + + def __init__( + self, + config: MutoxConfig, + *, + device: tp.Optional[Device] = None, + dtype: tp.Optional[DataType] = None, + ) -> None: + """ + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + self.config = config + self.device, self.dtype = device, dtype + + def build_model(self) -> MutoxClassifier: + model_h1 = nn.Sequential( + nn.Dropout(0.01), + nn.Linear(self.config.input_size, 512), + ) + + model_h2 = nn.Sequential( + nn.ReLU(), + nn.Linear(512, 128), + ) + + model_h3 = nn.Sequential( + nn.ReLU(), + nn.Linear(128, 1), + ) + + model_all = nn.Sequential( + model_h1, + model_h2, + model_h3, + ) + + return MutoxClassifier( + model_all, + ).to( + device=self.device, + dtype=self.dtype, + ) + + +def create_mutox_model( + config: MutoxConfig, + device: tp.Optional[Device] = None, + dtype: tp.Optional[DataType] = None, +) -> MutoxClassifier: + """Create a Mutox Classifier model. + + :param config: + The configuration to use. + :param device: + The device on which to initialize modules. + :param dtype: + The data type of module parameters and buffers. + """ + + return MutoxClassifierBuilder( + config, + device=device, + dtype=dtype, + ).build_model() diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py new file mode 100644 index 0000000..524a386 --- /dev/null +++ b/sonar/models/mutox/classifier.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from fairseq2.models.utils.arch_registry import ArchitectureRegistry +from fairseq2.typing import DataType, Device +from torch import nn + + +class MutoxClassifier(nn.Module): + def __init__( + self, + model_all, + ): + super().__init__() + self.model_all = model_all + + def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tensor: + outputs = self.model_all(inputs) + + if output_prob: + outputs = torch.sigmoid(outputs) + + return outputs + + +@dataclass +class MutoxConfig: + """Holds the configuration of a Mutox Classifier model.""" + + # size of the input embedding supported by this model + input_size: int + + +mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier") diff --git a/sonar/models/mutox/loader.py b/sonar/models/mutox/loader.py new file mode 100644 index 0000000..20d764d --- /dev/null +++ b/sonar/models/mutox/loader.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# MIT_LICENSE file in the root directory of this source tree. + +import typing as tp + +from fairseq2.assets import asset_store, download_manager +from fairseq2.models.utils import ConfigLoader, ModelLoader + +from .builder import create_mutox_model +from .classifier import MutoxClassifier, MutoxConfig, mutox_archs + +__import__("sonar") # Import only to update asset_store + + +@mutox_archs.decorator("mutox") +def _base_mutox() -> MutoxConfig: + return MutoxConfig( + input_size=1024, + ) + + +def convert_mutox_checkpoint( + checkpoint: tp.Mapping[str, tp.Any], config: MutoxConfig +) -> tp.Mapping[str, tp.Any]: + new_dict = {} + for key in checkpoint: + if key.startswith("model_all."): + new_dict[key] = checkpoint[key] + return {"model": new_dict} + + +load_mutox_config = ConfigLoader[MutoxConfig](asset_store, mutox_archs) + + +load_mutox_model = ModelLoader[MutoxClassifier, MutoxConfig]( + asset_store, + download_manager, + load_mutox_config, + create_mutox_model, + convert_mutox_checkpoint, +) diff --git a/tests/integration_tests/test_mutox.py b/tests/integration_tests/test_mutox.py new file mode 100644 index 0000000..a7e5b45 --- /dev/null +++ b/tests/integration_tests/test_mutox.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline +from sonar.models.mutox.loader import load_mutox_model + + +@pytest.mark.parametrize( + "input_texts, source_lang, expected_outputs", + [ + ( + ["De peur que le pays ne se prostitue et ne se remplisse de crimes."], + "fra_Latn", + [-19.7812], + ), + ( + ["She worked hard and made a significant contribution to the team."], + "eng_Latn", + [-53.5938], + ), + ( + [ + "El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones." + ], + "spa_Latn", + [-21.4062], + ), + ], +) +def test_sonar_mutox_classifier_integration(input_texts, source_lang, expected_outputs): + """Integration test to compare classifier outputs with expected values.""" + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval() + + with torch.inference_mode(): + embeddings = t2vec_model.predict(input_texts, source_lang=source_lang) + outputs = classifier(embeddings.to(device).to(dtype)).squeeze() + + if outputs.dim() == 0: + outputs = [outputs.item()] + else: + outputs = outputs.tolist() + + # Compare the outputs to expected values within a small tolerance + for output, expected in zip(outputs, expected_outputs): + assert abs(output - expected) < 0.1, ( + f"Expected output {expected}, but got {output}. " + "Outputs should be close to expected values." + ) + + +@pytest.mark.parametrize( + "input_texts, source_lang, expected_probabilities", + [ + ( + ["De peur que le pays ne se prostitue et ne se remplisse de crimes."], + "fra_Latn", + [0.0], + ), + ( + [ + "Dammit, that was a terrible launch, it will piss the director and make the mission fail." + ], + "eng_Latn", + [0.23], + ), + ( + [ + "El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones." + ], + "spa_Latn", + [0.0], + ), + ], +) +def test_sonar_mutox_classifier_probability_integration( + input_texts, source_lang, expected_probabilities +): + """Integration test to verify classifier output probabilities.""" + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + dtype = torch.float16 + else: + device = torch.device("cpu") + dtype = torch.float32 + + t2vec_model = TextToEmbeddingModelPipeline( + encoder="text_sonar_basic_encoder", + tokenizer="text_sonar_basic_encoder", + device=device, + ) + + classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval() + + for text, lang, expected_prob in zip( + input_texts, [source_lang] * len(input_texts), expected_probabilities + ): + with torch.inference_mode(): + emb = t2vec_model.predict([text], source_lang=lang) + + prob = classifier(emb.to(device).to(dtype), output_prob=True) + + assert abs(prob.item() - expected_prob) < 0.01, ( + f"Expected probability {expected_prob}, but got {prob.item()}. " + "Output probability should be within a reasonable range." + ) diff --git a/tests/unit_tests/huggingface_pipelines/text.py b/tests/unit_tests/huggingface_pipelines/text.py index a354c10..689349d 100644 --- a/tests/unit_tests/huggingface_pipelines/text.py +++ b/tests/unit_tests/huggingface_pipelines/text.py @@ -52,7 +52,7 @@ def test_embedding_to_text_process_batch(embedding_to_text_config): embedding_dim = 1024 num_embeddings = 4 - embeddings = [ + embeddings: List[np.ndarray] = [ np.random.rand(embedding_dim).astype(np.float32) for _ in range(num_embeddings) ] diff --git a/tests/unit_tests/test_mutox.py b/tests/unit_tests/test_mutox.py new file mode 100644 index 0000000..ebb73eb --- /dev/null +++ b/tests/unit_tests/test_mutox.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from torch import nn + +from sonar.models.mutox.builder import ( + MutoxClassifierBuilder, + MutoxConfig, + create_mutox_model, +) +from sonar.models.mutox.classifier import MutoxClassifier +from sonar.models.mutox.loader import convert_mutox_checkpoint + +# Builder tests + + +@pytest.mark.parametrize("input_size", [256, 512, 1024]) +@pytest.mark.parametrize("device", [torch.device("cpu")]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) +def test_mutox_classifier_builder(input_size, device, dtype): + """Test MutoxClassifierBuilder initializes a model with correct configuration and dtype.""" + config = MutoxConfig(input_size=input_size) + builder = MutoxClassifierBuilder(config, device=device, dtype=dtype) + model = builder.build_model() + + # Check if model layers are correctly initialized with shapes + assert isinstance(model, nn.Module), "Model should be an instance of nn.Module" + assert all( + isinstance(layer, nn.Sequential) for layer in model.model_all.children() + ), "All layers should be instances of nn.Sequential" + + test_input = torch.zeros((5, input_size), device=device, dtype=dtype) + result = model(test_input) + assert result.shape == (5, 1), f"Expected output shape (5, 1), got {result.shape}" + + +@pytest.mark.parametrize("input_size", [256, 512]) +def test_create_mutox_model(input_size): + """Test create_mutox_model function to confirm it creates a model with the specified config.""" + config = MutoxConfig(input_size=input_size) + model = create_mutox_model(config, device=torch.device("cpu")) + + # Check if the created model has the expected structure and behavior + test_input = torch.zeros((3, input_size)) + result = model(test_input) + assert result.shape == (3, 1), f"Expected output shape (3, 1), got {result.shape}" + assert isinstance(model, nn.Module), "Model should be an instance of nn.Module" + + +# Classifier tests + + +def test_mutox_classifier_forward(): + """Test that MutoxClassifier forward pass returns expected output shape.""" + test_model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + model = MutoxClassifier(test_model) + + test_input = torch.randn(3, 10) + output = model(test_input) + assert output.shape == ( + 3, + 1, + ), f"Expected output shape (3, 1), but instead got {output.shape}" + + +def test_mutox_classifier_forward_with_output_prob(): + """Test that MutoxClassifier forward pass applies sigmoid when output_prob=True.""" + test_model = nn.Sequential( + nn.Linear(10, 5), + nn.ReLU(), + nn.Linear(5, 1), + ) + model = MutoxClassifier(test_model) + + test_input = torch.randn(3, 10) + + output = model(test_input, output_prob=True) + + assert output.shape == ( + 3, + 1, + ), f"Expected output shape (3, 1), but instead got {output.shape}" + + assert (output >= 0).all() and ( + output <= 1 + ).all(), "Expected output values to be within the range [0, 1]" + + +def test_mutox_config(): + """Test that MutoxConfig stores the configuration for a model.""" + config = MutoxConfig(input_size=512) + assert ( + config.input_size == 512 + ), f"Config input_size should be 512, but got {config.input_size}" + + +# Loader tests + + +def test_convert_mutox_checkpoint(): + """Test convert_mutox_checkpoint correctly filters keys in the checkpoint.""" + checkpoint = { + "model_all.layer1.weight": torch.tensor([1.0]), + "model_all.layer1.bias": torch.tensor([0.5]), + "non_model_key": torch.tensor([3.0]), + } + config = MutoxConfig(input_size=1024) + converted = convert_mutox_checkpoint(checkpoint, config) + + # Verify only 'model_all.' keys are retained in the converted dictionary + assert "model" in converted, "Converted checkpoint should contain a 'model' key" + assert ( + "model_all.layer1.weight" in converted["model"] + ), "Expected 'model_all.layer1.weight'" + assert ( + "model_all.layer1.bias" in converted["model"] + ), "Expected 'model_all.layer1.bias'" + assert "non_model_key" not in converted["model"], "Unexpected 'non_model_key'"