diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e5a7632ac4..5988de1cf2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -950,7 +950,7 @@ def get_data(self, is_train=True, task_key="Default"): batch_data = next(iter(self.validation_data[task_key])) for key in batch_data.keys(): - if key == "sid" or key == "fid": + if key == "sid" or key == "fid" or key == "box": continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index cfc75d9438..56a062f1b8 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -27,14 +27,16 @@ def extend_input_and_build_neighbor_list( ): nframes, nloc = atype.shape[:2] if box is not None: + box_gpu = box.to(coord.device, non_blocking=True) coord_normalized = normalize_coord( coord.view(nframes, nloc, 3), - box.reshape(nframes, 3, 3), + box_gpu.reshape(nframes, 3, 3), ) else: + box_gpu = None coord_normalized = coord.clone() extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, rcut + coord_normalized, atype, box_gpu, rcut, box ) nlist = build_neighbor_list( extended_coord, @@ -262,6 +264,7 @@ def extend_coord_with_ghosts( atype: torch.Tensor, cell: Optional[torch.Tensor], rcut: float, + cell_cpu: Optional[torch.Tensor] = None, ): """Extend the coordinates of the atoms by appending peridoc images. The number of images is large enough to ensure all the neighbors @@ -277,6 +280,8 @@ def extend_coord_with_ghosts( simulation cell tensor of shape [-1, 9]. rcut : float the cutoff radius + cell_cpu : torch.Tensor + cell on cpu for performance Returns ------- @@ -299,27 +304,25 @@ def extend_coord_with_ghosts( else: coord = coord.view([nf, nloc, 3]) cell = cell.view([nf, 3, 3]) + cell_cpu = cell_cpu.view([nf, 3, 3]) if cell_cpu is not None else cell # nf x 3 - to_face = to_face_distance(cell) + to_face = to_face_distance(cell_cpu) # nf x 3 # *2: ghost copies on + and - directions # +1: central cell nbuff = torch.ceil(rcut / to_face).to(torch.long) # 3 nbuff = torch.max(nbuff, dim=0, keepdim=False).values - xi = torch.arange(-nbuff[0], nbuff[0] + 1, 1, device=device) - yi = torch.arange(-nbuff[1], nbuff[1] + 1, 1, device=device) - zi = torch.arange(-nbuff[2], nbuff[2] + 1, 1, device=device) - xyz = xi.view(-1, 1, 1, 1) * torch.tensor( - [1, 0, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) - xyz = xyz + yi.view(1, -1, 1, 1) * torch.tensor( - [0, 1, 0], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) - xyz = xyz + zi.view(1, 1, -1, 1) * torch.tensor( - [0, 0, 1], dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device - ) + nbuff_cpu = nbuff.cpu() + xi = torch.arange(-nbuff_cpu[0], nbuff_cpu[0] + 1, 1, device="cpu") + yi = torch.arange(-nbuff_cpu[1], nbuff_cpu[1] + 1, 1, device="cpu") + zi = torch.arange(-nbuff_cpu[2], nbuff_cpu[2] + 1, 1, device="cpu") + eye_3 = torch.eye(3, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device="cpu") + xyz = xi.view(-1, 1, 1, 1) * eye_3[0] + xyz = xyz + yi.view(1, -1, 1, 1) * eye_3[1] + xyz = xyz + zi.view(1, 1, -1, 1) * eye_3[2] xyz = xyz.view(-1, 3) + xyz = xyz.to(device=device, non_blocking=True) # ns x 3 shift_idx = xyz[torch.argsort(torch.norm(xyz, dim=1))] ns, _ = shift_idx.shape diff --git a/deepmd/pt/utils/region.py b/deepmd/pt/utils/region.py index b07d2f73bf..9d811acb9b 100644 --- a/deepmd/pt/utils/region.py +++ b/deepmd/pt/utils/region.py @@ -21,7 +21,7 @@ def phys2inter( the internal coordinates """ - rec_cell = torch.linalg.inv(cell) + rec_cell, _ = torch.linalg.inv_ex(cell) return torch.matmul(coord, rec_cell)