diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e5a7632ac4..5988de1cf2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -950,7 +950,7 @@ def get_data(self, is_train=True, task_key="Default"): batch_data = next(iter(self.validation_data[task_key])) for key in batch_data.keys(): - if key == "sid" or key == "fid": + if key == "sid" or key == "fid" or key == "box": continue elif not isinstance(batch_data[key], list): if batch_data[key] is not None: diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index 7c8b1ae2b8..b60a2eed8f 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -27,14 +27,16 @@ def extend_input_and_build_neighbor_list( ): nframes, nloc = atype.shape[:2] if box is not None: + box_gpu = box.to(coord.device, non_blocking=True) coord_normalized = normalize_coord( coord.view(nframes, nloc, 3), - box.reshape(nframes, 3, 3), + box_gpu.reshape(nframes, 3, 3), ) else: + box_gpu = None coord_normalized = coord.clone() extended_coord, extended_atype, mapping = extend_coord_with_ghosts( - coord_normalized, atype, box, rcut + coord_normalized, atype, box_gpu, rcut, box ) nlist = build_neighbor_list( extended_coord, @@ -262,6 +264,7 @@ def extend_coord_with_ghosts( atype: torch.Tensor, cell: Optional[torch.Tensor], rcut: float, + cell_cpu: Optional[torch.Tensor] = None, ): """Extend the coordinates of the atoms by appending peridoc images. The number of images is large enough to ensure all the neighbors @@ -277,6 +280,8 @@ def extend_coord_with_ghosts( simulation cell tensor of shape [-1, 9]. rcut : float the cutoff radius + cell_cpu : torch.Tensor + cell on cpu for performance Returns ------- @@ -299,8 +304,9 @@ def extend_coord_with_ghosts( else: coord = coord.view([nf, nloc, 3]) cell = cell.view([nf, 3, 3]) + cell_cpu = cell_cpu.view([nf, 3, 3]) if cell_cpu is not None else cell # nf x 3 - to_face = to_face_distance(cell) + to_face = to_face_distance(cell_cpu) # nf x 3 # *2: ghost copies on + and - directions # +1: central cell