Skip to content

Commit

Permalink
access .is_leaf before .grad (#1210)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Sep 28, 2024
1 parent 4df9893 commit a15d586
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,7 +1951,7 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
dtype = dtypes.to_dtype(t.dtype)

grad = None
if t.grad is not None:
if t.is_leaf and t.grad is not None:
grad_pr = None
if history is not None:
attr_pr = ProvenanceRecord(inst=PseudoInst.CONSTANT, inputs=[], value="grad")
Expand Down
7 changes: 6 additions & 1 deletion thunder/tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
all_test_executors_and_dynamo = _all_test_executors() + [DynamoThunderExecutor]


# see https://docs.pytest.org/en/stable/how-to/capture-warnings.html#recwarn for the recwarn fixture
@instantiate(dtypes=(thunder.float32,), executors=all_test_executors_and_dynamo)
def test_nanogpt_complete(executor, device, dtype):
def test_nanogpt_complete(executor, device, dtype, recwarn):
tdtype = ttorch.to_torch_dtype(dtype)
make = partial(make_tensor, dtype=torch.int64, device=device)

Expand All @@ -46,6 +47,10 @@ def test_nanogpt_complete(executor, device, dtype):

assert_close(torch_result, thunder_result)

if recwarn:
for r in recwarn:
assert "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed." not in str(r.message)


# TODO Investigate grad inconsistency
# TODO: Add float16 and bfloat16 comparison tests here and to all other tests in
Expand Down

0 comments on commit a15d586

Please sign in to comment.