diff --git a/extra/models/bert.py b/extra/models/bert.py index 80ab3a8d74fb3..c1288c8fa2310 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -64,7 +64,7 @@ def __call__(self, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, onehot = counter == masked_positions.unsqueeze(2).expand(*masked_positions.shape, output.shape[1]) h_masked = onehot @ output - h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked))) + h_masked = self.lm_norm(self.lm_transform_activation(self.lm_transform(h_masked).float())) lm_logits = self.lm_output(h_masked) + self.lm_output_bias return lm_logits, clsf_logits