diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py new file mode 100644 index 0000000000..f039b6f289 --- /dev/null +++ b/backend/find_pytorch.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os +import site +from functools import ( + lru_cache, +) +from importlib.machinery import ( + FileFinder, +) +from importlib.util import ( + find_spec, +) +from pathlib import ( + Path, +) +from sysconfig import ( + get_path, +) +from typing import ( + Optional, +) + + +@lru_cache +def find_pytorch() -> Optional[str]: + """Find PyTorch library. + + Tries to find PyTorch in the order of: + + 1. Environment variable `PYTORCH_ROOT` if set + 2. The current Python environment. + 3. user site packages directory if enabled + 4. system site packages directory (purelib) + + Considering the default PyTorch package still uses old CXX11 ABI, we + cannot install it automatically. + + Returns + ------- + str, optional + PyTorch library path if found. + """ + if os.environ.get("DP_ENABLE_PYTORCH", "0") == "0": + return None + pt_spec = None + + if (pt_spec is None or not pt_spec) and os.environ.get("PYTORCH_ROOT") is not None: + site_packages = Path(os.environ.get("PYTORCH_ROOT")).parent.absolute() + pt_spec = FileFinder(str(site_packages)).find_spec("torch") + + # get pytorch spec + # note: isolated build will not work for backend + if pt_spec is None or not pt_spec: + pt_spec = find_spec("torch") + + if not pt_spec and site.ENABLE_USER_SITE: + # first search TF from user site-packages before global site-packages + site_packages = site.getusersitepackages() + if site_packages: + pt_spec = FileFinder(site_packages).find_spec("torch") + + if not pt_spec: + # purelib gets site-packages path + site_packages = get_path("purelib") + if site_packages: + pt_spec = FileFinder(site_packages).find_spec("torch") + + # get install dir from spec + try: + pt_install_dir = pt_spec.submodule_search_locations[0] # type: ignore + # AttributeError if ft_spec is None + # TypeError if submodule_search_locations are None + # IndexError if submodule_search_locations is an empty list + except (AttributeError, TypeError, IndexError): + pt_install_dir = None + return pt_install_dir diff --git a/backend/read_env.py b/backend/read_env.py index bee5d607e3..c97c854a13 100644 --- a/backend/read_env.py +++ b/backend/read_env.py @@ -13,6 +13,9 @@ Version, ) +from .find_pytorch import ( + find_pytorch, +) from .find_tensorflow import ( find_tensorflow, get_tf_version, @@ -99,6 +102,19 @@ def get_argument_from_env() -> Tuple[str, list, list, dict, str]: cmake_args.append("-DENABLE_TENSORFLOW=OFF") tf_version = None + if os.environ.get("DP_ENABLE_PYTORCH", "0") == "1": + pt_install_dir = find_pytorch() + if pt_install_dir is None: + raise RuntimeError("Cannot find installed PyTorch.") + cmake_args.extend( + [ + "-DENABLE_PYTORCH=ON", + f"-DCMAKE_PREFIX_PATH={pt_install_dir}", + ] + ) + else: + cmake_args.append("-DENABLE_PYTORCH=OFF") + cmake_args = [ "-DBUILD_PY_IF:BOOL=TRUE", *cmake_args, diff --git a/deepmd/pt/__init__.py b/deepmd/pt/__init__.py index 6ceb116d85..ab61736198 100644 --- a/deepmd/pt/__init__.py +++ b/deepmd/pt/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +# import customized OPs globally +from deepmd.pt.cxx_op import ( + ENABLE_CUSTOMIZED_OP, +) + +__all__ = [ + "ENABLE_CUSTOMIZED_OP", +] diff --git a/deepmd/pt/cxx_op.py b/deepmd/pt/cxx_op.py new file mode 100644 index 0000000000..7887b5722c --- /dev/null +++ b/deepmd/pt/cxx_op.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import platform + +import torch + +from deepmd.env import ( + SHARED_LIB_DIR, +) + + +def load_library(module_name: str) -> bool: + """Load OP library. + + Parameters + ---------- + module_name : str + Name of the module + + Returns + ------- + bool + Whether the library is loaded successfully + """ + if platform.system() == "Windows": + ext = ".dll" + prefix = "" + else: + ext = ".so" + prefix = "lib" + + module_file = (SHARED_LIB_DIR / (prefix + module_name)).with_suffix(ext).resolve() + + if module_file.is_file(): + torch.ops.load_library(module_file) + return True + return False + + +ENABLE_CUSTOMIZED_OP = load_library("deepmd_op_pt") + +__all__ = [ + "ENABLE_CUSTOMIZED_OP", +] diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7b1463a3b2..adaec0968a 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -32,6 +32,9 @@ from deepmd.main import ( parse_args, ) +from deepmd.pt.cxx_op import ( + ENABLE_CUSTOMIZED_OP, +) from deepmd.pt.infer import ( inference, ) @@ -224,6 +227,7 @@ def get_backend_info(self) -> dict: return { "Backend": "PyTorch", "PT ver": f"v{torch.__version__}-g{torch.version.git_version[:11]}", + "Enable custom OP": ENABLE_CUSTOMIZED_OP, } diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 9e86ee33b0..5195992853 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -118,6 +118,16 @@ Note that TensorFlow may have specific requirements for the compiler version to ::: +:::{tab-item} PyTorch {{ pytorch_icon }} + +You can set the environment variable `export DP_ENABLE_PYTORCH=1` to enable customized C++ OPs in the PyTorch backend. +Note that PyTorch may have specific requirements for the compiler version to support the C++ standard version and [`_GLIBCXX_USE_CXX11_ABI`](https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_dual_abi.html) used by PyTorch. + +The customized C++ OPs are not enabled by default because TensorFlow and PyTorch packages from the PyPI use different `_GLIBCXX_USE_CXX11_ABI` flags. +We recommend conda-forge packages in this case. + +::: + :::: Execute @@ -135,6 +145,7 @@ One may set the following environment variables before executing `pip`: | CUDAToolkit_ROOT | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is required. | | ROCM_ROOT | Path | Detected automatically | The path to the ROCM toolkit directory. | | DP_ENABLE_TENSORFLOW | 0, 1 | 1 | {{ tensorflow_icon }} Enable the TensorFlow backend. | +| DP_ENABLE_PYTORCH | 0, 1 | 0 | {{ pytorch_icon }} Enable customized C++ OPs for the PyTorch backend. PyTorch can still run without customized C++ OPs, but features will be limited. | | TENSORFLOW_ROOT | Path | Detected automatically | {{ tensorflow_icon }} The path to TensorFlow Python library. By default the installer only finds TensorFlow under user site-package directory (`site.getusersitepackages()`) or system site-package directory (`sysconfig.get_path("purelib")`) due to limitation of [PEP-517](https://peps.python.org/pep-0517/). If not found, the latest TensorFlow (or the environment variable `TENSORFLOW_VERSION` if given) from PyPI will be built against. | | DP_ENABLE_NATIVE_OPTIMIZATION | 0, 1 | 0 | Enable compilation optimization for the native machine's CPU type. Do not enable it if generated code will run on different CPUs. | | CMAKE_ARGS | str | - | Additional CMake arguments | diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 9560b69a70..bbbbac9578 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -305,6 +305,9 @@ if(NOT DEEPMD_C_ROOT) if(ENABLE_TENSORFLOW) add_subdirectory(op/) endif() + if(ENABLE_PYTORCH) + add_subdirectory(op/pt/) + endif() add_subdirectory(lib/) endif() if(BUILD_PY_IF) diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index 4c188280f2..b4631b5e46 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -37,6 +37,7 @@ void DeepPotPT::init(const std::string& model, << std::endl; return; } + deepmd::load_op_library(); int gpu_num = torch::cuda::device_count(); if (gpu_num > 0) { gpu_id = gpu_rank % gpu_num; diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index aa1e27ace1..07b6a10220 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -379,12 +379,19 @@ void deepmd::get_env_nthreads(int& num_intra_nthreads, void deepmd::load_op_library() { #ifdef BUILD_TENSORFLOW - tensorflow::Env* env = tensorflow::Env::Default(); + _load_single_op_library("deepmd_op") +#endif +#ifdef BUILD_PYTORCH + _load_single_op_library("deepmd_op_pt") +#endif +} + +static inline void _load_single_op_library(std::string library_name) { #if defined(_WIN32) - std::string dso_path = "deepmd_op.dll"; + std::string dso_path = library_name + ".dll"; void* dso_handle = LoadLibrary(dso_path.c_str()); #else - std::string dso_path = "libdeepmd_op.so"; + std::string dso_path = "lib" + library_name + ".so"; void* dso_handle = dlopen(dso_path.c_str(), RTLD_NOW | RTLD_LOCAL); #endif if (!dso_handle) { @@ -392,7 +399,6 @@ void deepmd::load_op_library() { dso_path + " is not found! You can add the library directory to LD_LIBRARY_PATH"); } -#endif } std::string deepmd::name_prefix(const std::string& scope) { diff --git a/source/op/pt/CMakeLists.txt b/source/op/pt/CMakeLists.txt new file mode 100644 index 0000000000..46ea38c193 --- /dev/null +++ b/source/op/pt/CMakeLists.txt @@ -0,0 +1,26 @@ +file(GLOB OP_SRC print_summary.cc) + +add_library(deepmd_op_pt MODULE ${OP_SRC}) +# link: libdeepmd libtorch +target_link_libraries(deepmd_op_pt PRIVATE ${TORCH_LIBRARIES} ${LIB_DEEPMD}) +if(APPLE) + set_target_properties(deepmd_op_pt PROPERTIES INSTALL_RPATH "@loader_path") +else() + set_target_properties(deepmd_op_pt PROPERTIES INSTALL_RPATH "$ORIGIN") +endif() + +find_package(MPI) +if(MPI_FOUND) + target_link_libraries(deepmd_op_pt INTERFACE MPI::MPI_CXX) + target_compile_definitions(deepmd_op_pt PRIVATE USE_MPI) +endif() + +if(CMAKE_TESTING_ENABLED) + target_link_libraries(deepmd_op_pt PRIVATE coverage_config) +endif() + +if(BUILD_PY_IF) + install(TARGETS deepmd_op_pt DESTINATION deepmd/lib/) +else(BUILD_PY_IF) + install(TARGETS deepmd_op_pt DESTINATION lib/) +endif(BUILD_PY_IF) diff --git a/source/op/pt/print_summary.cc b/source/op/pt/print_summary.cc new file mode 100644 index 0000000000..83209aab31 --- /dev/null +++ b/source/op/pt/print_summary.cc @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +#include + +#include + +torch::Tensor enable_mpi() { +#ifdef USE_MPI + return torch::ones({1}, torch::kBool); +#else + return torch::zeros({1}, torch::kBool); +#endif +} + +TORCH_LIBRARY(deepmd, m) { m.def("enable_mpi", enable_mpi); }