diff --git a/torchbenchmark/operators/low_mem_dropout/operator.py b/torchbenchmark/operators/low_mem_dropout/operator.py index 82d50a6b0..d171eb8f4 100644 --- a/torchbenchmark/operators/low_mem_dropout/operator.py +++ b/torchbenchmark/operators/low_mem_dropout/operator.py @@ -38,7 +38,7 @@ def triton_dropout(self, p, x): n_elements = x.numel() grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda() + x_keep = (torch.rand(size=(n_elements,)) > p).to(torch.int32).cuda() def _inner(): return _triton_dropout[grid](