From 960c99bfa339671bf4b274e7ee346c8550851cbe Mon Sep 17 00:00:00 2001 From: jfrery Date: Thu, 19 Dec 2024 12:06:41 +0100 Subject: [PATCH] chore: fix test for attention assertion --- tests/torch/test_lora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/torch/test_lora.py b/tests/torch/test_lora.py index dfdff971b..03a38d929 100644 --- a/tests/torch/test_lora.py +++ b/tests/torch/test_lora.py @@ -459,7 +459,7 @@ def simple_loss_fn(logits, labels): x = torch.randn(5, 10) y = torch.randn(5, 10) - attention_mask = torch.randn(5, 10) + attention_mask = torch.randint(0, 2, (5, 10)) # Call forward with (input_ids, labels, attention_mask) loss, _ = lora_training((x, y, attention_mask)) @@ -495,7 +495,7 @@ def forward(self, x, attention_mask=None, labels=None): lora_training = LoraTraining(model) x = torch.randn(5, 10) y = torch.randn(5, 10) - attention_mask = torch.randn(5, 10) + attention_mask = torch.randint(0, 2, (5, 10)) loss, _ = lora_training((x, y, attention_mask)) assert isinstance(loss, torch.Tensor)