diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb index f511a68..4cbe982 100644 --- a/examples/mutox_example.ipynb +++ b/examples/mutox_example.ipynb @@ -19,7 +19,7 @@ "source": [ "# MUTOX toxicity classification\n", "\n", - "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that." + "Mutox enables toxicity scoring for speech and text using sonar embeddings and a classifier trained with a _Binary Cross Entropy loss with logits_ objective. To obtain probabilities from the classifier's output, apply a sigmoid layer. This notebook demonstrates encoding speech and text into sonar embeddings and classifying their toxicity." ] }, { diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml index 760c40d..f1e60ee 100644 --- a/sonar/cards/sonar_mutox.yaml +++ b/sonar/cards/sonar_mutox.yaml @@ -4,13 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -This card is a duplicate of the original found at -[Facebook Research's Seamless Communication repository] -(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). -It is included here to prevent circular dependencies between the Seamless Communication -repository and this project. -""" +#This card is a duplicate of the original found at +#[Facebook Research's Seamless Communication repository] +#(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). +#It is included here to prevent circular dependencies between the Seamless Communication name: sonar_mutox model_type: mutox_classifier diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py index 535a9e6..04e7807 100644 --- a/sonar/inference_pipelines/mutox_speech.py +++ b/sonar/inference_pipelines/mutox_speech.py @@ -30,17 +30,19 @@ def __init__( device: Device = CPU_DEVICE, ) -> None: if isinstance(encoder, str): - model = self.load_model_from_name("sonar_mutox", encoder, device=device) + self.model = self.load_model_from_name("sonar_mutox", encoder, device=device) else: - model = encoder + self.model = encoder - super().__init__(model) + super().__init__(self.model) self.model.to(device).eval() - self.mutox_classifier = mutox_classifier.to(device).eval() if isinstance(mutox_classifier, str): - self.mutox_classifier = load_mutox_model(mutox_classifier, device=device,) + self.mutox_classifier = load_mutox_model( + mutox_classifier, + device=device, + ) else: self.mutox_classifier = mutox_classifier @@ -65,4 +67,8 @@ def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuild @torch.inference_mode() def _run_classifier(self, data: dict): - return self.mutox_classifier(data.sentence_embeddings) + sentence_embeddings = data.get("sentence_embeddings") + if sentence_embeddings is None: + raise ValueError("Missing sentence embeddings in the data.") + return self.mutox_classifier(sentence_embeddings) + diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index d9d86b0..f021c85 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -40,7 +40,7 @@ def __init__( self.config = config self.device, self.dtype = device, dtype - def build_model(self, activation=nn.ReLU) -> MutoxClassifier: + def build_model(self, activation=nn.ReLU()) -> MutoxClassifier: model_h1 = nn.Sequential( nn.Dropout(0.01), nn.Linear(self.config.input_size, 512), @@ -51,15 +51,9 @@ def build_model(self, activation=nn.ReLU) -> MutoxClassifier: nn.Linear(512, 128), ) - if self.config.output_prob: - model_h3 = nn.Sequential(activation(), nn.Linear(128, 1), nn.Sigmoid()) - else: - model_h3 = nn.Sequential(activation(), nn.Linear(128, 1)) - model_all = nn.Sequential( model_h1, model_h2, - model_h3, ) return MutoxClassifier( diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index 1dcd360..ada2efe 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -21,7 +21,12 @@ def __init__( super().__init__() self.model_all = model_all - def forward(self, inputs: torch.Tensor) -> torch.Tensor: + def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tensor: + if output_prob: + self.model_all.add_module("sigmoid", nn.Sigmoid()) + else: + self.model_all.add_module("linear", nn.Linear(128, 1)) + return self.model_all(inputs) @@ -32,8 +37,5 @@ class MutoxConfig: # size of the input embedding supported by this model input_size: int - # add sigmoid as last layer to output probability - output_prob: bool = False - mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier")