Skip to content

Commit

Permalink
fix: atomic stat concat, system with diff natoms
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Nov 17, 2024
1 parent 0ad4289 commit 6d4a6cd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,11 @@ def compute_output_stats_atomic(
]
for kk in keys
}
# shape: (nframes, nloc, ndim)
# reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation
# reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation
natoms = {k: [sys_v.reshape(-1,1) for sys_v in v] for k, v in natoms.items()}
outputs = {k: [sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) for sys_idx, sys in enumerate(v)] for k, v in outputs.items()}

merged_output = {
kk: to_numpy_array(torch.cat(outputs[kk]))
for kk in keys
Expand Down

0 comments on commit 6d4a6cd

Please sign in to comment.