Skip to content

Commit

Permalink
Add pt compress commad line (#4300)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced model compression functionality, allowing users to compress
models directly via the command-line interface.
	- Added a new command option `"compress"` to trigger model compression.
- Enhanced help messages and examples for the `"compress"` command to
clarify usage with different backends.
- Added a comprehensive JSON configuration file for model compression
parameters.
- Improved handling of compression parameters within descriptor classes
for better organization and efficiency.

- **Bug Fixes**
- Improved error handling for unsupported file formats during model
loading.

- **Tests**
- Introduced a suite of unit tests to evaluate the functionality of
model compression, ensuring accuracy and performance across different
configurations.
- Enhanced tests for loading model parameters to ensure all required
attributes are correctly handled.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2024
1 parent 3701566 commit 0c5ab07
Show file tree
Hide file tree
Showing 25 changed files with 1,003 additions and 199 deletions.
13 changes: 7 additions & 6 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,29 +424,30 @@ def main_parser() -> argparse.ArgumentParser:
parser_compress = subparsers.add_parser(
"compress",
parents=[parser_log, parser_mpi_log],
help="(Supported backend: TensorFlow) compress a model",
help="Compress a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp compress
dp compress -i graph.pb -o compressed.pb
dp --tf compress -i frozen_model.pb -o compressed_model.pb
dp --pt compress -i frozen_model.pth -o compressed_model.pth
"""
),
)
parser_compress.add_argument(
"-i",
"--input",
default="frozen_model.pb",
default="frozen_model",
type=str,
help="The original frozen model, which will be compressed by the code",
help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-o",
"--output",
default="frozen_model_compressed.pb",
default="frozen_model_compressed",
type=str,
help="The compressed model",
help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
)
parser_compress.add_argument(
"-s",
Expand Down
31 changes: 31 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json

import torch

from deepmd.pt.model.model import (
get_model,
)


def enable_compression(
input_file: str,
output: str,
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())

model.enable_compression(
extrapolate,
stride,
stride * 10,
check_frequency,
)

model = torch.jit.script(model)
torch.jit.save(model, output)
21 changes: 19 additions & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from deepmd.pt.cxx_op import (
ENABLE_CUSTOMIZED_OP,
)
from deepmd.pt.entrypoints.compress import (
enable_compression,
)
from deepmd.pt.infer import (
inference,
)
Expand Down Expand Up @@ -346,10 +349,14 @@ def train(
# save min_nbor_dist
if min_nbor_dist is not None:
if not multi_task:
trainer.model.min_nbor_dist = min_nbor_dist
trainer.model.min_nbor_dist = torch.tensor(
min_nbor_dist, dtype=torch.float64, device=DEVICE
)
else:
for model_item in min_nbor_dist:
trainer.model[model_item].min_nbor_dist = min_nbor_dist[model_item]
trainer.model[model_item].min_nbor_dist = torch.tensor(
min_nbor_dist[model_item], dtype=torch.float64, device=DEVICE
)
trainer.run()


Expand Down Expand Up @@ -549,6 +556,16 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
model_branch=FLAGS.model_branch,
output=FLAGS.output,
)
elif FLAGS.command == "compress":
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
enable_compression(
input_file=FLAGS.input,
output=FLAGS.output,
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
88 changes: 53 additions & 35 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
import torch
import torch.nn as nn

from deepmd.dpmodel.utils.seed import (
child_seed,
Expand Down Expand Up @@ -437,10 +438,6 @@ def update_sel(
class DescrptBlockSeA(DescriptorBlock):
ndescrpt: Final[int]
__constants__: ClassVar[list] = ["ndescrpt"]
lower: dict[str, int]
upper: dict[str, int]
table_data: dict[str, torch.Tensor]
table_config: list[Union[int, float]]

def __init__(
self,
Expand Down Expand Up @@ -500,13 +497,6 @@ def __init__(
self.register_buffer("mean", mean)
self.register_buffer("stddev", stddev)

# add for compression
self.compress = False
self.lower = {}
self.upper = {}
self.table_data = {}
self.table_config = []

ndim = 1 if self.type_one_side else 2
filter_layers = NetworkCollection(
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
Expand All @@ -529,6 +519,21 @@ def __init__(
for param in self.parameters():
param.requires_grad = trainable

# add for compression
self.compress = False
self.compress_info = nn.ParameterList(
[
nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))
for _ in range(len(self.filter_layers.networks))
]
)
self.compress_data = nn.ParameterList(
[
nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))
for _ in range(len(self.filter_layers.networks))
]
)

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
Expand Down Expand Up @@ -667,16 +672,39 @@ def reinit_exclude(

def enable_compression(
self,
table_data,
table_config,
lower,
upper,
table_data: dict[str, torch.Tensor],
table_config: list[Union[int, float]],
lower: dict[str, int],
upper: dict[str, int],
) -> None:
for embedding_idx, ll in enumerate(self.filter_layers.networks):
if self.type_one_side:
ii = embedding_idx
ti = -1
else:
# ti: center atom type, ii: neighbor type...
ii = embedding_idx // self.ntypes
ti = embedding_idx % self.ntypes
if self.type_one_side:
net = "filter_-1_net_" + str(ii)
else:
net = "filter_" + str(ti) + "_net_" + str(ii)
info_ii = torch.as_tensor(
[
lower[net],
upper[net],
upper[net] * table_config[0],
table_config[1],
table_config[2],
table_config[3],
],
dtype=self.prec,
device="cpu",
)
tensor_data_ii = table_data[net].to(device=env.DEVICE, dtype=self.prec)
self.compress_data[embedding_idx] = tensor_data_ii
self.compress_info[embedding_idx] = info_ii
self.compress = True
self.table_data = table_data
self.table_config = table_config
self.lower = lower
self.upper = upper

def forward(
self,
Expand Down Expand Up @@ -724,7 +752,9 @@ def forward(
)
# nfnl x nnei
exclude_mask = self.emask(nlist, extended_atype).view(nfnl, self.nnei)
for embedding_idx, ll in enumerate(self.filter_layers.networks):
for embedding_idx, (ll, compress_data_ii, compress_info_ii) in enumerate(
zip(self.filter_layers.networks, self.compress_data, self.compress_info)
):
if self.type_one_side:
ii = embedding_idx
ti = -1
Expand All @@ -751,23 +781,11 @@ def forward(
ss = rr[:, :, :1]

if self.compress:
if self.type_one_side:
net = "filter_-1_net_" + str(ii)
else:
net = "filter_" + str(ti) + "_net_" + str(ii)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf
tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec)

gr = torch.ops.deepmd.tabulate_fusion_se_a(
tensor_data.contiguous(),
torch.tensor(info, dtype=self.prec, device="cpu").contiguous(),
compress_data_ii.contiguous(),
compress_info_ii.cpu().contiguous(),
ss.contiguous(),
rr.contiguous(),
self.filter_neuron[-1],
Expand Down
56 changes: 26 additions & 30 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ def tabulate_fusion_se_atten(

@DescriptorBlock.register("se_atten")
class DescrptBlockSeAtten(DescriptorBlock):
lower: dict[str, int]
upper: dict[str, int]
table_data: dict[str, torch.Tensor]
table_config: list[Union[int, float]]

def __init__(
self,
rcut: float,
Expand Down Expand Up @@ -202,14 +197,6 @@ def __init__(
ln_eps = 1e-5
self.ln_eps = ln_eps

# add for compression
self.compress = False
self.is_sorted = False
self.lower = {}
self.upper = {}
self.table_data = {}
self.table_config = []

if isinstance(sel, int):
sel = [sel]

Expand Down Expand Up @@ -282,6 +269,16 @@ def __init__(
self.filter_layers_strip = filter_layers_strip
self.stats = None

# add for compression
self.compress = False
self.is_sorted = False
self.compress_info = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device="cpu"))]
)
self.compress_data = nn.ParameterList(
[nn.Parameter(torch.zeros(0, dtype=self.prec, device=env.DEVICE))]
)

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return self.rcut
Expand Down Expand Up @@ -431,11 +428,21 @@ def enable_compression(
lower,
upper,
) -> None:
net = "filter_net"
self.compress_info[0] = torch.as_tensor(
[
lower[net],
upper[net],
upper[net] * table_config[0],
table_config[1],
table_config[2],
table_config[3],
],
dtype=self.prec,
device="cpu",
)
self.compress_data[0] = table_data[net].to(device=env.DEVICE, dtype=self.prec)
self.compress = True
self.table_data = table_data
self.table_config = table_config
self.lower = lower
self.upper = upper

def forward(
self,
Expand Down Expand Up @@ -544,15 +551,6 @@ def forward(
xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg)
elif self.tebd_input_mode in ["strip"]:
if self.compress:
net = "filter_net"
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
ss = ss.reshape(-1, 1)
# nfnl x nnei x ng
# gg_s = self.filter_layers.networks[0](ss)
Expand All @@ -569,14 +567,12 @@ def forward(
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
# nfnl x nnei x ng
# gg = gg_s * gg_t + gg_s
tensor_data = self.table_data[net].to(gg_t.device).to(dtype=self.prec)
info_tensor = torch.tensor(info, dtype=self.prec, device="cpu")
gg_t = gg_t.reshape(-1, gg_t.size(-1))
# Convert all tensors to the required precision at once
ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t))
xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten(
tensor_data.contiguous(),
info_tensor.contiguous(),
self.compress_data[0].contiguous(),
self.compress_info[0].cpu().contiguous(),
ss.contiguous(),
rr.contiguous(),
gg_t.contiguous(),
Expand Down
Loading

0 comments on commit 0c5ab07

Please sign in to comment.