Skip to content

Commit

Permalink
[BE] use torch.amp.autocast instead of torch.cuda.amp.autocast (pytor…
Browse files Browse the repository at this point in the history
…ch#134291)

torch.cuda.amp.autocast / torch.cpu.amp.autocast are deprecated and spew a ton of warnings when these tests run. This PR: Update to just use torch.amp.autocast(device).

Note: this uncovers a bug in the test: when `device` is CUDA, it actually shows up as "cuda:0" - so previously, this test was _always_ using `torch.cpu.amp.autocast` even for `cuda` device. This PR fixes this, and uncovers additional bugs in `pinverse` and `linalg.pinv`; `linalg.pinv` was already failing before on CPU, but now the test also catches failures on CUDA, (and this PR adds to the skipped-test list).
Pull Request resolved: pytorch#134291
Approved by: https://github.com/YuqingJ
  • Loading branch information
davidberard98 authored and pytorchmergebot committed Aug 24, 2024
1 parent a106100 commit d433a60
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,7 @@ def test_refs_are_in_decomp_table(self, op):

# TODO: investigate/fix
fake_autocast_device_skips["cpu"] = {"linalg.pinv"}
fake_autocast_device_skips["cuda"] = {"linalg.pinv", "pinverse"}


dynamic_output_op_tests = (
Expand Down Expand Up @@ -2575,12 +2576,14 @@ def test_fake(self, device, dtype, op):

@ops(op_db, dtypes=OpDTypes.any_one)
def test_fake_autocast(self, device, dtype, op):
if op.name in fake_autocast_device_skips[device]:
device_type = torch.device(device).type
if op.name in fake_autocast_device_skips[device_type]:
self.skipTest("Skip failing test")
context = (
torch.cuda.amp.autocast if device == "cuda" else torch.cpu.amp.autocast
)
self._test_fake_helper(device, dtype, op, context)

def context_fn():
return torch.amp.autocast(device_type)

self._test_fake_helper(device, dtype, op, context_fn)

def _test_fake_crossref_helper(self, device, dtype, op, context):
samples = op.sample_inputs(device, dtype, requires_grad=True)
Expand Down

0 comments on commit d433a60

Please sign in to comment.