Skip to content

Commit

Permalink
Merge branch 'devel' into chore/stat-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml authored Mar 29, 2024
2 parents b0c9801 + 23f67a1 commit e4545cf
Show file tree
Hide file tree
Showing 14 changed files with 462 additions and 266 deletions.
81 changes: 20 additions & 61 deletions backend/dynamic_metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
from pathlib import (
Path,
)
from typing import (
Dict,
List,
Expand All @@ -12,6 +16,11 @@
get_argument_from_env,
)

if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib

__all__ = ["dynamic_metadata"]


Expand All @@ -22,74 +31,24 @@ def __dir__() -> List[str]:
def dynamic_metadata(
field: str,
settings: Optional[Dict[str, object]] = None,
) -> str:
):
assert field in ["optional-dependencies", "entry-points", "scripts"]
_, _, find_libpython_requires, extra_scripts, tf_version = get_argument_from_env()
with Path("pyproject.toml").open("rb") as f:
pyproject = tomllib.load(f)

if field == "scripts":
return {
"dp": "deepmd.main:main",
**pyproject["tool"]["deepmd_build_backend"]["scripts"],
**extra_scripts,
}
elif field == "optional-dependencies":
optional_dependencies = pyproject["tool"]["deepmd_build_backend"][
"optional-dependencies"
]
optional_dependencies["lmp"].extend(find_libpython_requires)
optional_dependencies["ipi"].extend(find_libpython_requires)
return {
"test": [
"dpdata>=0.2.7",
"ase",
"pytest",
"pytest-cov",
"pytest-sugar",
"dpgui",
],
"docs": [
"sphinx>=3.1.1",
"sphinx_rtd_theme>=1.0.0rc1",
"sphinx_markdown_tables",
"myst-nb>=1.0.0rc0",
"myst-parser>=0.19.2",
"sphinx-design",
"breathe",
"exhale",
"numpydoc",
"ase",
"deepmodeling-sphinx>=0.1.0",
"dargs>=0.3.4",
"sphinx-argparse",
"pygments-lammps",
"sphinxcontrib-bibtex",
],
"lmp": [
"lammps~=2023.8.2.3.0",
*find_libpython_requires,
],
"ipi": [
"i-PI",
*find_libpython_requires,
],
"gui": [
"dpgui",
],
**optional_dependencies,
**get_tf_requirement(tf_version),
"cu11": [
"nvidia-cuda-runtime-cu11",
"nvidia-cublas-cu11",
"nvidia-cufft-cu11",
"nvidia-curand-cu11",
"nvidia-cusolver-cu11",
"nvidia-cusparse-cu11",
"nvidia-cudnn-cu11<9",
"nvidia-cuda-nvcc-cu11",
],
"cu12": [
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cufft-cu12",
"nvidia-curand-cu12",
"nvidia-cusolver-cu12",
"nvidia-cusparse-cu12",
"nvidia-cudnn-cu12<9",
"nvidia-cuda-nvcc-cu12",
],
"torch": [
"torch>=2a",
],
}
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
env,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
ActivationFn,
)


