Skip to content

Commit

Permalink
chore: fix test for attention assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Dec 19, 2024
1 parent abfb76a commit 960c99b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/torch/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def simple_loss_fn(logits, labels):

x = torch.randn(5, 10)
y = torch.randn(5, 10)
attention_mask = torch.randn(5, 10)
attention_mask = torch.randint(0, 2, (5, 10))

# Call forward with (input_ids, labels, attention_mask)
loss, _ = lora_training((x, y, attention_mask))
Expand Down Expand Up @@ -495,7 +495,7 @@ def forward(self, x, attention_mask=None, labels=None):
lora_training = LoraTraining(model)
x = torch.randn(5, 10)
y = torch.randn(5, 10)
attention_mask = torch.randn(5, 10)
attention_mask = torch.randint(0, 2, (5, 10))

loss, _ = lora_training((x, y, attention_mask))
assert isinstance(loss, torch.Tensor)
Expand Down

0 comments on commit 960c99b

Please sign in to comment.