From 8a11329576fff501ba6f5817c0746f40bfc3c614 Mon Sep 17 00:00:00 2001 From: SAE1RNG Date: Mon, 22 Apr 2024 17:52:45 +0200 Subject: [PATCH] Refactor causal mask generation to use torch.tril for simplicity --- dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dataset.py b/dataset.py index 7aa175c..2e8e403 100644 --- a/dataset.py +++ b/dataset.py @@ -86,5 +86,4 @@ def __getitem__(self, idx): } def causal_mask(size): - mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) - return mask == 0 \ No newline at end of file + return torch.tril(torch.ones((1, size, size), dtype=torch.int)) \ No newline at end of file