Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: LinearModel Stat #3575

Merged
merged 35 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
fdb6a3d
feat: jit export in linear model
anyangml Mar 21, 2024
a4201fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2024
c0c1dba
fix: revert atomic change
anyangml Mar 22, 2024
21dc19f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2024
a5cc8b1
Merge branch 'devel' into feat/linear-jit
anyangml Mar 22, 2024
cba5823
fix: revert change
anyangml Mar 22, 2024
83686cf
Merge branch 'devel' into feat/linear-jit
anyangml Mar 23, 2024
08e6053
fix: linear bias
anyangml Mar 24, 2024
83e8a5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
e57ac56
fix: index
anyangml Mar 24, 2024
abb2455
fix: UTs
anyangml Mar 25, 2024
7d8afcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
6307ad5
chore: revert changes
anyangml Mar 25, 2024
82f1c39
fix: placeholder
anyangml Mar 25, 2024
798cb3c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
0f6f0f2
fix: Jit
anyangml Mar 25, 2024
eb3919c
Merge branch 'devel' into feat/linear-jit
anyangml Mar 25, 2024
8d59871
Merge branch 'devel' into feat/linear-jit
anyangml Apr 7, 2024
6fda3d6
fix: remove get/set_out_bias
anyangml Apr 7, 2024
a5c6f2d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2024
cbffc2b
fix: log
anyangml Apr 7, 2024
a3ca59e
Merge branch 'devel' into feat/linear-jit
anyangml Apr 7, 2024
33b33be
fix:UTs
anyangml Apr 7, 2024
3b6f65c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2024
9f0979d
fix: UTs
anyangml Apr 7, 2024
919dad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2024
89fe042
chore: remove bias_atom_e
anyangml Apr 7, 2024
90b8d56
chore:use forward_common_atomic
anyangml Apr 7, 2024
1ef891e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2024
f32614f
feat: add UTs
anyangml Apr 7, 2024
1b233e7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2024
97bb808
fix: precommit
anyangml Apr 7, 2024
2dec4cd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2024
6669939
fix: cuda
anyangml Apr 8, 2024
c8010bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
anyangml marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,31 @@
"""
return True

@torch.jit.export

Check warning on line 95 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L95

Added line #L95 was not covered by tests
def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())

@torch.jit.export

Check warning on line 100 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L100

Added line #L100 was not covered by tests
def get_type_map(self) -> List[str]:
"""Get the type map."""
return self.type_map

@torch.jit.export

Check warning on line 105 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L105

Added line #L105 was not covered by tests
def get_model_rcuts(self) -> List[float]:
"""Get the cut-off radius for each individual models."""
return [model.get_rcut() for model in self.models]

@torch.jit.export

Check warning on line 110 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L110

Added line #L110 was not covered by tests
def get_sel(self) -> List[int]:
return [max([model.get_nsel() for model in self.models])]

@torch.jit.export

Check warning on line 114 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L114

Added line #L114 was not covered by tests
def get_model_nsels(self) -> List[int]:
"""Get the processed sels for each individual models. Not distinguishing types."""
return [model.get_nsel() for model in self.models]

@torch.jit.export

Check warning on line 119 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L119

Added line #L119 was not covered by tests
def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]
Expand Down Expand Up @@ -289,15 +295,18 @@
for _ in range(nmodels)
]

@torch.jit.export

Check warning on line 298 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L298

Added line #L298 was not covered by tests
def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
# tricky...
return max([model.get_dim_fparam() for model in self.models])

@torch.jit.export

Check warning on line 304 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L304

Added line #L304 was not covered by tests
def get_dim_aparam(self) -> int:
"""Get the number (dimension) of atomic parameters of this atomic model."""
return max([model.get_dim_aparam() for model in self.models])

@torch.jit.export

Check warning on line 309 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L309

Added line #L309 was not covered by tests
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Expand All @@ -318,6 +327,7 @@
)
).tolist()

@torch.jit.export

Check warning on line 330 in deepmd/pt/model/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/atomic_model/linear_atomic_model.py#L330

Added line #L330 was not covered by tests
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

Expand Down
10 changes: 4 additions & 6 deletions source/tests/pt/model/test_linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,11 @@ def test_self_consistency(self):

def test_jit(self):
md1 = torch.jit.script(self.md1)
# atomic model no more export methods
# self.assertEqual(md1.get_rcut(), self.rcut)
# self.assertEqual(md1.get_type_map(), ["foo", "bar"])
self.assertEqual(md1.get_rcut(), self.rcut)
self.assertEqual(md1.get_type_map(), ["foo", "bar"])
md3 = torch.jit.script(self.md3)
# atomic model no more export methods
# self.assertEqual(md3.get_rcut(), self.rcut)
# self.assertEqual(md3.get_type_map(), ["foo", "bar"])
self.assertEqual(md3.get_rcut(), self.rcut)
self.assertEqual(md3.get_type_map(), ["foo", "bar"])


class TestRemmapMethod(unittest.TestCase):
Expand Down