Skip to content

Commit

Permalink
[Enhance] Add unit tests for autocast with Ascend device (#1363)
Browse files Browse the repository at this point in the history
  • Loading branch information
6Vvv authored Sep 27, 2023
1 parent 88dc1e9 commit e9e08db
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions tests/test_runner/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn as nn

import mmengine
from mmengine.device import get_device, is_mlu_available
from mmengine.device import get_device, is_mlu_available, is_npu_available
from mmengine.runner import autocast
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
Expand All @@ -14,7 +14,22 @@
class TestAmp(unittest.TestCase):

def test_autocast(self):
if is_mlu_available():
if is_npu_available():
device = 'npu'
with autocast(device_type=device):
# torch.autocast support npu mode.
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertIn(res.dtype, (torch.bfloat16, torch.float16))
with autocast(enabled=False, device_type=device):
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
# Test with fp32_enabled
with autocast(enabled=False, device_type=device):
layer = nn.Conv2d(1, 1, 1).to(device)
res = layer(torch.randn(1, 1, 1, 1).to(device))
self.assertEqual(res.dtype, torch.float32)
elif is_mlu_available():
device = 'mlu'
with autocast(device_type=device):
# torch.autocast support mlu mode.
Expand Down

0 comments on commit e9e08db

Please sign in to comment.