Skip to content

Commit

Permalink
[MPS] Add regression test for sync deadlock (pytorch#141296)
Browse files Browse the repository at this point in the history
See pytorch#140725 (comment)
Running `torch.mps.synchronize()` after metal kernel resulted in infinite wait inside `[_MTLCommandBuffer waitUntilCompleted]`
```
(lldb) bt
* thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP
  * frame #0: 0x00000001aa919084 Metal`pthread_cond_wait + 12
    frame #1: 0x00000001aa78b1b4 Metal`-[_MTLCommandBuffer waitUntilCompleted] + 84
    frame #2: 0x00000001032bf358 libtorch_python.dylib`torch::mps::MPSModule_deviceSynchronize(_object*, _object*) + 40
    frame #3: 0x0000000100e94c20 Python`cfunction_vectorcall_NOARGS + 100
    frame #4: 0x0000000100e389b8 Python`PyObject_Vectorcall + 92
    frame #5: 0x0000000100f61e38 Python`_PyEval_EvalFrameDefault + 19040
    frame #6: 0x0000000100f5d180 Python`PyEval_EvalCode + 200
    frame #7: 0x0000000100fcd1a4 Python`run_eval_code_obj + 104
    frame #8: 0x0000000100fccbe4 Python`run_mod + 168
    frame #9: 0x0000000100fcb518 Python`pyrun_file + 164
    frame #10: 0x0000000100fca854 Python`_PyRun_SimpleFileObject + 256
    frame pytorch#11: 0x0000000100fca4e8 Python`_PyRun_AnyFileObject + 80
    frame pytorch#12: 0x0000000100ff2028 Python`pymain_run_file_obj + 164
    frame pytorch#13: 0x0000000100ff1ce4 Python`pymain_run_file + 72
    frame pytorch#14: 0x0000000100ff0f74 Python`Py_RunMain + 988
    frame pytorch#15: 0x0000000100ff1564 Python`pymain_main + 304
    frame pytorch#16: 0x0000000100ff1604 Python`Py_BytesMain + 40
    frame pytorch#17: 0x000000019f630274 dyld`start + 2840
```

Pull Request resolved: pytorch#141296
Approved by: https://github.com/huydhn
  • Loading branch information
malfet authored and pytorchmergebot committed Nov 22, 2024
1 parent 25c0b91 commit 65166d8
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8385,6 +8385,14 @@ def test_cumprod_dim_check(self):
self.assertRaises(IndexError, lambda: x.cumprod(2))
self.assertRaises(IndexError, lambda: x.cumprod(-3))

def test_do_sync_thrice_its_all_right(self):
# Regression test for https://github.com/pytorch/pytorch/commit/9bc9d4cdb4355a385a7d7959f07d04d1648d6904
# That caused sync calls to deadlock
x = torch.nextafter(torch.ones(1024, device='mps'), torch.zeros(1024, device='mps'))
for _ in range(3):
torch.mps.synchronize()
self.assertLess(x.sum().item(), x.numel())

class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down

0 comments on commit 65166d8

Please sign in to comment.