Skip to content

Commit

Permalink
update pd version code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Sep 27, 2024
1 parent c792563 commit 72241ea
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 11 deletions.
4 changes: 4 additions & 0 deletions backend/dynamic_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
Optional,
)

from .find_paddle import (
get_pd_requirement,
)
from .find_pytorch import (
get_pt_requirement,
)
Expand Down Expand Up @@ -57,4 +60,5 @@ def dynamic_metadata(
**optional_dependencies,
**get_tf_requirement(tf_version),
**get_pt_requirement(pt_version),
**get_pd_requirement(pd_version),
}
4 changes: 2 additions & 2 deletions backend/find_paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_pd_requirement(pd_version: str = "") -> dict:
# https://peps.python.org/pep-0440/#version-matching
f"paddle=={Version(pd_version).base_version}.*"
if pd_version != ""
else "paddle>=3.0.0",
else "paddle>=3b",
],
}

Expand All @@ -138,7 +138,7 @@ def get_pd_version(pd_path: Optional[Union[str, Path]]) -> str:
"""
if pd_path is None or pd_path == "":
return ""
version_file = Path(pd_path) / "version.py"
version_file = Path(pd_path) / "version" / "__init__.py"
spec = importlib.util.spec_from_file_location("paddle.version", version_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
Expand Down
7 changes: 5 additions & 2 deletions backend/read_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .find_paddle import (
find_paddle,
get_pd_version,
)
from .find_pytorch import (
find_pytorch,
Expand All @@ -27,7 +28,7 @@


@lru_cache
def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]:
def get_argument_from_env() -> Tuple[str, list, list, dict, str, str, str]:
"""Get the arguments from environment variables.
The environment variables are assumed to be not changed during the build.
Expand All @@ -46,6 +47,8 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]:
The TensorFlow version.
str
The PyTorch version.
str
The Paddle version.
"""
cmake_args = []
extra_scripts = {}
Expand Down Expand Up @@ -125,7 +128,7 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str, str]:

if os.environ.get("DP_ENABLE_PADDLE", "0") == "1":
pd_install_dir, _ = find_paddle()
pt_version = get_pt_version(pd_install_dir)
pd_version = get_pd_version(pd_install_dir)
cmake_args.extend(
[
"-DENABLE_PADDLE=ON",
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def get_trainer(
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")
assert paddle.version.nccl() != "0"
dist.init_parallel_env()

def prepare_trainer_input_single(
model_params_single, data_dict_single, rank=0, seed=None
Expand Down
7 changes: 2 additions & 5 deletions deepmd/pd/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,20 @@ def __init__(
super().__init__(type_map, **kwargs)
ntypes = len(type_map)
self.type_map = type_map

self.ntypes = ntypes
self.descriptor = descriptor
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
super().init_out_stat()

# specify manually for access by name in C++ inference

# register 'type_map' as buffer
def string_to_array(s: str) -> int:
def _string_to_array(s: str) -> List[int]:
return [ord(c) for c in s]

self.register_buffer(
"buffer_type_map",
paddle.to_tensor(string_to_array(" ".join(self.type_map)), dtype="int32"),
paddle.to_tensor(_string_to_array(" ".join(self.type_map)), dtype="int32"),
)
self.buffer_type_map.name = "buffer_type_map"
# register 'has_message_passing' as buffer(cast to int32 as problems may meets with vector<bool>)
Expand Down

0 comments on commit 72241ea

Please sign in to comment.