diff --git a/deepmd/entrypoints/freeze.py b/deepmd/entrypoints/freeze.py index 11e0d55645..22f3cb80b4 100755 --- a/deepmd/entrypoints/freeze.py +++ b/deepmd/entrypoints/freeze.py @@ -511,9 +511,13 @@ def freeze( # We import the meta graph and retrieve a Saver try: # In case paralle training - import horovod.tensorflow as _ # noqa: F401 + import horovod.tensorflow as HVD except ImportError: pass + else: + HVD.init() + if HVD.rank() > 0: + return saver = tf.train.import_meta_graph( f"{input_checkpoint}.meta", clear_devices=clear_devices )