-
Notifications
You must be signed in to change notification settings - Fork 523
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: correct
exclude_types
in descriptors (#3841)
1. make `exclude_types` consistent with mask in nlist for all descriptors. (bugs fixed in dpa1 and dpa2) 2. add universal tests for descriptor. (now only test_exclude_types) 3. `TestCaseSingleFrameWithNlist` in source/tests/pt/model/test_env_mat.py will be removed in a seperate PR. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced compatibility parameter `ntypes` in several descriptor classes for improved flexibility. - Added new test case classes and utility functions for simulation testing of atomic models. - **Bug Fixes** - Enhanced logic for exclusion masks and neighbor lists in various descriptor methods to improve accuracy. - **Tests** - Added comprehensive tests for forward propagation and type exclusion in atomic models, ensuring consistency and correctness. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
- Loading branch information
Showing
16 changed files
with
450 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import numpy as np | ||
|
||
|
||
# originally copied from source/tests/pt/model/test_env_mat.py | ||
class TestCaseSingleFrameWithNlist: | ||
def setUp(self): | ||
# nloc == 3, nall == 4 | ||
self.nloc = 3 | ||
self.nall = 4 | ||
self.nf, self.nt = 2, 2 | ||
self.coord_ext = np.array( | ||
[ | ||
[0, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 1], | ||
[0, -2, 0], | ||
], | ||
dtype=np.float64, | ||
).reshape([1, self.nall, 3]) | ||
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall]) | ||
self.mapping = np.array([0, 1, 2, 0], dtype=int).reshape([1, self.nall]) | ||
# sel = [5, 2] | ||
self.sel = [5, 2] | ||
self.sel_mix = [7] | ||
self.natoms = [3, 3, 2, 1] | ||
self.nlist = np.array( | ||
[ | ||
[1, 3, -1, -1, -1, 2, -1], | ||
[0, -1, -1, -1, -1, 2, -1], | ||
[0, 1, -1, -1, -1, -1, -1], | ||
], | ||
dtype=int, | ||
).reshape([1, self.nloc, sum(self.sel)]) | ||
self.rcut = 2.2 | ||
self.rcut_smth = 0.4 | ||
# permutations | ||
self.perm = np.array([2, 0, 1, 3], dtype=np.int32) | ||
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32) | ||
# permute the coord and atype | ||
self.coord_ext = np.concatenate( | ||
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 | ||
).reshape(self.nf, self.nall * 3) | ||
self.atype_ext = np.concatenate( | ||
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0 | ||
) | ||
self.mapping = np.concatenate( | ||
[self.mapping, self.mapping[:, self.perm]], axis=0 | ||
) | ||
|
||
# permute the nlist | ||
nlist1 = self.nlist[:, self.perm[: self.nloc], :] | ||
mask = nlist1 == -1 | ||
nlist1 = inv_perm[nlist1] | ||
nlist1 = np.where(mask, -1, nlist1) | ||
self.nlist = np.concatenate([self.nlist, nlist1], axis=0) | ||
self.atol = 1e-12 | ||
|
||
|
||
class TestCaseSingleFrameWithNlistWithVirtual: | ||
def setUp(self): | ||
# nloc == 3, nall == 4 | ||
self.nloc = 4 | ||
self.nall = 5 | ||
self.nf, self.nt = 2, 2 | ||
self.coord_ext = np.array( | ||
[ | ||
[0, 0, 0], | ||
[0, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 1], | ||
[0, -2, 0], | ||
], | ||
dtype=np.float64, | ||
).reshape([1, self.nall, 3]) | ||
self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall]) | ||
# sel = [5, 2] | ||
self.sel = [5, 2] | ||
self.sel_mix = [7] | ||
self.natoms = [3, 3, 2, 1] | ||
self.nlist = np.array( | ||
[ | ||
[2, 4, -1, -1, -1, 3, -1], | ||
[-1, -1, -1, -1, -1, -1, -1], | ||
[0, -1, -1, -1, -1, 3, -1], | ||
[0, 2, -1, -1, -1, -1, -1], | ||
], | ||
dtype=int, | ||
).reshape([1, self.nloc, sum(self.sel)]) | ||
self.rcut = 2.2 | ||
self.rcut_smth = 0.4 | ||
# permutations | ||
self.perm = np.array([3, 0, 1, 2, 4], dtype=np.int32) | ||
inv_perm = np.argsort(self.perm) | ||
# permute the coord and atype | ||
self.coord_ext = np.concatenate( | ||
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 | ||
).reshape(self.nf, self.nall * 3) | ||
self.atype_ext = np.concatenate( | ||
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0 | ||
) | ||
# permute the nlist | ||
nlist1 = self.nlist[:, self.perm[: self.nloc], :] | ||
mask = nlist1 == -1 | ||
nlist1 = inv_perm[nlist1] | ||
nlist1 = np.where(mask, -1, nlist1) | ||
self.nlist = np.concatenate([self.nlist, nlist1], axis=0) | ||
self.get_real_mapping = np.array([[0, 2, 3], [0, 1, 3]], dtype=np.int32) | ||
self.atol = 1e-12 | ||
|
||
|
||
class TestCaseSingleFrameWithoutNlist: | ||
def setUp(self): | ||
# nloc == 3, nall == 4 | ||
self.nloc = 3 | ||
self.nf, self.nt = 1, 2 | ||
self.coord = np.array( | ||
[ | ||
[0, 0, 0], | ||
[0, 1, 0], | ||
[0, 0, 1], | ||
], | ||
dtype=np.float64, | ||
).reshape([1, self.nloc * 3]) | ||
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc]) | ||
self.cell = 2.0 * np.eye(3).reshape([1, 9]) | ||
# sel = [5, 2] | ||
self.sel = [16, 8] | ||
self.sel_mix = [24] | ||
self.natoms = [3, 3, 2, 1] | ||
self.rcut = 2.2 | ||
self.rcut_smth = 0.4 | ||
self.atol = 1e-12 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later |
11 changes: 11 additions & 0 deletions
11
source/tests/universal/common/cases/descriptor/descriptor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
|
||
from .utils import ( | ||
DescriptorTestCase, | ||
) | ||
|
||
|
||
class DescriptorTest(DescriptorTestCase): | ||
def setUp(self) -> None: | ||
DescriptorTestCase.setUp(self) |
Oops, something went wrong.