Skip to content

Commit

Permalink
Merge branch 'devel' into USE_PT_PYTHON_LIBS
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Mar 26, 2024
2 parents 38b217d + 625e893 commit ba8317d
Show file tree
Hide file tree
Showing 13 changed files with 99 additions and 137 deletions.
31 changes: 19 additions & 12 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
Dict,
List,
Optional,
Set,
TypeVar,
Union,
get_args,
)

try:
Expand All @@ -45,23 +47,28 @@
"j_loader",
"expand_sys_str",
"get_np_precision",
"VALID_PRECISION",
"VALID_ACTIVATION",
]

_PRECISION = Literal["default", "float16", "float32", "float64"]
_ACTIVATION = Literal[
"relu",
"relu6",
"softplus",
"sigmoid",
"tanh",
"gelu",
"gelu_tf",
"none",
"linear",
]
# get_args is new in py38
VALID_PRECISION: Set[_PRECISION] = set(get_args(_PRECISION))
VALID_ACTIVATION: Set[_ACTIVATION] = set(get_args(_ACTIVATION))

if TYPE_CHECKING:
_DICT_VAL = TypeVar("_DICT_VAL")
_PRECISION = Literal["default", "float16", "float32", "float64"]
_ACTIVATION = Literal[
"relu",
"relu6",
"softplus",
"sigmoid",
"tanh",
"gelu",
"gelu_tf",
"none",
"linear",
]
__all__.extend(
[
"_DICT_VAL",
Expand Down
12 changes: 12 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
abstractmethod,
)

import ml_dtypes
import numpy as np

from deepmd.common import (
VALID_PRECISION,
)
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
Expand All @@ -21,14 +25,22 @@
"int32": np.int32,
"int64": np.int64,
"default": GLOBAL_NP_FLOAT_PRECISION,
# NumPy doesn't have bfloat16 (and does't plan to add)
# ml_dtypes is a solution, but it seems not supporting np.save/np.load
# hdf5 hasn't supported bfloat16 as well (see https://forum.hdfgroup.org/t/11975)
"bfloat16": ml_dtypes.bfloat16,
}
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())

RESERVED_PRECISON_DICT = {
np.float16: "float16",
np.float32: "float32",
np.float64: "float64",
np.int32: "int32",
np.int64: "int64",
ml_dtypes.bfloat16: "bfloat16",
}
assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values())
DEFAULT_PRECISION = "float64"


Expand Down
10 changes: 6 additions & 4 deletions deepmd/pt/model/network/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
)
from deepmd.pt.utils.utils import (
ActivationFn,
to_numpy_array,
to_torch_tensor,
)

try:
Expand Down Expand Up @@ -151,9 +153,9 @@ def serialize(self) -> dict:
precision=self.precision,
)
nl.w, nl.b, nl.idt = (
self.matrix.detach().cpu().numpy(),
self.bias.detach().cpu().numpy() if self.bias is not None else None,
self.idt.detach().cpu().numpy() if self.idt is not None else None,
to_numpy_array(self.matrix),
to_numpy_array(self.bias),
to_numpy_array(self.idt),
)
return nl.serialize()

Expand All @@ -180,7 +182,7 @@ def deserialize(cls, data: dict) -> "MLPLayer":

def check_load_param(ss):
return (
nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device))
nn.Parameter(data=to_torch_tensor(nl[ss]))
if nl[ss] is not None
else None
)
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import numpy as np
import torch

from deepmd.common import (
VALID_PRECISION,
)
from deepmd.env import (
GLOBAL_ENER_FLOAT_PRECISION,
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -40,20 +43,24 @@
"double": torch.float64,
"int32": torch.int32,
"int64": torch.int64,
"bfloat16": torch.bfloat16,
}
GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name]
GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[
np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name
]
PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())
# cannot automatically generated
RESERVED_PRECISON_DICT = {
torch.float16: "float16",
torch.float32: "float32",
torch.float64: "float64",
torch.int32: "int32",
torch.int64: "int64",
torch.bfloat16: "bfloat16",
}
assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISON_DICT.keys())
DEFAULT_PRECISION = "float64"

# throw warnings if threads not set
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/utils/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def change_finetune_model_params(finetune_model, model_config, model_branch=""):
model_branch_from=model_branch,
)
finetune_links["Default"] = (
model_branch if finetune_from_multi_task else "Default"
model_config["model_branch_chosen"]
if finetune_from_multi_task
else "Default"
)
else:
assert model_branch == "", (
Expand Down
7 changes: 7 additions & 0 deletions deepmd/pt/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
overload,
)

import ml_dtypes
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -85,6 +86,9 @@ def to_numpy_array(
prec = NP_PRECISION_DICT.get(prec, None)
if prec is None:
raise ValueError(f"unknown precision {xx.dtype}")
if xx.dtype == torch.bfloat16:
# https://github.com/pytorch/pytorch/issues/109873
xx = xx.float()
return xx.detach().cpu().numpy().astype(prec)


Expand All @@ -109,6 +113,9 @@ def to_torch_tensor(
prec = PT_PRECISION_DICT.get(prec, None)
if prec is None:
raise ValueError(f"unknown precision {xx.dtype}")
if xx.dtype == ml_dtypes.bfloat16:
# https://github.com/pytorch/pytorch/issues/109873
xx = xx.astype(np.float32)
return torch.tensor(xx, dtype=prec, device=DEVICE)


Expand Down
4 changes: 4 additions & 0 deletions deepmd/tf/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)

from deepmd.common import (
VALID_ACTIVATION,
VALID_PRECISION,
add_data_requirement,
data_requirement,
expand_sys_str,
Expand Down Expand Up @@ -69,6 +71,7 @@
"float64": tf.float64,
"bfloat16": tf.bfloat16,
}
assert VALID_PRECISION.issubset(PRECISION_DICT.keys())


def gelu(x: tf.Tensor) -> tf.Tensor:
Expand Down Expand Up @@ -138,6 +141,7 @@ def gelu_wrapper(x):
"linear": lambda x: x,
"none": lambda x: x,
}
assert VALID_ACTIVATION.issubset(ACTIVATION_FN_DICT.keys())


