Skip to content

Commit

Permalink
bump python to 3.12 in the test environment (deepmodeling#3343)
Browse files Browse the repository at this point in the history
Fix a bug caused by the breaking change in Keras 3 (shipped by TF 2.16).

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 27, 2024
1 parent 3e6b507 commit 473cc0a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- python: 3.8
tf:
torch:
- python: "3.11"
- python: "3.12"
tf:
torch:

Expand Down
5 changes: 5 additions & 0 deletions backend/find_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ def get_tf_requirement(tf_version: str = "") -> dict:
extra_select = {}
if not (tf_version == "" or tf_version in SpecifierSet(">=2.12", prereleases=True)):
extra_requires.append("protobuf<3.20")
# keras 3 is not compatible with tf.compat.v1
if tf_version == "" or tf_version in SpecifierSet(">=2.15.0rc0", prereleases=True):
extra_requires.append("tf-keras; python_version>='3.9'")
# only TF>=2.16 is compatible with Python 3.12
extra_requires.append("tf-keras>=2.16.0rc0; python_version>='3.12'")
if tf_version == "" or tf_version in SpecifierSet(">=1.15", prereleases=True):
extra_select["mpi"] = [
"horovod",
Expand Down
1 change: 1 addition & 0 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def _attention_layers(
input_xyz = tf.keras.layers.LayerNormalization(
beta_initializer=tf.constant_initializer(self.beta[i]),
gamma_initializer=tf.constant_initializer(self.gamma[i]),
dtype=self.filter_precision,
)(input_xyz)
# input_xyz = self._feedforward(input_xyz, outputs_size[-1], self.att_n)
return input_xyz
Expand Down
4 changes: 3 additions & 1 deletion deepmd/tf/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ def dlopen_library(module: str, filename: str):
dlopen_library("nvidia.cusparse.lib", "libcusparse.so*")
dlopen_library("nvidia.cudnn.lib", "libcudnn.so*")


# keras 3 is incompatible with tf.compat.v1
# https://keras.io/getting_started/#tensorflow--keras-2-backwards-compatibility
os.environ["TF_USE_LEGACY_KERAS"] = "1"
# import tensorflow v1 compatability
try:
import tensorflow.compat.v1 as tf
Expand Down

0 comments on commit 473cc0a

Please sign in to comment.