From e9e08dbb6571e92ef0dbb05792cdb8f2ab8c1d75 Mon Sep 17 00:00:00 2001 From: 6V <49866079+6Vvv@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:20:13 +0800 Subject: [PATCH] [Enhance] Add unit tests for autocast with Ascend device (#1363) --- tests/test_runner/test_amp.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index 89794f3414..a80c7f35cb 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -5,7 +5,7 @@ import torch.nn as nn import mmengine -from mmengine.device import get_device, is_mlu_available +from mmengine.device import get_device, is_mlu_available, is_npu_available from mmengine.runner import autocast from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -14,7 +14,22 @@ class TestAmp(unittest.TestCase): def test_autocast(self): - if is_mlu_available(): + if is_npu_available(): + device = 'npu' + with autocast(device_type=device): + # torch.autocast support npu mode. + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertIn(res.dtype, (torch.bfloat16, torch.float16)) + with autocast(enabled=False, device_type=device): + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + # Test with fp32_enabled + with autocast(enabled=False, device_type=device): + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + elif is_mlu_available(): device = 'mlu' with autocast(device_type=device): # torch.autocast support mlu mode.