Expand Down Expand Up @@ -332,7 +332,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.do_bn_mode = do_bn_mode
self.bn_momentum = bn_momentum
self.act = get_activation_fn(activation_function)
self.act = ActivationFn(activation_function)
self.update_g1_has_grrg = update_g1_has_grrg
self.update_g1_has_drrd = update_g1_has_drrd
self.update_g1_has_conv = update_g1_has_conv
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
PairExcludeMask,
)
from deepmd.pt.utils.utils import (
get_activation_fn,
ActivationFn,
)
from deepmd.utils.env_mat_stat import (
StatItem,
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.g1_dim = g1_dim
self.g2_dim = g2_dim
self.act = get_activation_fn(activation_function)
self.act = ActivationFn(activation_function)
self.direct_dist = direct_dist
self.add_type_ebd_to_seq = add_type_ebd_to_seq
# order matters, placed after the assignment of self.ntypes
Expand Down
7 changes: 3 additions & 4 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

from deepmd.pt.utils.utils import (
ActivationFn,
get_activation_fn,
)


Expand Down Expand Up @@ -470,7 +469,7 @@ class MaskLMHead(nn.Module):
def __init__(self, embed_dim, output_dim, activation_fn, weight=None):
super().__init__()
self.dense = SimpleLinear(embed_dim, embed_dim)
self.activation_fn = get_activation_fn(activation_fn)
self.activation_fn = ActivationFn(activation_fn)
self.layer_norm = nn.LayerNorm(embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION)

if weight is None:
Expand Down Expand Up @@ -818,7 +817,7 @@ def __init__(
self.fc1 = nn.Linear(
self.embed_dim, self.ffn_embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
self.activation_fn = get_activation_fn(activation)
self.activation_fn = ActivationFn(activation)
self.fc2 = nn.Linear(
self.ffn_embed_dim, self.embed_dim, dtype=env.GLOBAL_PT_FLOAT_PRECISION
)
Expand Down Expand Up @@ -1387,7 +1386,7 @@ def __init__(
self.ffn_dim = ffn_dim
self.attn_head = attn_head
self.activation_fn = (
get_activation_fn(activation_fn) if activation_fn is not None else None
ActivationFn(activation_fn) if activation_fn is not None else None
)
self.post_ln = post_ln
self.self_attn_layer_norm = nn.LayerNorm(
Expand Down
78 changes: 41 additions & 37 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,46 +558,50 @@ def update_single_finetune_params(
]
self.wrapper.load_state_dict(state_dict)

def single_model_finetune(
_model,
_model_params,
_sample_func,
):
old_type_map, new_type_map = (
_model_params["type_map"],
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
)
else:
# need to updated
pass
if finetune_model is not None:

# finetune
if not self.multi_task:
single_model_finetune(
self.model, model_params, self.get_sample_func
)
else:
for model_key in self.model_keys:
if model_key in self.finetune_links:
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
def single_model_finetune(
_model,
_model_params,
_sample_func,
):
old_type_map, new_type_map = (
_model_params["type_map"],
_model_params["new_type_map"],
)
if isinstance(_model, EnergyModel):
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_model_params.get(
"bias_adjust_mode", "change-by-statistic"
),
origin_type_map=new_type_map,
full_type_map=old_type_map,
)
else:
log.info(f"Model branch {model_key} will resume training.")
# need to updated
pass

# finetune
if not self.multi_task:
single_model_finetune(
self.model, model_params, self.get_sample_func
)
else:
for model_key in self.model_keys:
if model_key in self.finetune_links:
log.info(
f"Model branch {model_key} will be fine-tuned. This may take a long time..."
)
single_model_finetune(
self.model[model_key],
model_params["model_dict"][model_key],
self.get_sample_func[model_key],
)
else:
log.info(
f"Model branch {model_key} will resume training."
)

if init_frz_model is not None:
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
Expand Down
38 changes: 22 additions & 16 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

import h5py
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing
Expand Down Expand Up @@ -106,29 +107,34 @@ def construct_dataset(system):

self.dataloaders = []
self.batch_sizes = []
for system in self.systems:
if isinstance(batch_size, str):
if batch_size == "auto":
rule = 32
elif batch_size.startswith("auto:"):
rule = int(batch_size.split(":")[1])
else:
rule = None
log.error("Unsupported batch size type")
for ii in self.systems:
ni = ii._natoms
bsi = rule // ni
if bsi * ni < rule:
bsi += 1
self.batch_sizes.append(bsi)
elif isinstance(batch_size, list):
self.batch_sizes = batch_size
else:
self.batch_sizes = batch_size * np.ones(len(systems), dtype=int)
assert len(self.systems) == len(self.batch_sizes)
for system, batch_size in zip(self.systems, self.batch_sizes):
if dist.is_initialized():
system_sampler = DistributedSampler(system)
self.sampler_list.append(system_sampler)
else:
system_sampler = None
if isinstance(batch_size, str):
if batch_size == "auto":
rule = 32
elif batch_size.startswith("auto:"):
rule = int(batch_size.split(":")[1])
else:
rule = None
log.error("Unsupported batch size type")
self.batch_size = rule // system._natoms
if self.batch_size * system._natoms < rule:
self.batch_size += 1
else:
self.batch_size = batch_size
self.batch_sizes.append(self.batch_size)
system_dataloader = DataLoader(
dataset=system,
batch_size=self.batch_size,
batch_size=int(batch_size),
num_workers=0, # Should be 0 to avoid too many threads forked
sampler=system_sampler,
collate_fn=collate_batch,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/multi_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def replace_one_item(params_dict, key_type, key_in_dict, suffix="", index=None):
type_map_keys.append(key_in_dict)
else:
if shared_key not in shared_links:
class_name = get_class_name(shared_type, shared_dict[key_in_dict])
class_name = get_class_name(shared_type, shared_dict[shared_key])
shared_links[shared_key] = {"type": class_name, "links": []}
link_item = {
"model_key": model_key,
Expand Down
21 changes: 0 additions & 21 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Callable,
Optional,
overload,
)
Expand All @@ -18,26 +17,6 @@
from .env import PRECISION_DICT as PT_PRECISION_DICT


def get_activation_fn(activation: str) -> Callable:
"""Returns the activation function corresponding to `activation`."""
if activation.lower() == "relu":
return F.relu
elif activation.lower() == "gelu" or activation.lower() == "gelu_tf":
return lambda x: F.gelu(x, approximate="tanh")
elif activation.lower() == "tanh":
return torch.tanh
elif activation.lower() == "relu6":
return F.relu6
elif activation.lower() == "softplus":
return F.softplus
elif activation.lower() == "sigmoid":
return torch.sigmoid
elif activation.lower() == "linear" or activation.lower() == "none":
return lambda x: x
else:
raise RuntimeError(f"activation function {activation} not supported")


class ActivationFn(torch.nn.Module):
def __init__(self, activation: Optional[str]):
super().__init__()
Expand Down
Loading

0 comments on commit e4545cf

Please sign in to comment.