diff --git a/tests/torch/test_lora.py b/tests/torch/test_lora.py index dfdff971b..03a38d929 100644 --- a/tests/torch/test_lora.py +++ b/tests/torch/test_lora.py @@ -459,7 +459,7 @@ def simple_loss_fn(logits, labels): x = torch.randn(5, 10) y = torch.randn(5, 10) - attention_mask = torch.randn(5, 10) + attention_mask = torch.randint(0, 2, (5, 10)) # Call forward with (input_ids, labels, attention_mask) loss, _ = lora_training((x, y, attention_mask)) @@ -495,7 +495,7 @@ def forward(self, x, attention_mask=None, labels=None): lora_training = LoraTraining(model) x = torch.randn(5, 10) y = torch.randn(5, 10) - attention_mask = torch.randn(5, 10) + attention_mask = torch.randint(0, 2, (5, 10)) loss, _ = lora_training((x, y, attention_mask)) assert isinstance(loss, torch.Tensor)