diff --git a/dataset.py b/dataset.py index 7aa175c..2e8e403 100644 --- a/dataset.py +++ b/dataset.py @@ -86,5 +86,4 @@ def __getitem__(self, idx): } def causal_mask(size): - mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) - return mask == 0 \ No newline at end of file + return torch.tril(torch.ones((1, size, size), dtype=torch.int)) \ No newline at end of file