Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pt compress commad line #4300

Merged
merged 33 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
13f8620
add pt compress commad line
cherryWangY Nov 1, 2024
cfca381
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
28e0f95
seperate compress
cherryWangY Nov 2, 2024
5ba079f
add compress suffix support
cherryWangY Nov 2, 2024
6e20942
fix compress argument
cherryWangY Nov 2, 2024
33e6417
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
7985bc0
remove redundant code
cherryWangY Nov 2, 2024
2d0bbbc
add file suffix for compress
cherryWangY Nov 2, 2024
b9acf04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
0afd3f3
Merge remote-tracking branch 'upstream/devel' into addCommandLine
cherryWangY Nov 2, 2024
14ac57b
add pt command line end-to-end test
cherryWangY Nov 5, 2024
c91b8bb
add main for test
cherryWangY Nov 5, 2024
f087949
Merge remote-tracking branch 'upstream/devel' into addCommandLine
cherryWangY Nov 5, 2024
b186980
fix(pt): store `min_nbor_dist` in the state dict
njzjz Nov 5, 2024
d84e32c
Merge pull request #1 from njzjz/pt-fix-min-nbor-dist-state-dict
cherryWangY Nov 5, 2024
41b5c2b
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 5, 2024
0ed83bb
info type error version
cherryWangY Nov 6, 2024
232511e
add ParameterDict for compression
cherryWangY Nov 6, 2024
1e6d388
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 6, 2024
9d36d83
this seems work
njzjz Nov 7, 2024
73d7414
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 7, 2024
03ed3c7
vectorized tabulate
cherryWangY Nov 7, 2024
1463745
add ParameterList for pt compression command line
cherryWangY Nov 7, 2024
ce089b9
fix device error
cherryWangY Nov 7, 2024
7d8d190
try to fix min_nbor_list assign error
cherryWangY Nov 7, 2024
fe552fb
change compress examples and default value
cherryWangY Nov 7, 2024
b38a62e
remove redundant notaion
cherryWangY Nov 7, 2024
5dd73df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
2016a94
set default value and examples
cherryWangY Nov 8, 2024
d4fc7ef
remove redundant code
cherryWangY Nov 8, 2024
a84f161
Merge branch 'devel' into addCommandLine
njzjz Nov 8, 2024
33d5fd9
fix tests
njzjz Nov 8, 2024
dbae8d0
fix failing tests
njzjz Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,29 +424,30 @@ def main_parser() -> argparse.ArgumentParser:
parser_compress = subparsers.add_parser(
"compress",
parents=[parser_log, parser_mpi_log],
help="(Supported backend: TensorFlow) compress a model",
help="Compress a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp compress
dp compress -i graph.pb -o compressed.pb
dp --pt compress -i model.pth -o compressed.pth
"""
),
)
parser_compress.add_argument(
"-i",
"--input",
default="frozen_model.pb",
default="frozen_model",
type=str,
help="The original frozen model, which will be compressed by the code",
help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-o",
"--output",
default="frozen_model_compressed.pb",
default="frozen_model_compressed",
type=str,
help="The compressed model",
help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-s",
Expand Down
31 changes: 31 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

import torch

from deepmd.pt.model.model import (
get_model,
)


def enable_compression(
input_file: str,
output: str,
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())

model.enable_compression(
extrapolate,
stride,
stride * 10,
check_frequency,
)

model = torch.jit.script(model)
torch.jit.save(model, output)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 13 additions & 0 deletions deepmd/pt/entrypoints/main.py
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.entrypoints.compress import (
enable_compression,
)
from deepmd.pt.infer import (
inference,
)
Expand Down Expand Up @@ -549,6 +552,16 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
elif FLAGS.command == "compress":
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
enable_compression(
input_file=FLAGS.input,
output=FLAGS.output,
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
6 changes: 6 additions & 0 deletions deepmd/tf/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
elif args.command == "transfer":
transfer(**dict_args)
elif args.command == "compress":
dict_args["input"] = format_model_suffix(
dict_args["input"], preferred_backend=args.backend, strict_prefer=True
)
dict_args["output"] = format_model_suffix(
dict_args["output"], preferred_backend=args.backend, strict_prefer=True
)
compress(**dict_args)
elif args.command == "convert-from":
convert(**dict_args)
Expand Down
8 changes: 8 additions & 0 deletions source/tests/pt/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import pathlib
from typing import (
Optional,
Union,
Expand All @@ -7,6 +8,7 @@
import numpy as np
import torch

from deepmd.common import j_loader as dp_j_loader
from deepmd.main import (
main,
)
Expand All @@ -15,6 +17,12 @@
GLOBAL_PT_FLOAT_PRECISION,
)

tests_path = pathlib.Path(__file__).parent.absolute()


def j_loader(filename):
return dp_j_loader(tests_path / filename)


def run_dp(cmd: str) -> int:
"""Run DP directly from the entry point instead of the subprocess.
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
6 changes: 6 additions & 0 deletions source/tests/pt/model_compression/data/type.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
0
1
1
0
1
1
2 changes: 2 additions & 0 deletions source/tests/pt/model_compression/data/type_map.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
O
H
85 changes: 85 additions & 0 deletions source/tests/pt/model_compression/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
{
"_comment1": " model parameters",
"model": {
"type_map": [
"O",
"H"
],
"descriptor": {
"type": "se_e2_a",
"sel": [
46,
92
],
"rcut_smth": 0.50,
"rcut": 6.00,
"_comment": "N2=2N1, N2=N1, and otherwise can be tested",
"neuron": [
4,
8,
17,
17
],
"resnet_dt": false,
"axis_neuron": 16,
"seed": 1,
"_comment2": " that's all"
},
"fitting_net": {
"neuron": [
20,
20,
20
],
"resnet_dt": true,
"seed": 1,
"_comment3": " that's all"
},
"_comment4": " that's all"
},

"learning_rate": {
"type": "exp",
"decay_steps": 5000,
"start_lr": 0.001,
"stop_lr": 3.51e-8,
"_comment5": "that's all"
},

"loss": {
"type": "ener",
"start_pref_e": 0.02,
"limit_pref_e": 1,
"start_pref_f": 1000,
"limit_pref_f": 1,
"start_pref_v": 0,
"limit_pref_v": 0,
"_comment6": " that's all"
},

"training": {
"training_data": {
"systems": [
"model_compression/data"
],
"batch_size": "auto",
"_comment7": "that's all"
},
"validation_data": {
"systems": [
"model_compression/data"
],
"batch_size": 1,
"numb_btch": 3,
"_comment8": "that's all"
},
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
"numb_steps": 1,
"seed": 10,
"disp_file": "lcurve.out",
"disp_freq": 1,
"save_freq": 1,
"_comment9": "that's all"
},
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved

"_comment10": "that's all"
}
Loading
Loading