Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support type embedding changing for frozen .pb #2827

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
input_map=input_map,
)
self.load_prefix = load_prefix
self.model_file = model_file

# graph_compatable should be called after graph and prefix are set
if not self._graph_compatable():
Expand Down Expand Up @@ -360,3 +361,44 @@
t_typeebd = self._get_tensor("t_typeebd:0")
[typeebd] = run_sess(self.sess, [t_typeebd], feed_dict={})
return typeebd

def update_typeebd(self, new_typeebd: np.ndarray, save_path: str):
"""Change the type embedding of this model and then save to a new one.

Parameters
----------
new_typeebd
The new type embedding to replace the old one.
save_path
The output file to save the new model.

Examples
--------
Change the type embedding of `graph.pb` and save the new model in `graph_new.pb`:

>>> from deepmd.infer import DeepPotential
>>> dp = DeepPotential('graph.pb')
>>> new_tebd = dp.eval_typeebd() # or some np.ndarray has the same shape as new_tebd
>>> dp.update_typeebd(new_tebd, 'graph_new.pb')
"""
with tf.gfile.GFile(self.model_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
for node in graph_def.node:
if node.name in ["type_embed_net/matrix_1"]:
t_matrix = node.attr["value"].tensor
old_typeebd = tf.make_ndarray(t_matrix)
required_shape = (old_typeebd.shape[0] + 1, old_typeebd.shape[1])
assert (

Check warning on line 392 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L384-L392

Added lines #L384 - L392 were not covered by tests
required_shape == new_typeebd.shape
), f"The input type embedding should has shape {required_shape}, but got {new_typeebd.shape} instead!"
new_typeebd = new_typeebd[:-1].astype(old_typeebd.dtype)
new_typeebd_tensor_pb = tf.make_tensor_proto(new_typeebd)
node.attr["value"].tensor.CopyFrom(new_typeebd_tensor_pb)
elif node.name in ["type_embed_net/bias_1"]:
t_bias = node.attr["value"].tensor
old_bias = tf.make_ndarray(t_bias)
new_bias_tensor_pb = tf.make_tensor_proto(np.zeros_like(old_bias))
node.attr["value"].tensor.CopyFrom(new_bias_tensor_pb)
with tf.gfile.GFile(save_path, "wb") as f:
f.write(graph_def.SerializeToString())

Check warning on line 404 in deepmd/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L395-L404

Added lines #L395 - L404 were not covered by tests
Loading