Skip to content

Commit

Permalink
[Fix] Adapt to PyTorch v2.1 on Ascend (#1332)
Browse files Browse the repository at this point in the history
  • Loading branch information
LRJKD authored Sep 1, 2023
1 parent 762c9a2 commit 5671b53
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,17 @@ def to(self, *args, **kwargs) -> nn.Module:
# directly will cause the NPU to not be found.
# Here, the input parameters are processed to avoid errors.
if args and isinstance(args[0], str) and 'npu' in args[0]:
args = tuple(
[list(args)[0].replace('npu', torch.npu.native_device)])
import torch_npu
args = tuple([
list(args)[0].replace(
'npu', torch_npu.npu.native_device if hasattr(
torch_npu.npu, 'native_device') else 'privateuseone')
])
if kwargs and 'npu' in str(kwargs.get('device', '')):
import torch_npu
kwargs['device'] = kwargs['device'].replace(
'npu', torch.npu.native_device)
'npu', torch_npu.npu.native_device if hasattr(
torch_npu.npu, 'native_device') else 'privateuseone')

device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device is not None:
Expand Down

0 comments on commit 5671b53

Please sign in to comment.