diff --git a/deepmd/utils/convert.py b/deepmd/utils/convert.py index dd26fa1058..13e07f0885 100644 --- a/deepmd/utils/convert.py +++ b/deepmd/utils/convert.py @@ -1,15 +1,28 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging import os import textwrap +from typing import ( + Optional, +) from google.protobuf import ( text_format, ) +from packaging.specifiers import ( + SpecifierSet, +) +from packaging.version import parse as parse_version +from deepmd import ( + __version__, +) from deepmd.env import ( tf, ) +log = logging.getLogger(__name__) + def detect_model_version(input_model: str): """Detect DP graph version. @@ -20,33 +33,33 @@ def detect_model_version(input_model: str): filename of the input graph """ convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") - version = "undetected" + version = None with open("frozen_model.pbtxt") as fp: file_content = fp.read() if file_content.find("DescrptNorot") > -1: - version = "<= 0.12" + version = parse_version("0.12") elif ( file_content.find("fitting_attr/dfparam") > -1 and file_content.find("fitting_attr/daparam") == -1 ): - version = "1.0" + version = parse_version("1.0") elif file_content.find("model_attr/model_version") == -1: name_dsea = file_content.find('name: "DescrptSeA"') post_dsea = file_content[name_dsea:] post_dsea2 = post_dsea[:300].find(r"}") search_double = post_dsea[:post_dsea2] if search_double.find("DT_DOUBLE") == -1: - version = "1.2" + version = parse_version("1.2") else: - version = "1.3" + version = parse_version("1.3") elif file_content.find('string_val: "1.0"') > -1: - version = "2.0" + version = parse_version("2.0") elif file_content.find('string_val: "1.1"') > -1: - version = ">= 2.1" + version = parse_version("2.1") return version -def convert_to_21(input_model: str, output_model: str): +def convert_to_21(input_model: str, output_model: str, version: Optional[str] = None): """Convert DP graph to 2.1 graph. Parameters @@ -55,37 +68,36 @@ def convert_to_21(input_model: str, output_model: str): filename of the input graph output_model : str filename of the output graph + version : str + version of the input graph, if not specified, it will be detected automatically """ - version = detect_model_version(input_model) - if version == "<= 0.12": + if version is None: + version = detect_model_version(input_model) + else: + convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") + if version is None: + raise ValueError( + "The version of the DP graph %s cannot be detected. Please do the conversion manually." + % (input_model) + ) + if version in SpecifierSet("<1.0"): convert_dp012_to_dp10("frozen_model.pbtxt") + if version in SpecifierSet("<1.1"): convert_dp10_to_dp11("frozen_model.pbtxt") + if version in SpecifierSet("<1.3"): convert_dp12_to_dp13("frozen_model.pbtxt") + if version in SpecifierSet("<2.0"): convert_dp13_to_dp20("frozen_model.pbtxt") + if version in SpecifierSet("<2.1"): convert_dp20_to_dp21("frozen_model.pbtxt") - elif version == "1.0": - convert_dp10_to_dp11("frozen_model.pbtxt") - convert_dp12_to_dp13("frozen_model.pbtxt") - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - elif version == "1.2": - convert_dp12_to_dp13("frozen_model.pbtxt") - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - elif version == "1.3": - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - elif version == "2.0": - convert_dp20_to_dp21("frozen_model.pbtxt") - elif version == "undetected": - raise ValueError( - "The version of the DP graph %s cannot be detected. Please do the conversion manually." - % (input_model) - ) convert_pbtxt_to_pb("frozen_model.pbtxt", output_model) if os.path.isfile("frozen_model.pbtxt"): os.remove("frozen_model.pbtxt") - print("the converted output model (2.1 support) is saved in %s" % output_model) + log.info( + "the converted output model (%s support) is saved in %s", + __version__, + output_model, + ) def convert_13_to_21(input_model: str, output_model: str): @@ -98,13 +110,7 @@ def convert_13_to_21(input_model: str, output_model: str): output_model : str filename of the output graph """ - convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - convert_pbtxt_to_pb("frozen_model.pbtxt", output_model) - if os.path.isfile("frozen_model.pbtxt"): - os.remove("frozen_model.pbtxt") - print("the converted output model (2.1 support) is saved in %s" % output_model) + convert_to_21(input_model, output_model, version="1.3") def convert_12_to_21(input_model: str, output_model: str): @@ -117,14 +123,7 @@ def convert_12_to_21(input_model: str, output_model: str): output_model : str filename of the output graph """ - convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") - convert_dp12_to_dp13("frozen_model.pbtxt") - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - convert_pbtxt_to_pb("frozen_model.pbtxt", output_model) - if os.path.isfile("frozen_model.pbtxt"): - os.remove("frozen_model.pbtxt") - print("the converted output model (2.1 support) is saved in %s" % output_model) + convert_to_21(input_model, output_model, version="1.2") def convert_10_to_21(input_model: str, output_model: str): @@ -137,15 +136,7 @@ def convert_10_to_21(input_model: str, output_model: str): output_model : str filename of the output graph """ - convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") - convert_dp10_to_dp11("frozen_model.pbtxt") - convert_dp12_to_dp13("frozen_model.pbtxt") - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - convert_pbtxt_to_pb("frozen_model.pbtxt", output_model) - if os.path.isfile("frozen_model.pbtxt"): - os.remove("frozen_model.pbtxt") - print("the converted output model (2.1 support) is saved in %s" % output_model) + convert_to_21(input_model, output_model, version="1.0") def convert_012_to_21(input_model: str, output_model: str): @@ -158,16 +149,7 @@ def convert_012_to_21(input_model: str, output_model: str): output_model : str filename of the output graph """ - convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") - convert_dp012_to_dp10("frozen_model.pbtxt") - convert_dp10_to_dp11("frozen_model.pbtxt") - convert_dp12_to_dp13("frozen_model.pbtxt") - convert_dp13_to_dp20("frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - convert_pbtxt_to_pb("frozen_model.pbtxt", output_model) - if os.path.isfile("frozen_model.pbtxt"): - os.remove("frozen_model.pbtxt") - print("the converted output model (2.1 support) is saved in %s" % output_model) + convert_to_21(input_model, output_model, version="0.12") def convert_20_to_21(input_model: str, output_model: str): @@ -180,12 +162,7 @@ def convert_20_to_21(input_model: str, output_model: str): output_model : str filename of the output graph """ - convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt") - convert_dp20_to_dp21("frozen_model.pbtxt") - convert_pbtxt_to_pb("frozen_model.pbtxt", output_model) - if os.path.isfile("frozen_model.pbtxt"): - os.remove("frozen_model.pbtxt") - print("the converted output model (2.1 support) is saved in %s" % output_model) + convert_to_21(input_model, output_model, version="2.0") def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 1f43121e65..006b391e49 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -8,6 +8,7 @@ run_dp, tests_path, ) +from packaging.version import parse as parse_version from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -750,33 +751,33 @@ def test_detect(self): new_model_pb = "deeppot_new.pb" convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model) version = detect_model_version(old_model) - self.assertEqual(version, "<= 0.12") + self.assertEqual(version, parse_version("0.12")) os.remove(old_model) shutil.copyfile(str(tests_path / "infer" / "sea_012.pbtxt"), new_model_txt) convert_dp012_to_dp10(new_model_txt) convert_pbtxt_to_pb(new_model_txt, new_model_pb) version = detect_model_version(new_model_pb) - self.assertEqual(version, "1.0") + self.assertEqual(version, parse_version("1.0")) os.remove(new_model_pb) convert_dp10_to_dp11(new_model_txt) convert_pbtxt_to_pb(new_model_txt, new_model_pb) version = detect_model_version(new_model_pb) - self.assertEqual(version, "1.3") + self.assertEqual(version, parse_version("1.3")) os.remove(new_model_pb) convert_dp12_to_dp13(new_model_txt) convert_pbtxt_to_pb(new_model_txt, new_model_pb) version = detect_model_version(new_model_pb) - self.assertEqual(version, "1.3") + self.assertEqual(version, parse_version("1.3")) os.remove(new_model_pb) convert_dp13_to_dp20(new_model_txt) convert_pbtxt_to_pb(new_model_txt, new_model_pb) version = detect_model_version(new_model_pb) - self.assertEqual(version, "2.0") + self.assertEqual(version, parse_version("2.0")) os.remove(new_model_pb) convert_dp20_to_dp21(new_model_txt) convert_pbtxt_to_pb(new_model_txt, new_model_pb) version = detect_model_version(new_model_pb) - self.assertEqual(version, ">= 2.1") + self.assertEqual(version, parse_version("2.1")) os.remove(new_model_pb) os.remove(new_model_txt)