From 3c5199d89a41305c5bf037dad785cd413e84005f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 13 Nov 2024 17:49:41 -0500 Subject: [PATCH] feat(jax): add options to use TensorFlow C library to build the JAX backend Signed-off-by: Jinzhe Zeng --- doc/install/easy-install.md | 4 +++ doc/install/install-from-c-library.md | 4 +-- doc/install/install-from-source.md | 31 ++++++++++++++++++++ source/CMakeLists.txt | 22 ++++++++++++++ source/api_cc/CMakeLists.txt | 4 +++ source/api_cc/src/DeepPot.cc | 6 ++-- source/api_cc/src/DeepPotJAX.cc | 2 +- source/cmake/FindTensorFlowC.cmake | 42 +++++++++++++++++++++++++++ 8 files changed, 110 insertions(+), 5 deletions(-) create mode 100644 source/cmake/FindTensorFlowC.cmake diff --git a/doc/install/easy-install.md b/doc/install/easy-install.md index c2260b58b6..5241b3d0a0 100644 --- a/doc/install/easy-install.md +++ b/doc/install/easy-install.md @@ -204,6 +204,10 @@ pip install deepmd-kit[jax] :::: +To generate a SavedModel and use [the LAMMPS module](../third-party/lammps-command.md) and [the i-PI driver](../third-party/ipi.md), +you need to install the TensorFlow. +Switch to the TensorFlow {{ tensorflow_icon }} tab for more information. + ::::: :::::: diff --git a/doc/install/install-from-c-library.md b/doc/install/install-from-c-library.md index cb7808bfdf..d408fb1b67 100644 --- a/doc/install/install-from-c-library.md +++ b/doc/install/install-from-c-library.md @@ -1,7 +1,7 @@ -# Install from pre-compiled C library {{ tensorflow_icon }} +# Install from pre-compiled C library {{ tensorflow_icon }}, JAX {{ jax_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, JAX {{ jax_icon }} ::: DeePMD-kit provides pre-compiled C library package (`libdeepmd_c.tar.gz`) in each [release](https://github.com/deepmodeling/deepmd-kit/releases). It can be used to build the [LAMMPS plugin](./install-lammps.md) and [GROMACS patch](./install-gromacs.md), as well as many [third-party software packages](../third-party/out-of-deepmd-kit.md), without building TensorFlow and DeePMD-kit on one's own. diff --git a/doc/install/install-from-source.md b/doc/install/install-from-source.md index 0bf6fa5ee3..63060f692a 100644 --- a/doc/install/install-from-source.md +++ b/doc/install/install-from-source.md @@ -316,6 +316,15 @@ You can also download libtorch prebuilt library from the [PyTorch website](https ::: +:::{tab-item} JAX {{ jax_icon }} + +The JAX backend only depends on the TensorFlow C API, which is included in both TensorFlow C++ library and [TensorFlow C library](https://www.tensorflow.org/install/lang_c). +If you want to use the TensorFlow C++ library, just enable the TensorFlow backend (which depends on the TensorFlow C++ library) and nothing else needs to do. +If you want to use the TensorFlow C library and disable the TensorFlow backend, +download the TensorFlow C library from [this page](https://www.tensorflow.org/install/lang_c#download_and_extract). + +::: + :::: ### Install DeePMD-kit's C++ interface @@ -369,6 +378,17 @@ cmake -DENABLE_PYTORCH=TRUE -DUSE_PT_PYTHON_LIBS=TRUE -DCMAKE_INSTALL_PREFIX=$de ::: +:::{tab-item} JAX {{ jax_icon }} + +If you want to use the TensorFlow C++ library, just enable the TensorFlow backend and nothing else needs to do. +If you want to use the TensorFlow C library and disable the TensorFlow backend, set {cmake:variable}`ENABLE_JAX` to `ON` and `CMAKE_PREFIX_PATH` to the root directory of the [TensorFlow C library](https://www.tensorflow.org/install/lang_c). + +```bash +cmake -DENABLE_JAX=ON -D CMAKE_PREFIX_PATH=${tensorflow_c_root} .. +``` + +::: + :::: One may add the following CMake variables to `cmake` using the [`-D =` option](https://cmake.org/cmake/help/latest/manual/cmake.1.html#cmdoption-cmake-D): @@ -378,6 +398,7 @@ One may add the following CMake variables to `cmake` using the [`-D ===5 + set(OP_CXX_ABI 1) +endif() # log enabled backends if(NOT DEEPMD_C_ROOT) message(STATUS "Enabled backends:") @@ -255,8 +273,12 @@ if(NOT DEEPMD_C_ROOT) if(ENABLE_PYTORCH) message(STATUS "- PyTorch") endif() + if(ENABLE_JAX) + message(STATUS "- JAX") + endif() if(NOT ENABLE_TENSORFLOW AND NOT ENABLE_PYTORCH + AND NOT ENABLE_JAX AND NOT BUILD_PY_IF) message(FATAL_ERROR "No backend is enabled.") endif() diff --git a/source/api_cc/CMakeLists.txt b/source/api_cc/CMakeLists.txt index 228a6657d3..32d2d7e18a 100644 --- a/source/api_cc/CMakeLists.txt +++ b/source/api_cc/CMakeLists.txt @@ -23,6 +23,10 @@ if(ENABLE_PYTORCH target_link_libraries(${libname} PRIVATE "${TORCH_LIBRARIES}") target_compile_definitions(${libname} PRIVATE BUILD_PYTORCH) endif() +if(ENABLE_JAX) + target_link_libraries(${libname} PRIVATE TensorFlow::tensorflow_c) + target_compile_definitions(${libname} PRIVATE BUILD_JAX) +endif() target_include_directories( ${libname} diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 6f8724f78e..8769f5b211 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -7,12 +7,14 @@ #include "AtomMap.h" #include "common.h" #ifdef BUILD_TENSORFLOW -#include "DeepPotJAX.h" #include "DeepPotTF.h" #endif #ifdef BUILD_PYTORCH #include "DeepPotPT.h" #endif +#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX) +#include "DeepPotJAX.h" +#endif #include "device.h" using namespace deepmd; @@ -63,7 +65,7 @@ void DeepPot::init(const std::string& model, } else if (deepmd::DPBackend::Paddle == backend) { throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet"); } else if (deepmd::DPBackend::JAX == backend) { -#ifdef BUILD_TENSORFLOW +#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX) dp = std::make_shared(model, gpu_rank, file_content); #else throw deepmd::deepmd_exception( diff --git a/source/api_cc/src/DeepPotJAX.cc b/source/api_cc/src/DeepPotJAX.cc index be1a5542b4..908c36322c 100644 --- a/source/api_cc/src/DeepPotJAX.cc +++ b/source/api_cc/src/DeepPotJAX.cc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: LGPL-3.0-or-later -#ifdef BUILD_TENSORFLOW +#if defined(BUILD_TENSORFLOW) || defined(BUILD_JAX) #include "DeepPotJAX.h" diff --git a/source/cmake/FindTensorFlowC.cmake b/source/cmake/FindTensorFlowC.cmake new file mode 100644 index 0000000000..7319e810f7 --- /dev/null +++ b/source/cmake/FindTensorFlowC.cmake @@ -0,0 +1,42 @@ +# Find TensorFlow C library (libtensorflow) Define target +# TensorFlow::tensorflow_c If TensorFlow::tensorflow_cc is not found, also +# define: - TENSORFLOWC_INCLUDE_DIR - TENSORFLOWC_LIBRARY + +if(TARGET TensorFlow::tensorflow_cc) + # since tensorflow_cc contain tensorflow_c, just use it + add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) + target_link_libraries(TensorFlow::tensorflow_c + INTERFACE TensorFlow::tensorflow_cc) + set(TensorFlowC_FOUND TRUE) +endif() + +if(NOT TensorFlowC_FOUND) + find_path( + TENSORFLOWC_INCLUDE_DIR + NAMES tensorflow/c/c_api.h + PATH_SUFFIXES include + DOC "Path to TensorFlow C include directory") + + find_library( + TENSORFLOWC_LIBRARY + NAMES tensorflow + PATH_SUFFIXES lib + DOC "Path to TensorFlow C library") + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args( + TensorFlowC REQUIRED_VARS TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR) + + if(TensorFlowC_FOUND) + set(TensorFlowC_INCLUDE_DIRS ${TENSORFLOWC_INCLUDE_DIR}) + set(TensorFlowC_LIBRARIES ${TENSORFLOWC_LIBRARY}) + endif() + + add_library(TensorFlow::tensorflow_c SHARED IMPORTED GLOBAL) + set_property(TARGET TensorFlow::tensorflow_c PROPERTY IMPORTED_LOCATION + ${TENSORFLOWC_LIBRARY}) + target_include_directories(TensorFlow::tensorflow_c + INTERFACE ${TENSORFLOWC_INCLUDE_DIR}) + + mark_as_advanced(TENSORFLOWC_LIBRARY TENSORFLOWC_INCLUDE_DIR =) +endif()