Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pt compress commad line #4300

Merged
merged 33 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
13f8620
add pt compress commad line
cherryWangY Nov 1, 2024
cfca381
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
28e0f95
seperate compress
cherryWangY Nov 2, 2024
5ba079f
add compress suffix support
cherryWangY Nov 2, 2024
6e20942
fix compress argument
cherryWangY Nov 2, 2024
33e6417
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
7985bc0
remove redundant code
cherryWangY Nov 2, 2024
2d0bbbc
add file suffix for compress
cherryWangY Nov 2, 2024
b9acf04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2024
0afd3f3
Merge remote-tracking branch 'upstream/devel' into addCommandLine
cherryWangY Nov 2, 2024
14ac57b
add pt command line end-to-end test
cherryWangY Nov 5, 2024
c91b8bb
add main for test
cherryWangY Nov 5, 2024
f087949
Merge remote-tracking branch 'upstream/devel' into addCommandLine
cherryWangY Nov 5, 2024
b186980
fix(pt): store `min_nbor_dist` in the state dict
njzjz Nov 5, 2024
d84e32c
Merge pull request #1 from njzjz/pt-fix-min-nbor-dist-state-dict
cherryWangY Nov 5, 2024
41b5c2b
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 5, 2024
0ed83bb
info type error version
cherryWangY Nov 6, 2024
232511e
add ParameterDict for compression
cherryWangY Nov 6, 2024
1e6d388
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 6, 2024
9d36d83
this seems work
njzjz Nov 7, 2024
73d7414
Merge branch 'addCommandLine' of https://github.com/cherryWangY/deepm…
cherryWangY Nov 7, 2024
03ed3c7
vectorized tabulate
cherryWangY Nov 7, 2024
1463745
add ParameterList for pt compression command line
cherryWangY Nov 7, 2024
ce089b9
fix device error
cherryWangY Nov 7, 2024
7d8d190
try to fix min_nbor_list assign error
cherryWangY Nov 7, 2024
fe552fb
change compress examples and default value
cherryWangY Nov 7, 2024
b38a62e
remove redundant notaion
cherryWangY Nov 7, 2024
5dd73df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 7, 2024
2016a94
set default value and examples
cherryWangY Nov 8, 2024
d4fc7ef
remove redundant code
cherryWangY Nov 8, 2024
a84f161
Merge branch 'devel' into addCommandLine
njzjz Nov 8, 2024
33d5fd9
fix tests
njzjz Nov 8, 2024
dbae8d0
fix failing tests
njzjz Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,29 +424,30 @@ def main_parser() -> argparse.ArgumentParser:
parser_compress = subparsers.add_parser(
"compress",
parents=[parser_log, parser_mpi_log],
help="(Supported backend: TensorFlow) compress a model",
help="Compress a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp compress
dp compress -i graph.pb -o compressed.pb
dp --tf compress -i frozen_model.pb -o compressed_model.pb
dp --pt compress -i frozen_model.pth -o compressed_model.pth
"""
),
)
parser_compress.add_argument(
"-i",
"--input",
default="frozen_model.pb",
default="frozen_model",
type=str,
help="The original frozen model, which will be compressed by the code",
help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-o",
"--output",
default="frozen_model_compressed.pb",
default="frozen_model_compressed",
type=str,
help="The compressed model",
help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-s",
Expand Down
31 changes: 31 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

import torch

from deepmd.pt.model.model import (
get_model,
)


def enable_compression(
input_file: str,
output: str,
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())

model.enable_compression(
extrapolate,
stride,
stride * 10,
check_frequency,
)

model = torch.jit.script(model)
torch.jit.save(model, output)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
21 changes: 19 additions & 2 deletions deepmd/pt/entrypoints/main.py
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.entrypoints.compress import (
enable_compression,
)
from deepmd.pt.infer import (
inference,
)
Expand Down Expand Up @@ -346,10 +349,14 @@ def train(
# save min_nbor_dist
if min_nbor_dist is not None:
if not multi_task:
trainer.model.min_nbor_dist = min_nbor_dist
trainer.model.min_nbor_dist = torch.tensor(
min_nbor_dist, dtype=torch.float64, device=DEVICE
)
else:
for model_item in min_nbor_dist:
trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item]
trainer.model[model_item].min_nbor_dist = torch.tensor(
min_nbor_dist[model_item], dtype=torch.float64, device=DEVICE
)
trainer.run()


Expand Down Expand Up @@ -549,6 +556,16 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
elif FLAGS.command == "compress":
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
enable_compression(
input_file=FLAGS.input,
output=FLAGS.output,
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
88 changes: 53 additions & 35 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
import torch.nn as nn

from deepmd.dpmodel.utils.seed import (
child_seed,
Expand Down Expand Up @@ -437,10 +438,6 @@ def update_sel(
class DescrptBlockSeA(DescriptorBlock):
ndescrpt: Final[int]
__constants__: ClassVar[list] = ["ndescrpt"]
lower: dict[str, int]
upper: dict[str, int]
table_data: dict[str, torch.Tensor]
table_config: list[Union[int, float]]

def __init__(
self,
Expand Down Expand Up @@ -500,13 +497,6 @@ def __init__(
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)

# add for compression
self.compress = False
self.lower = {}
self.upper = {}
self.table_data = {}
self.table_config = []

ndim = 1 if self.type_one_side else 2
filter_layers = NetworkCollection(
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
Expand All @@ -529,6 +519,21 @@ def __init__(
for param in self.parameters():
param.requires_grad = trainable

# add for compression
self.compress = False
self.compress_info = nn.ParameterList(
[
nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))
for _ in range(len(self.filter_layers.networks))
]
)
self.compress_data = nn.ParameterList(
[
nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))
for _ in range(len(self.filter_layers.networks))
]
)

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
Expand Down Expand Up @@ -667,16 +672,39 @@ def reinit_exclude(

def enable_compression(
self,
table_data,
table_config,
lower,
upper,
table_data: dict[str, torch.Tensor],
table_config: list[Union[int, float]],
lower: dict[str, int],
upper: dict[str, int],
) -> None:
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
ti = -1
else:
# ti: center atom type, ii: neighbor type...
ii = embedding_idx // self.ntypes
ti = embedding_idx % self.ntypes
if self.type_one_side:
net = "filter_-1_net_" + str(ii)
else:
net = "filter_" + str(ti) + "_net_" + str(ii)
info_ii = torch.as_tensor(
[
lower[net],
upper[net],
upper[net] * table_config[0],
table_config[1],
table_config[2],
table_config[3],
],
dtype=self.prec,
device="cpu",
)
tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec)
self.compress_data[embedding_idx] = tensor_data_ii
self.compress_info[embedding_idx] = info_ii
self.compress = True
self.table_data = table_data
self.table_config = table_config
self.lower = lower
self.upper = upper

def forward(
self,
Expand Down Expand Up @@ -724,7 +752,9 @@ def forward(
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)
for embedding_idx, ll in enumerate(self.filter_layers.networks):
for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate(
zip(self.filter_layers.networks, self.compress_data, self.compress_info)
):
if self.type_one_side:
ii = embedding_idx
ti = -1
Expand All @@ -751,23 +781,11 @@ def forward(
ss = rr[:, :, :1]

if self.compress:
if self.type_one_side:
net = "filter_-1_net_" + str(ii)
else:
net = "filter_" + str(ti) + "_net_" + str(ii)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf
tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec)

gr = torch.ops.deepmd.tabulate_fusion_se_a(
tensor_data.contiguous(),
torch.tensor(info, dtype=self.prec, device="cpu").contiguous(),
compress_data_ii.contiguous(),
compress_info_ii.cpu().contiguous(),
ss.contiguous(),
rr.contiguous(),
self.filter_neuron[-1],
Expand Down
56 changes: 26 additions & 30 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def tabulate_fusion_se_atten(

@DescriptorBlock.register("se_atten")
class DescrptBlockSeAtten(DescriptorBlock):
lower: dict[str, int]
upper: dict[str, int]
table_data: dict[str, torch.Tensor]
table_config: list[Union[int, float]]

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -202,14 +197,6 @@ def __init__(
ln_eps = 1e-5
self.ln_eps = ln_eps

# add for compression
self.compress = False
self.is_sorted = False
self.lower = {}
self.upper = {}
self.table_data = {}
self.table_config = []

if isinstance(sel, int):
sel = [sel]

Expand Down Expand Up @@ -282,6 +269,16 @@ def __init__(
self.filter_layers_strip = filter_layers_strip
self.stats = None

# add for compression
self.compress = False
self.is_sorted = False
self.compress_info = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
)
self.compress_data = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
)

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
Expand Down Expand Up @@ -431,11 +428,21 @@ def enable_compression(
lower,
upper,
) -> None:
net = "filter_net"
self.compress_info[0] = torch.as_tensor(
[
lower[net],
upper[net],
upper[net] * table_config[0],
table_config[1],
table_config[2],
table_config[3],
],
dtype=self.prec,
device="cpu",
)
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
cherryWangY marked this conversation as resolved.
Show resolved Hide resolved
self.compress = True
self.table_data = table_data
self.table_config = table_config
self.lower = lower
self.upper = upper

def forward(
self,
Expand Down Expand Up @@ -544,15 +551,6 @@ def forward(
xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg)
elif self.tebd_input_mode in ["strip"]:
if self.compress:
net = "filter_net"
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
ss = ss.reshape(-1, 1)
# nfnl x nnei x ng
# gg_s = self.filter_layers.networks[0](ss)
Expand All @@ -569,14 +567,12 @@ def forward(
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
# nfnl x nnei x ng
# gg = gg_s * gg_t + gg_s
tensor_data = self.table_data[net].to(gg_t.device).to(dtype=self.prec)
info_tensor = torch.tensor(info, dtype=self.prec, device="cpu")
gg_t = gg_t.reshape(-1, gg_t.size(-1))
# Convert all tensors to the required precision at once
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
tensor_data.contiguous(),
info_tensor.contiguous(),
self.compress_data[0].contiguous(),
self.compress_info[0].cpu().contiguous(),
ss.contiguous(),
rr.contiguous(),
gg_t.contiguous(),
Expand Down
Loading
Loading