diff --git a/examples/mutox_example.ipynb b/examples/mutox_example.ipynb
index f511a68..4cbe982 100644
--- a/examples/mutox_example.ipynb
+++ b/examples/mutox_example.ipynb
@@ -19,7 +19,7 @@
    "source": [
     "# MUTOX toxicity classification\n",
     "\n",
-    "Mutox lets you score speech and text toxicity using a classifier that can score sonar embeddings. In this notebook, we provide an example of encoding speech and text and classifying that."
+    "Mutox enables toxicity scoring for speech and text using sonar embeddings and a classifier trained with a _Binary Cross Entropy loss with logits_ objective. To obtain probabilities from the classifier's output, apply a sigmoid layer. This notebook demonstrates encoding speech and text into sonar embeddings and classifying their toxicity."
    ]
   },
   {
diff --git a/sonar/cards/sonar_mutox.yaml b/sonar/cards/sonar_mutox.yaml
index 760c40d..f1e60ee 100644
--- a/sonar/cards/sonar_mutox.yaml
+++ b/sonar/cards/sonar_mutox.yaml
@@ -4,13 +4,10 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-"""
-This card is a duplicate of the original found at
-[Facebook Research's Seamless Communication repository]
-(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml).
-It is included here to prevent circular dependencies between the Seamless Communication
-repository and this project.
-"""
+#This card is a duplicate of the original found at
+#[Facebook Research's Seamless Communication repository]
+#(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml).
+#It is included here to prevent circular dependencies between the Seamless Communication
 
 name: sonar_mutox
 model_type: mutox_classifier
diff --git a/sonar/inference_pipelines/mutox_speech.py b/sonar/inference_pipelines/mutox_speech.py
index 535a9e6..04e7807 100644
--- a/sonar/inference_pipelines/mutox_speech.py
+++ b/sonar/inference_pipelines/mutox_speech.py
@@ -30,17 +30,19 @@ def __init__(
         device: Device = CPU_DEVICE,
     ) -> None:
         if isinstance(encoder, str):
-            model = self.load_model_from_name("sonar_mutox", encoder, device=device)
+            self.model = self.load_model_from_name("sonar_mutox", encoder, device=device)
         else:
-            model = encoder
+            self.model = encoder
 
-        super().__init__(model)
+        super().__init__(self.model)
 
         self.model.to(device).eval()
-        self.mutox_classifier = mutox_classifier.to(device).eval()
 
         if isinstance(mutox_classifier, str):
-            self.mutox_classifier = load_mutox_model(mutox_classifier, device=device,)
+            self.mutox_classifier = load_mutox_model(
+                mutox_classifier,
+                device=device,
+            )
         else:
             self.mutox_classifier = mutox_classifier
 
@@ -65,4 +67,8 @@ def prebuild_pipeline(self, context: SpeechInferenceParams) -> DataPipelineBuild
 
     @torch.inference_mode()
     def _run_classifier(self, data: dict):
-        return self.mutox_classifier(data.sentence_embeddings)
+        sentence_embeddings = data.get("sentence_embeddings")
+        if sentence_embeddings is None:
+            raise ValueError("Missing sentence embeddings in the data.")
+        return self.mutox_classifier(sentence_embeddings)
+
diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py
index d9d86b0..f021c85 100644
--- a/sonar/models/mutox/builder.py
+++ b/sonar/models/mutox/builder.py
@@ -40,7 +40,7 @@ def __init__(
         self.config = config
         self.device, self.dtype = device, dtype
 
-    def build_model(self, activation=nn.ReLU) -> MutoxClassifier:
+    def build_model(self, activation=nn.ReLU()) -> MutoxClassifier:
         model_h1 = nn.Sequential(
             nn.Dropout(0.01),
             nn.Linear(self.config.input_size, 512),
@@ -51,15 +51,9 @@ def build_model(self, activation=nn.ReLU) -> MutoxClassifier:
             nn.Linear(512, 128),
         )
 
-        if self.config.output_prob:
-            model_h3 = nn.Sequential(activation(), nn.Linear(128, 1), nn.Sigmoid())
-        else:
-            model_h3 = nn.Sequential(activation(), nn.Linear(128, 1))
-
         model_all = nn.Sequential(
             model_h1,
             model_h2,
-            model_h3,
         )
 
         return MutoxClassifier(
diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py
index 1dcd360..ada2efe 100644
--- a/sonar/models/mutox/classifier.py
+++ b/sonar/models/mutox/classifier.py
@@ -21,7 +21,12 @@ def __init__(
         super().__init__()
         self.model_all = model_all
 
-    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+    def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tensor:
+        if output_prob:
+            self.model_all.add_module("sigmoid", nn.Sigmoid())
+        else:
+            self.model_all.add_module("linear", nn.Linear(128, 1))
+
         return self.model_all(inputs)
 
 
@@ -32,8 +37,5 @@ class MutoxConfig:
     # size of the input embedding supported by this model
     input_size: int
 
-    # add sigmoid as last layer to output probability
-    output_prob: bool = False
-
 
 mutox_archs = ArchitectureRegistry[MutoxConfig]("mutox_classifier")