From cb5089f6ecaa3976143aa43bce12fb0ad4f08c50 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Mar 2024 03:03:25 -0500 Subject: [PATCH] pt: improve nlist performance 1. use inv_ex instead of inv. `inv_ex` does not check errors. We can assume the input is correct. 2. copy nbuff from device to host once other than 6 times (although copying once is still slow); 3. avoid torch.tensor. Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/nlist.py | 20 ++++++++------------ deepmd/pt/utils/region.py | 2 +- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index cfc75d9438..7c8b1ae2b8 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -307,18 +307,14 @@ def extend_coord_with_ghosts( nbuff = torch.ceil(rcut / to_face).to(torch.long) # 3 nbuff = torch.max(nbuff, dim=0, keepdim=False).values - xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device) - yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device) - zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device) - xyz = xi.view(-1, 1, 1, 1) * torch.tensor( - [1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) - xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor( - [0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) - xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor( - [0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) + nbuff_cpu = nbuff.cpu() + xi = torch.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device=device) + yi = torch.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1, device=device) + zi = torch.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1, device=device) + eye_3 = torch.eye(3, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device) + xyz = xi.view(-1, 1, 1, 1) * eye_3[0] + xyz = xyz + yi.view(1, -1, 1, 1) * eye_3[1] + xyz = xyz + zi.view(1, 1, -1, 1) * eye_3[2] xyz = xyz.view(-1, 3) # ns x 3 shift_idx = xyz[torch.argsort(torch.norm(xyz, dim=1))] diff --git a/deepmd/pt/utils/region.py b/deepmd/pt/utils/region.py index b07d2f73bf..9d811acb9b 100644 --- a/deepmd/pt/utils/region.py +++ b/deepmd/pt/utils/region.py @@ -21,7 +21,7 @@ def phys2inter( the internal coordinates """ - rec_cell = torch.linalg.inv(cell) + rec_cell, _ = torch.linalg.inv_ex(cell) return torch.matmul(coord, rec_cell)