From 08b3814a3849537f02fee78157ac5a145b5c34ce Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 26 Mar 2024 05:09:13 -0400 Subject: [PATCH 1/3] feat: Support bfloat16 and ensure valid precision and activation functions consistent everywhere (#3601) Fix #3553. To support bfloat16, - DP: As NumPy doesn't plan to support bfloat16, use `ml_dtypes.bfloat16`. The `ml_dtypes` package, developed by Google, is required. - TF: use `tf.make_ndarray` instead of manually extracting information from a tensor node. - PT: As ml_dtypes is not supported yet (https://github.com/pytorch/pytorch/issues/109873), use float32 as a bridge. --------- Signed-off-by: Jinzhe Zeng --- deepmd/common.py | 31 ++++-- deepmd/dpmodel/common.py | 12 ++ deepmd/pt/model/network/mlp.py | 10 +- deepmd/pt/utils/env.py | 7 ++ deepmd/pt/utils/utils.py | 7 ++ deepmd/tf/common.py | 4 + deepmd/tf/utils/graph.py | 111 +++++-------------- deepmd/utils/argcheck.py | 26 +---- pyproject.toml | 1 + source/tests/consistent/fitting/test_ener.py | 6 +- source/tests/consistent/test_activation.py | 15 +-- 11 files changed, 96 insertions(+), 134 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index 84f98c6318..098bb0ed11 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -17,8 +17,10 @@ Dict, List, Optional, + Set, TypeVar, Union, + get_args, ) try: @@ -45,23 +47,28 @@ "j_loader", "expand_sys_str", "get_np_precision", + "VALID_PRECISION", + "VALID_ACTIVATION", ] +_PRECISION = Literal["default", "float16", "float32", "float64"] +_ACTIVATION = Literal[ + "relu", + "relu6", + "softplus", + "sigmoid", + "tanh", + "gelu", + "gelu_tf", + "none", + "linear", +] +# get_args is new in py38 +VALID_PRECISION: Set[_PRECISION] = set(get_args(_PRECISION)) +VALID_ACTIVATION: Set[_ACTIVATION] = set(get_args(_ACTIVATION)) if TYPE_CHECKING: _DICT_VAL = TypeVar("_DICT_VAL") - _PRECISION = Literal["default", "float16", "float32", "float64"] - _ACTIVATION = Literal[ - "relu", - "relu6", - "softplus", - "sigmoid", - "tanh", - "gelu", - "gelu_tf", - "none", - "linear", - ] __all__.extend( [ "_DICT_VAL", diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index 761db2f6aa..8030432385 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -4,8 +4,12 @@ abstractmethod, ) +import ml_dtypes import numpy as np +from deepmd.common import ( + VALID_PRECISION, +) from deepmd.env import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, @@ -21,14 +25,22 @@ "int32": np.int32, "int64": np.int64, "default": GLOBAL_NP_FLOAT_PRECISION, + # NumPy doesn't have bfloat16 (and does't plan to add) + # ml_dtypes is a solution, but it seems not supporting np.save/np.load + # hdf5 hasn't supported bfloat16 as well (see https://forum.hdfgroup.org/t/11975) + "bfloat16": ml_dtypes.bfloat16, } +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) + RESERVED_PRECISON_DICT = { np.float16: "float16", np.float32: "float32", np.float64: "float64", np.int32: "int32", np.int64: "int64", + ml_dtypes.bfloat16: "bfloat16", } +assert set(RESERVED_PRECISON_DICT.keys()) == set(PRECISION_DICT.values()) DEFAULT_PRECISION = "float64" diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 4af1d00df8..762461111e 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -30,6 +30,8 @@ ) from deepmd.pt.utils.utils import ( ActivationFn, + to_numpy_array, + to_torch_tensor, ) try: @@ -151,9 +153,9 @@ def serialize(self) -> dict: precision=self.precision, ) nl.w, nl.b, nl.idt = ( - self.matrix.detach().cpu().numpy(), - self.bias.detach().cpu().numpy() if self.bias is not None else None, - self.idt.detach().cpu().numpy() if self.idt is not None else None, + to_numpy_array(self.matrix), + to_numpy_array(self.bias), + to_numpy_array(self.idt), ) return nl.serialize() @@ -180,7 +182,7 @@ def deserialize(cls, data: dict) -> "MLPLayer": def check_load_param(ss): return ( - nn.Parameter(data=torch.tensor(nl[ss], dtype=prec, device=device)) + nn.Parameter(data=to_torch_tensor(nl[ss])) if nl[ss] is not None else None ) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 0b92953255..d841a9b73c 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -4,6 +4,9 @@ import numpy as np import torch +from deepmd.common import ( + VALID_PRECISION, +) from deepmd.env import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, @@ -40,12 +43,14 @@ "double": torch.float64, "int32": torch.int32, "int64": torch.int64, + "bfloat16": torch.bfloat16, } GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name ] PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) # cannot automatically generated RESERVED_PRECISON_DICT = { torch.float16: "float16", @@ -53,7 +58,9 @@ torch.float64: "float64", torch.int32: "int32", torch.int64: "int64", + torch.bfloat16: "bfloat16", } +assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISON_DICT.keys()) DEFAULT_PRECISION = "float64" # throw warnings if threads not set diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index ac251648af..4b64a7231c 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -5,6 +5,7 @@ overload, ) +import ml_dtypes import numpy as np import torch import torch.nn.functional as F @@ -85,6 +86,9 @@ def to_numpy_array( prec = NP_PRECISION_DICT.get(prec, None) if prec is None: raise ValueError(f"unknown precision {xx.dtype}") + if xx.dtype == torch.bfloat16: + # https://github.com/pytorch/pytorch/issues/109873 + xx = xx.float() return xx.detach().cpu().numpy().astype(prec) @@ -109,6 +113,9 @@ def to_torch_tensor( prec = PT_PRECISION_DICT.get(prec, None) if prec is None: raise ValueError(f"unknown precision {xx.dtype}") + if xx.dtype == ml_dtypes.bfloat16: + # https://github.com/pytorch/pytorch/issues/109873 + xx = xx.astype(np.float32) return torch.tensor(xx, dtype=prec, device=DEVICE) diff --git a/deepmd/tf/common.py b/deepmd/tf/common.py index 0d59990a29..5f2d0d882e 100644 --- a/deepmd/tf/common.py +++ b/deepmd/tf/common.py @@ -18,6 +18,8 @@ ) from deepmd.common import ( + VALID_ACTIVATION, + VALID_PRECISION, add_data_requirement, data_requirement, expand_sys_str, @@ -69,6 +71,7 @@ "float64": tf.float64, "bfloat16": tf.bfloat16, } +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) def gelu(x: tf.Tensor) -> tf.Tensor: @@ -138,6 +141,7 @@ def gelu_wrapper(x): "linear": lambda x: x, "none": lambda x: x, } +assert VALID_ACTIVATION.issubset(ACTIVATION_FN_DICT.keys()) def get_activation_func( diff --git a/deepmd/tf/utils/graph.py b/deepmd/tf/utils/graph.py index 3ed43343fa..65f4a743f5 100644 --- a/deepmd/tf/utils/graph.py +++ b/deepmd/tf/utils/graph.py @@ -98,30 +98,6 @@ def get_tensor_by_name(model_file: str, tensor_name: str) -> tf.Tensor: return get_tensor_by_name_from_graph(graph, tensor_name) -def get_tensor_by_type(node, data_type: np.dtype) -> tf.Tensor: - """Get the tensor value within the given node according to the input data_type. - - Parameters - ---------- - node - The given tensorflow graph node - data_type - The data type of the node - - Returns - ------- - tf.Tensor - The tensor value of the given node - """ - if data_type == np.float64: - tensor = np.array(node.double_val) - elif data_type == np.float32: - tensor = np.array(node.float_val) - else: - raise RuntimeError("model compression does not support the half precision") - return tensor - - def get_pattern_nodes_from_graph_def(graph_def: tf.GraphDef, pattern: str) -> Dict: """Get the pattern nodes with the given tf.GraphDef object. @@ -214,22 +190,10 @@ def get_embedding_net_variables_from_graph_def( Dict The embedding net variables within the given tf.GraphDef object """ - embedding_net_variables = {} embedding_net_nodes = get_embedding_net_nodes_from_graph_def( graph_def, suffix=suffix ) - for item in embedding_net_nodes: - node = embedding_net_nodes[item] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape) - return embedding_net_variables + return convert_tensor_to_ndarray_in_dict(embedding_net_nodes) def get_extra_embedding_net_suffix(type_one_side: bool): @@ -268,16 +232,7 @@ def get_variables_from_graph_def_as_numpy_array(graph_def: tf.GraphDef, pattern: The numpy array of the variable """ node = get_pattern_nodes_from_graph_def(graph_def, pattern)[pattern] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, - dtype=tf.as_dtype(node.dtype).as_numpy_dtype, - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - return np.reshape(tensor_value, tensor_shape) + return tf.make_ndarray(node) def get_extra_embedding_net_variables_from_graph_def( @@ -403,20 +358,8 @@ def get_fitting_net_variables_from_graph_def( Dict The fitting net variables within the given tf.GraphDef object """ - fitting_net_variables = {} fitting_net_nodes = get_fitting_net_nodes_from_graph_def(graph_def, suffix=suffix) - for item in fitting_net_nodes: - node = fitting_net_nodes[item] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - fitting_net_variables[item] = np.reshape(tensor_value, tensor_shape) - return fitting_net_variables + return convert_tensor_to_ndarray_in_dict(fitting_net_nodes) def get_fitting_net_variables(model_file: str, suffix: str = "") -> Dict: @@ -487,22 +430,10 @@ def get_type_embedding_net_variables_from_graph_def( Dict The embedding net variables within the given tf.GraphDef object """ - type_embedding_net_variables = {} type_embedding_net_nodes = get_type_embedding_net_nodes_from_graph_def( graph_def, suffix=suffix ) - for item in type_embedding_net_nodes: - node = type_embedding_net_nodes[item] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - type_embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape) - return type_embedding_net_variables + return convert_tensor_to_ndarray_in_dict(type_embedding_net_nodes) def get_attention_layer_nodes_from_graph_def( @@ -556,19 +487,27 @@ def get_attention_layer_variables_from_graph_def( Dict The attention layer variables within the given tf.GraphDef object """ - attention_layer_variables = {} attention_layer_net_nodes = get_attention_layer_nodes_from_graph_def( graph_def, suffix=suffix ) - for item in attention_layer_net_nodes: - node = attention_layer_net_nodes[item] - dtype = tf.as_dtype(node.dtype).as_numpy_dtype - tensor_shape = tf.TensorShape(node.tensor_shape).as_list() - if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor_value = np.frombuffer( - node.tensor_content, dtype=tf.as_dtype(node.dtype).as_numpy_dtype - ) - else: - tensor_value = get_tensor_by_type(node, dtype) - attention_layer_variables[item] = np.reshape(tensor_value, tensor_shape) - return attention_layer_variables + return convert_tensor_to_ndarray_in_dict(attention_layer_net_nodes) + + +def convert_tensor_to_ndarray_in_dict( + tensor_dict: Dict[str, tf.Tensor], +) -> Dict[str, np.ndarray]: + """Convert tensor to ndarray in dict. + + Parameters + ---------- + tensor_dict : Dict[str, tf.Tensor] + The input tensor dict + + Returns + ------- + Dict[str, np.ndarray] + The converted tensor dict + """ + for key in tensor_dict: + tensor_dict[key] = tf.make_ndarray(tensor_dict[key]) + return tensor_dict diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 564039ccd0..2a98bee6fe 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -14,6 +14,10 @@ dargs, ) +from deepmd.common import ( + VALID_ACTIVATION, + VALID_PRECISION, +) from deepmd.utils.argcheck_nvnmd import ( nvnmd_args, ) @@ -24,26 +28,8 @@ log = logging.getLogger(__name__) -# TODO: import from a module outside tf/pt -ACTIVATION_FN_DICT = { - "relu": None, - "relu6": None, - "softplus": None, - "sigmoid": None, - "tanh": None, - "gelu": None, - "gelu_tf": None, - "None": None, - "none": None, -} -# TODO: import from a module outside tf/pt -PRECISION_DICT = { - "default": None, - "float16": None, - "float32": None, - "float64": None, - "bfloat16": None, -} +ACTIVATION_FN_DICT = dict.fromkeys(VALID_ACTIVATION) +PRECISION_DICT = dict.fromkeys(VALID_PRECISION) doc_only_tf_supported = "(Supported Backend: TensorFlow) " doc_only_pt_supported = "(Supported Backend: PyTorch) " diff --git a/pyproject.toml b/pyproject.toml index 128364249a..0e449d46af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ 'h5py', 'wcmatch', 'packaging', + 'ml_dtypes', ] requires-python = ">=3.8" keywords = ["deepmd"] diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index ab314cb9af..157b1bab8a 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -40,7 +40,7 @@ @parameterized( (True, False), # resnet_dt - ("float64", "float32"), # precision + ("float64", "float32", "bfloat16"), # precision (True, False), # mixed_types (0, 1), # numb_fparam ([], [-12345.6, None]), # atom_ener @@ -178,6 +178,8 @@ def rtol(self) -> float: return 1e-10 elif precision == "float32": return 1e-4 + elif precision == "bfloat16": + return 1e-1 else: raise ValueError(f"Unknown precision: {precision}") @@ -195,5 +197,7 @@ def atol(self) -> float: return 1e-10 elif precision == "float32": return 1e-4 + elif precision == "bfloat16": + return 1e-1 else: raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index bb06df9082..83b8494729 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -3,6 +3,9 @@ import numpy as np +from deepmd.common import ( + VALID_ACTIVATION, +) from deepmd.dpmodel.utils.network import get_activation_fn as get_activation_fn_dp from .common import ( @@ -25,17 +28,7 @@ @parameterized( - ( - "Relu", - "Relu6", - "Softplus", - "Sigmoid", - "Tanh", - "Gelu", - "Gelu_tf", - "Linear", - "None", - ), + tuple([x.capitalize() for x in VALID_ACTIVATION]), ) class TestActivationFunctionConsistent(unittest.TestCase): def setUp(self): From 3b5b8057c6ff84ea8f2faeea69bcbe42ad466189 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 26 Mar 2024 19:17:42 +0800 Subject: [PATCH 2/3] pt: fix typo in multitask finetune (#3607) fix #3604 , when doing single-task finetuning from multitask pretrained model and do not define the finetune model branch from command-line or input file. --- deepmd/pt/utils/finetune.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/utils/finetune.py b/deepmd/pt/utils/finetune.py index 3f76454442..2de4214070 100644 --- a/deepmd/pt/utils/finetune.py +++ b/deepmd/pt/utils/finetune.py @@ -146,7 +146,9 @@ def change_finetune_model_params(finetune_model, model_config, model_branch=""): model_branch_from=model_branch, ) finetune_links["Default"] = ( - model_branch if finetune_from_multi_task else "Default" + model_config["model_branch_chosen"] + if finetune_from_multi_task + else "Default" ) else: assert model_branch == "", ( From 625e8939d06b6e3c8c62264927835a9dfe7e857b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 26 Mar 2024 08:26:47 -0400 Subject: [PATCH 3/3] chore: remove incorrect memset TODOs (#3600) Fix #3556. Signed-off-by: Jinzhe Zeng --- source/lib/src/gpu/prod_env_mat.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/lib/src/gpu/prod_env_mat.cu b/source/lib/src/gpu/prod_env_mat.cu index a69e014272..e8909edb44 100644 --- a/source/lib/src/gpu/prod_env_mat.cu +++ b/source/lib/src/gpu/prod_env_mat.cu @@ -486,7 +486,6 @@ __global__ void compute_env_mat_a(FPTYPE* em, std[type[bid] * ndescrpt + idx_value + ii]; } } else { - // TODO: move it to the memset. row_descript[idx_value] -= avg[type[bid] * ndescrpt + idx_value] / std[type[bid] * ndescrpt + idx_value]; } @@ -562,7 +561,6 @@ __global__ void compute_env_mat_r(FPTYPE* em, row_em[idx_value] = (dd - avg[type[bid] * ndescrpt + idx_value]) / std[type[bid] * ndescrpt + idx_value]; } else { - // TODO: move it to the memset. row_em[idx_value] -= avg[type[bid] * ndescrpt + idx_value] / std[type[bid] * ndescrpt + idx_value]; }