Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pt compress commad line #4300

Merged
merged 33 commits into from
Nov 8, 2024
Merged
Changes from 6 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
13f8620
add pt compress commad line
cherryWangY Nov 1, 2024
cfca381
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
28e0f95
seperate compress
cherryWangY Nov 2, 2024
5ba079f
add compress suffix support
cherryWangY Nov 2, 2024
6e20942
fix compress argument
cherryWangY Nov 2, 2024
33e6417
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
7985bc0
remove redundant code
cherryWangY Nov 2, 2024
2d0bbbc
add file suffix for compress
cherryWangY Nov 2, 2024
b9acf04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
0afd3f3
Merge remote-tracking branch 'upstream/devel' into addCommandLine
cherryWangY Nov 2, 2024
14ac57b
add pt command line end-to-end test
cherryWangY Nov 5, 2024
c91b8bb
add main for test
cherryWangY Nov 5, 2024
f087949
Merge remote-tracking branch 'upstream/devel' into addCommandLine
cherryWangY Nov 5, 2024
b186980
fix(pt): store `min_nbor_dist` in the state dict
njzjz Nov 5, 2024
d84e32c
Merge pull request #1 from njzjz/pt-fix-min-nbor-dist-state-dict
cherryWangY Nov 5, 2024
41b5c2b
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 5, 2024
0ed83bb
info type error version
cherryWangY Nov 6, 2024
232511e
add ParameterDict for compression
cherryWangY Nov 6, 2024
1e6d388
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 6, 2024
9d36d83
this seems work
njzjz Nov 7, 2024
73d7414
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 7, 2024
03ed3c7
vectorized tabulate
cherryWangY Nov 7, 2024
1463745
add ParameterList for pt compression command line
cherryWangY Nov 7, 2024
ce089b9
fix device error
cherryWangY Nov 7, 2024
7d8d190
try to fix min_nbor_list assign error
cherryWangY Nov 7, 2024
fe552fb
change compress examples and default value
cherryWangY Nov 7, 2024
b38a62e
remove redundant notaion
cherryWangY Nov 7, 2024
5dd73df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
2016a94
set default value and examples
cherryWangY Nov 8, 2024
d4fc7ef
remove redundant code
cherryWangY Nov 8, 2024
a84f161
Merge branch 'devel' into addCommandLine
njzjz Nov 8, 2024
33d5fd9
fix tests
njzjz Nov 8, 2024
dbae8d0
fix failing tests
njzjz Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions deepmd/main.py
Original file line number Diff line number Diff line change
@@ -424,29 +424,30 @@ def main_parser() -> argparse.ArgumentParser:
parser_compress = subparsers.add_parser(
"compress",
parents=[parser_log, parser_mpi_log],
help="(Supported backend: TensorFlow) compress a model",
help="Compress a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp compress
dp compress -i graph.pb -o compressed.pb
dp --pt compress -i model.pth -o compressed.pth
"""
),
)
parser_compress.add_argument(
"-i",
"--input",
default="frozen_model.pb",
default="frozen_model",
type=str,
help="The original frozen model, which will be compressed by the code",
help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-o",
"--output",
default="frozen_model_compressed.pb",
default="frozen_model_compressed",
type=str,
help="The compressed model",
help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-s",
46 changes: 46 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

import torch

from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)


def enable_compression(
input_file: str,
output: str,
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
):
if input_file.endswith(".pth"):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
elif input_file.endswith(".pt"):
state_dict = torch.load(input_file, map_location="cpu", weights_only=True)
if "model" in state_dict:
state_dict = state_dict["model"]
model_def_script = state_dict["_extra_state"]["model_params"]
model = get_model(model_def_script)
modelwrapper = ModelWrapper(model)
modelwrapper.load_state_dict(state_dict)
model = modelwrapper.model["Default"]

Check warning on line 34 in deepmd/pt/entrypoints/compress.py

Codecov / codecov/patch

deepmd/pt/entrypoints/compress.py#L21-L34

Added lines #L21 - L34 were not covered by tests
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError("PyTorch backend only supports converting .pth or .pt file")

Check warning on line 36 in deepmd/pt/entrypoints/compress.py

Codecov / codecov/patch

deepmd/pt/entrypoints/compress.py#L36

Added line #L36 was not covered by tests

model.enable_compression(

Check warning on line 38 in deepmd/pt/entrypoints/compress.py

Codecov / codecov/patch

deepmd/pt/entrypoints/compress.py#L38

Added line #L38 was not covered by tests
extrapolate,
stride,
stride * 10,
check_frequency,
)

model = torch.jit.script(model)
torch.jit.save(model, output)

Check warning on line 46 in deepmd/pt/entrypoints/compress.py

Codecov / codecov/patch

deepmd/pt/entrypoints/compress.py#L45-L46

Added lines #L45 - L46 were not covered by tests
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 13 additions & 0 deletions deepmd/pt/entrypoints/main.py
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -38,6 +38,9 @@
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.entrypoints.compress import (
enable_compression,
)
from deepmd.pt.infer import (
inference,
)
@@ -549,6 +552,16 @@
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
elif FLAGS.command == "compress":
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
enable_compression(

Check warning on line 558 in deepmd/pt/entrypoints/main.py

Codecov / codecov/patch

deepmd/pt/entrypoints/main.py#L555-L558

Added lines #L555 - L558 were not covered by tests
input_file=FLAGS.input,
output=FLAGS.output,
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")