From b78713b61de1efc725e72e954aa68e44393fde5f Mon Sep 17 00:00:00 2001 From: Alex Athorne Date: Fri, 22 Dec 2023 22:36:25 +0000 Subject: [PATCH] Minor fix --- src/pytfex/transformer/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytfex/transformer/moe.py b/src/pytfex/transformer/moe.py index 861e064..7d1c9f7 100644 --- a/src/pytfex/transformer/moe.py +++ b/src/pytfex/transformer/moe.py @@ -59,7 +59,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ex = x[batch_indices, indices] ex_pred = scores[:, :, None] * expert(ex) new_x[batch_indices, indices] += ex_pred - return x + return new_x def _compute_k(self, l: int) -> int: k = int((l * self.c) / self.num_experts)