def get_activation_func(
Expand Down
111 changes: 25 additions & 86 deletions deepmd/tf/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,30 +98,6 @@ def get_tensor_by_name(model_file: str, tensor_name: str) -> tf.Tensor:
return get_tensor_by_name_from_graph(graph, tensor_name)


def get_tensor_by_type(node, data_type: np.dtype) -> tf.Tensor:
"""Get the tensor value within the given node according to the input data_type.
Parameters
----------
node
The given tensorflow graph node
data_type
The data type of the node
Returns
-------
tf.Tensor
The tensor value of the given node
"""
if data_type == np.float64:
tensor = np.array(node.double_val)
elif data_type == np.float32:
tensor = np.array(node.float_val)
else:
raise RuntimeError("model compression does not support the half precision")
return tensor


def get_pattern_nodes_from_graph_def(graph_def: tf.GraphDef, pattern: str) -> Dict:
"""Get the pattern nodes with the given tf.GraphDef object.
Expand Down Expand Up @@ -214,22 +190,10 @@ def get_embedding_net_variables_from_graph_def(
Dict
The embedding net variables within the given tf.GraphDef object
"""
embedding_net_variables = {}
embedding_net_nodes = get_embedding_net_nodes_from_graph_def(
graph_def, suffix=suffix
)
for item in embedding_net_nodes:
node = embedding_net_nodes[item]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype
)
else:
tensor_value = get_tensor_by_type(node, dtype)
embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return embedding_net_variables
return convert_tensor_to_ndarray_in_dict(embedding_net_nodes)


def get_extra_embedding_net_suffix(type_one_side: bool):
Expand Down Expand Up @@ -268,16 +232,7 @@ def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern:
The numpy array of the variable
"""
node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content,
dtype=tf.as_dtype(node.dtype).as_numpy_dtype,
)
else:
tensor_value = get_tensor_by_type(node, dtype)
return np.reshape(tensor_value, tensor_shape)
return tf.make_ndarray(node)


def get_extra_embedding_net_variables_from_graph_def(
Expand Down Expand Up @@ -403,20 +358,8 @@ def get_fitting_net_variables_from_graph_def(
Dict
The fitting net variables within the given tf.GraphDef object
"""
fitting_net_variables = {}
fitting_net_nodes = get_fitting_net_nodes_from_graph_def(graph_def, suffix=suffix)
for item in fitting_net_nodes:
node = fitting_net_nodes[item]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype
)
else:
tensor_value = get_tensor_by_type(node, dtype)
fitting_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return fitting_net_variables
return convert_tensor_to_ndarray_in_dict(fitting_net_nodes)


def get_fitting_net_variables(model_file: str, suffix: str = "") -> Dict:
Expand Down Expand Up @@ -487,22 +430,10 @@ def get_type_embedding_net_variables_from_graph_def(
Dict
The embedding net variables within the given tf.GraphDef object
"""
type_embedding_net_variables = {}
type_embedding_net_nodes = get_type_embedding_net_nodes_from_graph_def(
graph_def, suffix=suffix
)
for item in type_embedding_net_nodes:
node = type_embedding_net_nodes[item]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype
)
else:
tensor_value = get_tensor_by_type(node, dtype)
type_embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape)
return type_embedding_net_variables
return convert_tensor_to_ndarray_in_dict(type_embedding_net_nodes)


def get_attention_layer_nodes_from_graph_def(
Expand Down Expand Up @@ -556,19 +487,27 @@ def get_attention_layer_variables_from_graph_def(
Dict
The attention layer variables within the given tf.GraphDef object
"""
attention_layer_variables = {}
attention_layer_net_nodes = get_attention_layer_nodes_from_graph_def(
graph_def, suffix=suffix
)
for item in attention_layer_net_nodes:
node = attention_layer_net_nodes[item]
dtype = tf.as_dtype(node.dtype).as_numpy_dtype
tensor_shape = tf.TensorShape(node.tensor_shape).as_list()
if (len(tensor_shape) != 1) or (tensor_shape[0] != 1):
tensor_value = np.frombuffer(
node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype
)
else:
tensor_value = get_tensor_by_type(node, dtype)
attention_layer_variables[item] = np.reshape(tensor_value, tensor_shape)
return attention_layer_variables
return convert_tensor_to_ndarray_in_dict(attention_layer_net_nodes)


def convert_tensor_to_ndarray_in_dict(
tensor_dict: Dict[str, tf.Tensor],
) -> Dict[str, np.ndarray]:
"""Convert tensor to ndarray in dict.
Parameters
----------
tensor_dict : Dict[str, tf.Tensor]
The input tensor dict
Returns
-------
Dict[str, np.ndarray]
The converted tensor dict
"""
for key in tensor_dict:
tensor_dict[key] = tf.make_ndarray(tensor_dict[key])
return tensor_dict
Loading

0 comments on commit ba8317d

Please sign in to comment.