Skip to content

Commit

Permalink
use functions to store constants so it can be read by C++
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 3, 2024
1 parent 980b4a9 commit 8e216f5
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 27 deletions.
48 changes: 33 additions & 15 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,42 @@ def call_lower_with_fixed_do_atomic_virial(
tf_model.call_lower_atomic_virial = exported_whether_do_atomic_virial(
do_atomic_virial=True
)
# set other attributes
tf_model.type_map = tf.Variable(model.get_type_map(), dtype=tf.string)
tf_model.rcut = tf.Variable(model.get_rcut(), dtype=tf.double)
tf_model.dim_fparam = tf.Variable(model.get_dim_fparam(), dtype=tf.int64)
tf_model.dim_aparam = tf.Variable(model.get_dim_aparam(), dtype=tf.int64)
tf_model.sel_type = tf.Variable(model.get_sel_type(), dtype=tf.int64)
tf_model.is_aparam_nall = tf.Variable(model.is_aparam_nall(), dtype=tf.bool)
tf_model.model_output_type = tf.Variable(
model.model_output_type(), dtype=tf.string
# set functions to export other attributes
tf_model.get_type_map = tf.function(
lambda: tf.constant(model.get_type_map(), dtype=tf.string)
)
tf_model.get_rcut = tf.function(
lambda: tf.constant(model.get_rcut(), dtype=tf.double)
)
tf_model.get_dim_fparam = tf.function(
lambda: tf.constant(model.get_dim_fparam(), dtype=tf.int64)
)
tf_model.get_dim_aparam = tf.function(
lambda: tf.constant(model.get_dim_aparam(), dtype=tf.int64)
)
tf_model.get_sel_type = tf.function(
lambda: tf.constant(model.get_sel_type(), dtype=tf.int64)
)
tf_model.is_aparam_nall = tf.function(
lambda: tf.constant(model.is_aparam_nall(), dtype=tf.bool)
)
tf_model.model_output_type = tf.function(
lambda: tf.constant(model.model_output_type(), dtype=tf.string)
)
tf_model.mixed_types = tf.function(
lambda: tf.constant(model.mixed_types(), dtype=tf.bool)
)
tf_model.mixed_types = tf.Variable(model.mixed_types(), dtype=tf.bool)
if model.get_min_nbor_dist() is not None:
tf_model.min_nbor_dist = tf.Variable(
model.get_min_nbor_dist(), dtype=tf.double
tf_model.get_min_nbor_dist = tf.function(
lambda: tf.constant(model.get_min_nbor_dist(), dtype=tf.double)
)
tf_model.get_sel = tf.function(
lambda: tf.constant(model.get_sel(), dtype=tf.int64)
)
tf_model.get_model_def_script = tf.function(
lambda: tf.constant(
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
)
tf_model.sel = tf.Variable(model.get_sel(), dtype=tf.int64)
tf_model.model_def_script = tf.Variable(
json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string
)
tf.saved_model.save(
tf_model,
Expand Down
24 changes: 12 additions & 12 deletions deepmd/jax/jax2tf/tfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,22 @@ def __init__(
self._call_lower_atomic_virial = jax2tf.call_tf(
self.model.call_lower_atomic_virial
)
self.type_map = decode_list_of_bytes(self.model.type_map.numpy().tolist())
self.rcut = self.model.rcut.numpy().item()
self.dim_fparam = self.model.dim_fparam.numpy().item()
self.dim_aparam = self.model.dim_aparam.numpy().item()
self.sel_type = self.model.sel_type.numpy().tolist()
self._is_aparam_nall = self.model.is_aparam_nall.numpy().item()
self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist())
self.rcut = self.model.get_rcut().numpy().item()
self.dim_fparam = self.model.get_dim_fparam().numpy().item()
self.dim_aparam = self.model.get_dim_aparam().numpy().item()
self.sel_type = self.model.get_sel_type().numpy().tolist()
self._is_aparam_nall = self.model.is_aparam_nall().numpy().item()
self._model_output_type = decode_list_of_bytes(
self.model.model_output_type.numpy().tolist()
self.model.model_output_type().numpy().tolist()
)
self._mixed_types = self.model.mixed_types.numpy().item()
if hasattr(self.model, "min_nbor_dist"):
self.min_nbor_dist = self.model.min_nbor_dist.numpy().item()
self._mixed_types = self.model.mixed_types().numpy().item()
if hasattr(self.model, "get_min_nbor_dist"):
self.min_nbor_dist = self.model.get_min_nbor_dist().numpy().item()
else:
self.min_nbor_dist = None
self.sel = self.model.sel.numpy().tolist()
self.model_def_script = self.model.model_def_script.numpy().decode()
self.sel = self.model.get_sel().numpy().tolist()
self.model_def_script = self.model.get_model_def_script().numpy().decode()

def __call__(
self,
Expand Down

0 comments on commit 8e216f5

Please sign in to comment.