Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor convert #2854

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 47 additions & 70 deletions deepmd/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import textwrap
from typing import (
Optional,
)

from google.protobuf import (
text_format,
)
from packaging.specifiers import (
SpecifierSet,
)
from packaging.version import parse as parse_version

from deepmd import (
__version__,
)
from deepmd.env import (
tf,
)

log = logging.getLogger(__name__)


def detect_model_version(input_model: str):
"""Detect DP graph version.
Expand All @@ -20,33 +33,33 @@
filename of the input graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
version = "undetected"
version = None
with open("frozen_model.pbtxt") as fp:
file_content = fp.read()
if file_content.find("DescrptNorot") > -1:
version = "<= 0.12"
version = parse_version("0.12")
elif (
file_content.find("fitting_attr/dfparam") > -1
and file_content.find("fitting_attr/daparam") == -1
):
version = "1.0"
version = parse_version("1.0")
elif file_content.find("model_attr/model_version") == -1:
name_dsea = file_content.find('name: "DescrptSeA"')
post_dsea = file_content[name_dsea:]
post_dsea2 = post_dsea[:300].find(r"}")
search_double = post_dsea[:post_dsea2]
if search_double.find("DT_DOUBLE") == -1:
version = "1.2"
version = parse_version("1.2")

Check warning on line 52 in deepmd/utils/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/convert.py#L52

Added line #L52 was not covered by tests
else:
version = "1.3"
version = parse_version("1.3")
elif file_content.find('string_val: "1.0"') > -1:
version = "2.0"
version = parse_version("2.0")
elif file_content.find('string_val: "1.1"') > -1:
version = ">= 2.1"
version = parse_version("2.1")
return version


def convert_to_21(input_model: str, output_model: str):
def convert_to_21(input_model: str, output_model: str, version: Optional[str] = None):
"""Convert DP graph to 2.1 graph.

Parameters
Expand All @@ -55,37 +68,36 @@
filename of the input graph
output_model : str
filename of the output graph
version : str
version of the input graph, if not specified, it will be detected automatically
"""
version = detect_model_version(input_model)
if version == "<= 0.12":
if version is None:
version = detect_model_version(input_model)
else:
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
if version is None:
raise ValueError(

Check warning on line 79 in deepmd/utils/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/convert.py#L79

Added line #L79 was not covered by tests
"The version of the DP graph %s cannot be detected. Please do the conversion manually."
% (input_model)
)
if version in SpecifierSet("<1.0"):
convert_dp012_to_dp10("frozen_model.pbtxt")
if version in SpecifierSet("<1.1"):
convert_dp10_to_dp11("frozen_model.pbtxt")
if version in SpecifierSet("<1.3"):
convert_dp12_to_dp13("frozen_model.pbtxt")
if version in SpecifierSet("<2.0"):
convert_dp13_to_dp20("frozen_model.pbtxt")
if version in SpecifierSet("<2.1"):
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "1.0":
convert_dp10_to_dp11("frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "1.2":
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "1.3":
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "2.0":
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "undetected":
raise ValueError(
"The version of the DP graph %s cannot be detected. Please do the conversion manually."
% (input_model)
)
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
log.info(
"the converted output model (%s support) is saved in %s",
__version__,
output_model,
)


def convert_13_to_21(input_model: str, output_model: str):
Expand All @@ -98,13 +110,7 @@
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="1.3")

Check warning on line 113 in deepmd/utils/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/convert.py#L113

Added line #L113 was not covered by tests


def convert_12_to_21(input_model: str, output_model: str):
Expand All @@ -117,14 +123,7 @@
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="1.2")

Check warning on line 126 in deepmd/utils/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/convert.py#L126

Added line #L126 was not covered by tests


def convert_10_to_21(input_model: str, output_model: str):
Expand All @@ -137,15 +136,7 @@
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp10_to_dp11("frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="1.0")

Check warning on line 139 in deepmd/utils/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/convert.py#L139

Added line #L139 was not covered by tests


def convert_012_to_21(input_model: str, output_model: str):
Expand All @@ -158,16 +149,7 @@
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp012_to_dp10("frozen_model.pbtxt")
convert_dp10_to_dp11("frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="0.12")


def convert_20_to_21(input_model: str, output_model: str):
Expand All @@ -180,12 +162,7 @@
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="2.0")

Check warning on line 165 in deepmd/utils/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/convert.py#L165

Added line #L165 was not covered by tests


def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str):
Expand Down
13 changes: 7 additions & 6 deletions source/tests/test_deeppot_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
run_dp,
tests_path,
)
from packaging.version import parse as parse_version

from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
Expand Down Expand Up @@ -750,33 +751,33 @@ def test_detect(self):
new_model_pb = "deeppot_new.pb"
convert_pbtxt_to_pb(str(tests_path / "infer" / "sea_012.pbtxt"), old_model)
version = detect_model_version(old_model)
self.assertEqual(version, "<= 0.12")
self.assertEqual(version, parse_version("0.12"))
os.remove(old_model)
shutil.copyfile(str(tests_path / "infer" / "sea_012.pbtxt"), new_model_txt)
convert_dp012_to_dp10(new_model_txt)
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
version = detect_model_version(new_model_pb)
self.assertEqual(version, "1.0")
self.assertEqual(version, parse_version("1.0"))
os.remove(new_model_pb)
convert_dp10_to_dp11(new_model_txt)
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
version = detect_model_version(new_model_pb)
self.assertEqual(version, "1.3")
self.assertEqual(version, parse_version("1.3"))
os.remove(new_model_pb)
convert_dp12_to_dp13(new_model_txt)
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
version = detect_model_version(new_model_pb)
self.assertEqual(version, "1.3")
self.assertEqual(version, parse_version("1.3"))
os.remove(new_model_pb)
convert_dp13_to_dp20(new_model_txt)
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
version = detect_model_version(new_model_pb)
self.assertEqual(version, "2.0")
self.assertEqual(version, parse_version("2.0"))
os.remove(new_model_pb)
convert_dp20_to_dp21(new_model_txt)
convert_pbtxt_to_pb(new_model_txt, new_model_pb)
version = detect_model_version(new_model_pb)
self.assertEqual(version, ">= 2.1")
self.assertEqual(version, parse_version("2.1"))
os.remove(new_model_pb)
os.remove(new_model_txt)

Expand Down
Loading