From a24490c032a75e6ac31f7d33221e032552715eca Mon Sep 17 00:00:00 2001 From: nobu-g Date: Sat, 18 Nov 2023 22:57:51 +0900 Subject: [PATCH] fix a bug --- src/kwja/modules/functions/loss.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/kwja/modules/functions/loss.py b/src/kwja/modules/functions/loss.py index bdbebd80..893301c7 100644 --- a/src/kwja/modules/functions/loss.py +++ b/src/kwja/modules/functions/loss.py @@ -35,6 +35,7 @@ def compute_multi_label_token_mean_loss( if input_.isnan().any().item() is True: return torch.tensor(float("nan"), dtype=input_.dtype, device=input_.device) else: + target = torch.where(mask, target, torch.zeros_like(target)) losses = nn.functional.binary_cross_entropy(input_, target.float(), reduction="none") # (b, seq, num_features) # features の軸は和をとる losses = (losses * mask).sum(dim=2) # (b, seq)