From 6d4a6cdfea582d1c5800524ab78e0fb0f290266a Mon Sep 17 00:00:00 2001 From: anyangml Date: Sun, 17 Nov 2024 15:01:18 +0800 Subject: [PATCH] fix: atomic stat concat, system with diff natoms --- deepmd/pt/utils/stat.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 2bcc87ce42..175167edcb 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -546,7 +546,11 @@ def compute_output_stats_atomic( ] for kk in keys } - # shape: (nframes, nloc, ndim) + # reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation + # reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation + natoms = {k: [sys_v.reshape(-1,1) for sys_v in v] for k, v in natoms.items()} + outputs = {k: [sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) for sys_idx, sys in enumerate(v)] for k, v in outputs.items()} + merged_output = { kk: to_numpy_array(torch.cat(outputs[kk])) for kk in keys