From 8e216f5e082ccff398138f2d0a45ce7075dc05fe Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 16:54:35 -0500 Subject: [PATCH] use functions to store constants so it can be read by C++ Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/serialization.py | 48 ++++++++++++++++++++---------- deepmd/jax/jax2tf/tfmodel.py | 24 +++++++-------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 869aa8edeb..50e5f7c7be 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -71,24 +71,42 @@ def call_lower_with_fixed_do_atomic_virial( tf_model.call_lower_atomic_virial = exported_whether_do_atomic_virial( do_atomic_virial=True ) - # set other attributes - tf_model.type_map = tf.Variable(model.get_type_map(), dtype=tf.string) - tf_model.rcut = tf.Variable(model.get_rcut(), dtype=tf.double) - tf_model.dim_fparam = tf.Variable(model.get_dim_fparam(), dtype=tf.int64) - tf_model.dim_aparam = tf.Variable(model.get_dim_aparam(), dtype=tf.int64) - tf_model.sel_type = tf.Variable(model.get_sel_type(), dtype=tf.int64) - tf_model.is_aparam_nall = tf.Variable(model.is_aparam_nall(), dtype=tf.bool) - tf_model.model_output_type = tf.Variable( - model.model_output_type(), dtype=tf.string + # set functions to export other attributes + tf_model.get_type_map = tf.function( + lambda: tf.constant(model.get_type_map(), dtype=tf.string) + ) + tf_model.get_rcut = tf.function( + lambda: tf.constant(model.get_rcut(), dtype=tf.double) + ) + tf_model.get_dim_fparam = tf.function( + lambda: tf.constant(model.get_dim_fparam(), dtype=tf.int64) + ) + tf_model.get_dim_aparam = tf.function( + lambda: tf.constant(model.get_dim_aparam(), dtype=tf.int64) + ) + tf_model.get_sel_type = tf.function( + lambda: tf.constant(model.get_sel_type(), dtype=tf.int64) + ) + tf_model.is_aparam_nall = tf.function( + lambda: tf.constant(model.is_aparam_nall(), dtype=tf.bool) + ) + tf_model.model_output_type = tf.function( + lambda: tf.constant(model.model_output_type(), dtype=tf.string) + ) + tf_model.mixed_types = tf.function( + lambda: tf.constant(model.mixed_types(), dtype=tf.bool) ) - tf_model.mixed_types = tf.Variable(model.mixed_types(), dtype=tf.bool) if model.get_min_nbor_dist() is not None: - tf_model.min_nbor_dist = tf.Variable( - model.get_min_nbor_dist(), dtype=tf.double + tf_model.get_min_nbor_dist = tf.function( + lambda: tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + ) + tf_model.get_sel = tf.function( + lambda: tf.constant(model.get_sel(), dtype=tf.int64) + ) + tf_model.get_model_def_script = tf.function( + lambda: tf.constant( + json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) - tf_model.sel = tf.Variable(model.get_sel(), dtype=tf.int64) - tf_model.model_def_script = tf.Variable( - json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string ) tf.saved_model.save( tf_model, diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py index 7339835a4b..8f04014a97 100644 --- a/deepmd/jax/jax2tf/tfmodel.py +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -55,22 +55,22 @@ def __init__( self._call_lower_atomic_virial = jax2tf.call_tf( self.model.call_lower_atomic_virial ) - self.type_map = decode_list_of_bytes(self.model.type_map.numpy().tolist()) - self.rcut = self.model.rcut.numpy().item() - self.dim_fparam = self.model.dim_fparam.numpy().item() - self.dim_aparam = self.model.dim_aparam.numpy().item() - self.sel_type = self.model.sel_type.numpy().tolist() - self._is_aparam_nall = self.model.is_aparam_nall.numpy().item() + self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist()) + self.rcut = self.model.get_rcut().numpy().item() + self.dim_fparam = self.model.get_dim_fparam().numpy().item() + self.dim_aparam = self.model.get_dim_aparam().numpy().item() + self.sel_type = self.model.get_sel_type().numpy().tolist() + self._is_aparam_nall = self.model.is_aparam_nall().numpy().item() self._model_output_type = decode_list_of_bytes( - self.model.model_output_type.numpy().tolist() + self.model.model_output_type().numpy().tolist() ) - self._mixed_types = self.model.mixed_types.numpy().item() - if hasattr(self.model, "min_nbor_dist"): - self.min_nbor_dist = self.model.min_nbor_dist.numpy().item() + self._mixed_types = self.model.mixed_types().numpy().item() + if hasattr(self.model, "get_min_nbor_dist"): + self.min_nbor_dist = self.model.get_min_nbor_dist().numpy().item() else: self.min_nbor_dist = None - self.sel = self.model.sel.numpy().tolist() - self.model_def_script = self.model.model_def_script.numpy().decode() + self.sel = self.model.get_sel().numpy().tolist() + self.model_def_script = self.model.get_model_def_script().numpy().decode() def __call__( self,