Skip to content

Commit

Permalink
Set SDPBackend explicitly for validation (#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 authored Aug 23, 2024
1 parent 11edba8 commit 7c1f94a
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,12 +1057,13 @@ def sdpa_fn(q, k, v, dropout_p, is_causal, scale):
ref_inp.requires_grad = True
ref_tensor_inputs.append(ref_inp)

with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
from torch.nn.attention import SDPBackend, sdpa_kernel

with torch.random.fork_rng(devices=[torch.cuda.current_device()]) and sdpa_kernel(SDPBackend.FLASH_ATTENTION):
ref_attn_out = sdpa_fn(*ref_tensor_inputs, *scalar_inputs)
ref_attn_out.backward(grad_out)

nv_outputs = (attn_out, q.grad, k.grad, v.grad)
ref_outputs = (ref_attn_out, *(inp.grad for inp in ref_tensor_inputs))

for nv_out, ref_out in zip(nv_outputs, ref_outputs):
torch.testing.assert_close(nv_out, ref_out)

0 comments on commit 7c1f94a

Please sign in to comment.