Skip to content

Commit

Permalink
refactor convert
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Sep 21, 2023
1 parent 544875e commit 3849117
Showing 1 changed file with 47 additions and 70 deletions.
117 changes: 47 additions & 70 deletions deepmd/utils/convert.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import textwrap
from typing import (
Optional,
)

from google.protobuf import (
text_format,
)
from packaging.specifiers import (
SpecifierSet,
)
from packaging.version import parse as parse_version

from deepmd import (
__version__,
)
from deepmd.env import (
tf,
)

log = logging.getLogger(__name__)


def detect_model_version(input_model: str):
"""Detect DP graph version.
Expand All @@ -20,33 +33,33 @@ def detect_model_version(input_model: str):
filename of the input graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
version = "undetected"
version = None
with open("frozen_model.pbtxt") as fp:
file_content = fp.read()
if file_content.find("DescrptNorot") > -1:
version = "<= 0.12"
version = parse_version("0.12")
elif (
file_content.find("fitting_attr/dfparam") > -1
and file_content.find("fitting_attr/daparam") == -1
):
version = "1.0"
version = parse_version("1.0")
elif file_content.find("model_attr/model_version") == -1:
name_dsea = file_content.find('name: "DescrptSeA"')
post_dsea = file_content[name_dsea:]
post_dsea2 = post_dsea[:300].find(r"}")
search_double = post_dsea[:post_dsea2]
if search_double.find("DT_DOUBLE") == -1:
version = "1.2"
version = parse_version("1.2")
else:
version = "1.3"
version = parse_version("1.3")
elif file_content.find('string_val: "1.0"') > -1:
version = "2.0"
version = parse_version("2.0")
elif file_content.find('string_val: "1.1"') > -1:
version = ">= 2.1"
version = parse_version("2.1")
return version


def convert_to_21(input_model: str, output_model: str):
def convert_to_21(input_model: str, output_model: str, version: Optional[str] = None):
"""Convert DP graph to 2.1 graph.
Parameters
Expand All @@ -55,37 +68,36 @@ def convert_to_21(input_model: str, output_model: str):
filename of the input graph
output_model : str
filename of the output graph
version : str
version of the input graph, if not specified, it will be detected automatically
"""
version = detect_model_version(input_model)
if version == "<= 0.12":
if version is None:
version = detect_model_version(input_model)
else:
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
if version is None:
raise ValueError(
"The version of the DP graph %s cannot be detected. Please do the conversion manually."
% (input_model)
)
if version in SpecifierSet("<1.0"):
convert_dp012_to_dp10("frozen_model.pbtxt")
if version in SpecifierSet("<1.1"):
convert_dp10_to_dp11("frozen_model.pbtxt")
if version in SpecifierSet("<1.3"):
convert_dp12_to_dp13("frozen_model.pbtxt")
if version in SpecifierSet("<2.0"):
convert_dp13_to_dp20("frozen_model.pbtxt")
if version in SpecifierSet("<2.1"):
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "1.0":
convert_dp10_to_dp11("frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "1.2":
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "1.3":
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "2.0":
convert_dp20_to_dp21("frozen_model.pbtxt")
elif version == "undetected":
raise ValueError(
"The version of the DP graph %s cannot be detected. Please do the conversion manually."
% (input_model)
)
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
log.info(
"the converted output model (%s support) is saved in %s",
__version__,
output_model,
)


def convert_13_to_21(input_model: str, output_model: str):
Expand All @@ -98,13 +110,7 @@ def convert_13_to_21(input_model: str, output_model: str):
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="1.3")


def convert_12_to_21(input_model: str, output_model: str):
Expand All @@ -117,14 +123,7 @@ def convert_12_to_21(input_model: str, output_model: str):
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="1.2")


def convert_10_to_21(input_model: str, output_model: str):
Expand All @@ -137,15 +136,7 @@ def convert_10_to_21(input_model: str, output_model: str):
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp10_to_dp11("frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="1.0")


def convert_012_to_21(input_model: str, output_model: str):
Expand All @@ -158,16 +149,7 @@ def convert_012_to_21(input_model: str, output_model: str):
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp012_to_dp10("frozen_model.pbtxt")
convert_dp10_to_dp11("frozen_model.pbtxt")
convert_dp12_to_dp13("frozen_model.pbtxt")
convert_dp13_to_dp20("frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="0.12")


def convert_20_to_21(input_model: str, output_model: str):
Expand All @@ -180,12 +162,7 @@ def convert_20_to_21(input_model: str, output_model: str):
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, "frozen_model.pbtxt")
convert_dp20_to_dp21("frozen_model.pbtxt")
convert_pbtxt_to_pb("frozen_model.pbtxt", output_model)
if os.path.isfile("frozen_model.pbtxt"):
os.remove("frozen_model.pbtxt")
print("the converted output model (2.1 support) is saved in %s" % output_model)
convert_to_21(input_model, output_model, version="2.0")


def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str):
Expand Down

0 comments on commit 3849117

Please sign in to comment.