diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 3f5dede1ad..0724686ba3 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -64,6 +64,7 @@ def __init__( input_map=input_map, ) self.load_prefix = load_prefix + self.model_file = model_file # graph_compatable should be called after graph and prefix are set if not self._graph_compatable(): @@ -360,3 +361,44 @@ def eval_typeebd(self) -> np.ndarray: t_typeebd = self._get_tensor("t_typeebd:0") [typeebd] = run_sess(self.sess, [t_typeebd], feed_dict={}) return typeebd + + def update_typeebd(self, new_typeebd: np.ndarray, save_path: str): + """Change the type embedding of this model and then save to a new one. + + Parameters + ---------- + new_typeebd + The new type embedding to replace the old one. + save_path + The output file to save the new model. + + Examples + -------- + Change the type embedding of `graph.pb` and save the new model in `graph_new.pb`: + + >>> from deepmd.infer import DeepPotential + >>> dp = DeepPotential('graph.pb') + >>> new_tebd = dp.eval_typeebd() # or some np.ndarray has the same shape as new_tebd + >>> dp.update_typeebd(new_tebd, 'graph_new.pb') + """ + with tf.gfile.GFile(self.model_file, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + for node in graph_def.node: + if node.name in ["type_embed_net/matrix_1"]: + t_matrix = node.attr["value"].tensor + old_typeebd = tf.make_ndarray(t_matrix) + required_shape = (old_typeebd.shape[0] + 1, old_typeebd.shape[1]) + assert ( + required_shape == new_typeebd.shape + ), f"The input type embedding should has shape {required_shape}, but got {new_typeebd.shape} instead!" + new_typeebd = new_typeebd[:-1].astype(old_typeebd.dtype) + new_typeebd_tensor_pb = tf.make_tensor_proto(new_typeebd) + node.attr["value"].tensor.CopyFrom(new_typeebd_tensor_pb) + elif node.name in ["type_embed_net/bias_1"]: + t_bias = node.attr["value"].tensor + old_bias = tf.make_ndarray(t_bias) + new_bias_tensor_pb = tf.make_tensor_proto(np.zeros_like(old_bias)) + node.attr["value"].tensor.CopyFrom(new_bias_tensor_pb) + with tf.gfile.GFile(save_path, "wb") as f: + f.write(graph_def.SerializeToString())