Skip to content

Commit

Permalink
merge cuda and rocm files (#2844)
Browse files Browse the repository at this point in the history
Merge `source/lib/src/cuda` and `source/lib/src/rocm` into
`source/lib/src/gpu`.

- Define macros `gpuGetLastError`, `gpuDeviceSynchronize`, `gpuMemcpy`,
`gpuMemcpyDeviceToHost`, `gpuMemcpyHostToDevice`, and `gpuMemset` to
make them available for both CUDA and ROCm.
- Use `<<< >>> syntax` for both CUDA and ROCm. Per
ROCm/HIP@cf78d85,
it has been supported in HIP since 2018.
- Fix several int const numbers that should be double or float.
- For tabulate:
- Fix `WARP_SIZE` for ROCm. Per
pytorch/pytorch#64302, WARP_SIZE can be 32 or
64, so it should not be hardcoded to 64.
- Add `GpuShuffleSync`. Per
ROCm/HIP#1491, `__shfl_sync`
is not supported by HIP.
  - After merging the code, #1274 should also work for ROCm.
- Use the same `ii` for #830 and #2357. Although both of them work, `ii`
has different meanings in these two PRs, but now it should be the same.
- However, `ii` in `tabulate_fusion_se_a_fifth_order_polynomial` (rocm)
added by #2532 is wrong. After merging the codes, it should be
corrected.
  - Optimization in #830 was not applied to ROCm.
  - `__syncwarp` is not supported by ROCm.
- After merging the code, #2661 will be applied to ROCm. Although TF
ROCm stream is still blocking
(https://github.com/tensorflow/tensorflow/blob/9d1262082e761cd85d6726bcbdfdef331d6d72c6/tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc#L566),
we don't know whether it will change to non-blocking.
- There are several other differences between CUDA and ROCm.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Sep 22, 2023
1 parent 544875e commit 0f07afa
Show file tree
Hide file tree
Showing 41 changed files with 490 additions and 3,878 deletions.
4 changes: 2 additions & 2 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ Python:
Docs: doc/**/*
Examples: examples/**/*
Core: source/lib/**/*
CUDA: source/lib/src/cuda/**/*
ROCM: source/lib/src/rocm/**/*
CUDA: source/lib/src/gpu/**/*
ROCM: source/lib/src/gpu/**/*
OP: source/op/**/*
C++: source/api_cc/**/*
C: source/api_c/**/*
Expand Down
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "source/lib/src/cuda/cub"]
path = source/lib/src/cuda/cub
[submodule "source/lib/src/gpu/cub"]
path = source/lib/src/gpu/cub
url = https://github.com/NVIDIA/cub.git
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ repos:
rev: v16.0.6
hooks:
- id: clang-format
exclude: ^source/3rdparty|source/lib/src/cuda/cudart/.+\.inc
exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc
# CSS
- repo: https://github.com/pre-commit/mirrors-csslint
rev: v1.0.5
Expand Down Expand Up @@ -83,7 +83,7 @@ repos:
- --comment-style
- //
- --no-extra-eol
exclude: ^source/3rdparty|source/lib/src/cuda/cudart/.+\.inc
exclude: ^source/3rdparty|source/lib/src/gpu/cudart/.+\.inc
# CSS
- id: insert-license
files: \.(css|scss)$
Expand Down
4 changes: 2 additions & 2 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ One may set the following environment variables before executing `pip`:
| Environment variables | Allowed value | Default value | Usage |
| --------------------- | ---------------------- | ------------- | -------------------------- |
| DP_VARIANT | `cpu`, `cuda`, `rocm` | `cpu` | Build CPU variant or GPU variant with CUDA or ROCM support. |
| CUDAToolkit_ROOT | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 7.0 or later is supported. NVCC is required. |
| CUDAToolkit_ROOT | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is required. |
| ROCM_ROOT | Path | Detected automatically | The path to the ROCM toolkit directory. |
| TENSORFLOW_ROOT | Path | Detected automatically | The path to TensorFlow Python library. By default the installer only finds TensorFlow under user site-package directory (`site.getusersitepackages()`) or system site-package directory (`sysconfig.get_path("purelib")`) due to limitation of [PEP-517](https://peps.python.org/pep-0517/). If not found, the latest TensorFlow (or the environment variable `TENSORFLOW_VERSION` if given) from PyPI will be built against.|
| DP_ENABLE_NATIVE_OPTIMIZATION | 0, 1 | 0 | Enable compilation optimization for the native machine's CPU type. Do not enable it if generated code will run on different CPUs. |
Expand Down Expand Up @@ -188,7 +188,7 @@ One may add the following arguments to `cmake`:
| -DTENSORFLOW_ROOT=&lt;value&gt; | Path | - | The Path to TensorFlow's C++ interface. |
| -DCMAKE_INSTALL_PREFIX=&lt;value&gt; | Path | - | The Path where DeePMD-kit will be installed. |
| -DUSE_CUDA_TOOLKIT=&lt;value&gt; | `TRUE` or `FALSE` | `FALSE` | If `TRUE`, Build GPU support with CUDA toolkit. |
| -DCUDAToolkit_ROOT=&lt;value&gt; | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 7.0 or later is supported. NVCC is required. |
| -DCUDAToolkit_ROOT=&lt;value&gt; | Path | Detected automatically | The path to the CUDA toolkit directory. CUDA 9.0 or later is supported. NVCC is required. |
| -DUSE_ROCM_TOOLKIT=&lt;value&gt; | `TRUE` or `FALSE` | `FALSE` | If `TRUE`, Build GPU support with ROCM toolkit. |
| -DCMAKE_HIP_COMPILER_ROCM_ROOT=&lt;value&gt; | Path | Detected automatically | The path to the ROCM toolkit directory. |
| -DLAMMPS_SOURCE_ROOT=&lt;value&gt; | Path | - | Only neccessary for LAMMPS plugin mode. The path to the [LAMMPS source code](install-lammps.md). LAMMPS 8Apr2021 or later is supported. If not assigned, the plugin mode will not be enabled. |
Expand Down
4 changes: 2 additions & 2 deletions source/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ target_include_directories(

if(USE_CUDA_TOOLKIT)
add_definitions("-DGOOGLE_CUDA")
add_subdirectory(src/cuda)
add_subdirectory(src/gpu)
set(EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_cuda)
target_link_libraries(${libname} INTERFACE deepmd_dyn_cudart ${EXTRA_LIBS})
# gpu_cuda.h
Expand All @@ -22,7 +22,7 @@ endif()

if(USE_ROCM_TOOLKIT)
add_definitions("-DTENSORFLOW_USE_ROCM")
add_subdirectory(src/rocm)
add_subdirectory(src/gpu)
set(EXTRA_LIBS ${EXTRA_LIBS} deepmd_op_rocm)
target_link_libraries(${libname} INTERFACE ${ROCM_LIBRARIES} ${EXTRA_LIBS})
# gpu_rocm.h
Expand Down
7 changes: 7 additions & 0 deletions source/lib/include/gpu_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@

#include "errors.h"

#define gpuGetLastError cudaGetLastError
#define gpuDeviceSynchronize cudaDeviceSynchronize
#define gpuMemcpy cudaMemcpy
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
#define gpuMemset cudaMemset

#define GPU_MAX_NBOR_SIZE 4096
#define DPErrcheck(res) \
{ DPAssert((res), __FILE__, __LINE__); }
Expand Down
7 changes: 7 additions & 0 deletions source/lib/include/gpu_rocm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@

#define GPU_MAX_NBOR_SIZE 4096

#define gpuGetLastError hipGetLastError
#define gpuDeviceSynchronize hipDeviceSynchronize
#define gpuMemcpy hipMemcpy
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#define gpuMemset hipMemset

#define DPErrcheck(res) \
{ DPAssert((res), __FILE__, __LINE__); }
inline void DPAssert(hipError_t code,
Expand Down
60 changes: 0 additions & 60 deletions source/lib/src/cuda/CMakeLists.txt

This file was deleted.

95 changes: 95 additions & 0 deletions source/lib/src/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
if(USE_CUDA_TOOLKIT)
# required cmake version 3.23: CMAKE_CUDA_ARCHITECTURES all
cmake_minimum_required(VERSION 3.23)
# project name
project(deepmd_op_cuda)
set(GPU_LIB_NAME deepmd_op_cuda)

set(CMAKE_CUDA_ARCHITECTURES all)
enable_language(CUDA)
set(CMAKE_CUDA_STANDARD 11)
add_compile_definitions(
"$<$<COMPILE_LANGUAGE:CUDA>:_GLIBCXX_USE_CXX11_ABI=${OP_CXX_ABI}>")

find_package(CUDAToolkit REQUIRED)

# take dynamic open cudart library replace of static one so it's not required
# when using CPUs
add_subdirectory(cudart)

# nvcc -o libdeepmd_op_cuda.so -I/usr/local/cub-1.8.0 -rdc=true
# -DHIGH_PREC=true -gencode arch=compute_61,code=sm_61 -shared -Xcompiler
# -fPIC deepmd_op.cu -L/usr/local/cuda/lib64 -lcudadevrt very important here!
# Include path to cub. for searching device compute capability,
# https://developer.nvidia.com/cuda-gpus

# cub has been included in CUDA Toolkit 11, we do not need to include it any
# more see https://github.com/NVIDIA/cub
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS "11")
include_directories(cub)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS "9")
message(FATAL_ERROR "CUDA version must be >= 9.0")
endif()

message(STATUS "NVCC version is " ${CMAKE_CUDA_COMPILER_VERSION})

# arch will be configured by CMAKE_CUDA_ARCHITECTURES
set(CMAKE_CUDA_FLAGS
"${CMAKE_CUDA_FLAGS} -DCUB_IGNORE_DEPRECATED_CPP_DIALECT -DCUB_IGNORE_DEPRECATED_CPP_DIALECT"
)

file(GLOB SOURCE_FILES "*.cu")

add_library(${GPU_LIB_NAME} SHARED ${SOURCE_FILES})
target_link_libraries(${GPU_LIB_NAME} PRIVATE deepmd_dyn_cudart)

elseif(USE_ROCM_TOOLKIT)

# required cmake version
cmake_minimum_required(VERSION 3.21)
# project name
project(deepmd_op_rocm)
set(GPU_LIB_NAME deepmd_op_rocm)
set(CMAKE_LINK_WHAT_YOU_USE TRUE)

# set c++ version c++11
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_HIP_STANDARD 14)
add_definitions("-DCUB_IGNORE_DEPRECATED_CPP_DIALECT")
add_definitions("-DCUB_IGNORE_DEPRECATED_CPP_DIALECT")

message(STATUS "HIP major version is " ${HIP_VERSION_MAJOR})

set(HIP_HIPCC_FLAGS -fno-gpu-rdc; -fPIC --std=c++14 ${HIP_HIPCC_FLAGS}
)# --amdgpu-target=gfx906
if(HIP_VERSION VERSION_LESS 3.5.1)
set(HIP_HIPCC_FLAGS -hc; ${HIP_HIPCC_FLAGS})
endif()

file(GLOB SOURCE_FILES "*.cu")

hip_add_library(${GPU_LIB_NAME} SHARED ${SOURCE_FILES})

endif()

target_include_directories(
${GPU_LIB_NAME}
PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../../include/>
$<INSTALL_INTERFACE:include>)
target_precompile_headers(${GPU_LIB_NAME} PUBLIC [["device.h"]])
if(APPLE)
set_target_properties(${GPU_LIB_NAME} PROPERTIES INSTALL_RPATH @loader_path)
else()
set_target_properties(${GPU_LIB_NAME} PROPERTIES INSTALL_RPATH "$ORIGIN")
endif()

if(BUILD_CPP_IF AND NOT BUILD_PY_IF)
install(
TARGETS ${GPU_LIB_NAME}
EXPORT ${CMAKE_PROJECT_NAME}Targets
DESTINATION lib/)
endif(BUILD_CPP_IF AND NOT BUILD_PY_IF)
if(BUILD_PY_IF)
install(TARGETS ${GPU_LIB_NAME} DESTINATION deepmd/lib/)
endif(BUILD_PY_IF)
50 changes: 26 additions & 24 deletions source/lib/src/cuda/coord.cu → source/lib/src/gpu/coord.cu
Original file line number Diff line number Diff line change
Expand Up @@ -266,21 +266,21 @@ void compute_int_data(int *int_data,
_fill_idx_cellmap<<<nblock_loc, TPB>>>(idx_cellmap, idx_cellmap_noshift, in_c,
rec_boxt, nat_stt, nat_end, ext_stt,
ext_end, nloc);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());

const int nblock_loc_cellnum = (loc_cellnum + TPB - 1) / TPB;
_fill_loc_cellnum_map<<<nblock_loc_cellnum, TPB>>>(
temp_idx_order, loc_cellnum_map, idx_cellmap_noshift, nloc, loc_cellnum);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());

const int nblock_total_cellnum = (total_cellnum + TPB - 1) / TPB;
_fill_total_cellnum_map<<<nblock_total_cellnum, TPB>>>(
total_cellnum_map, mask_cellnum_map, cell_map, cell_shift_map, nat_stt,
nat_end, ext_stt, ext_end, loc_cellnum_map, total_cellnum);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}

void build_loc_clist(int *int_data,
Expand All @@ -297,8 +297,8 @@ void build_loc_clist(int *int_data,
total_cellnum * 3 + loc_cellnum + 1 + total_cellnum + 1;
_build_loc_clist<<<nblock, TPB>>>(loc_clist, idx_cellmap_noshift,
temp_idx_order, sec_loc_cellnum_map, nloc);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}

template <typename FPTYPE>
Expand Down Expand Up @@ -326,23 +326,23 @@ void copy_coord(FPTYPE *out_c,
cell_shift_map, sec_loc_cellnum_map,
sec_total_cellnum_map, loc_clist, nloc, nall,
total_cellnum, boxt, rec_boxt);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}

namespace deepmd {
template <typename FPTYPE>
void normalize_coord_gpu(FPTYPE *coord,
const int natom,
const Region<FPTYPE> &region) {
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
const FPTYPE *boxt = region.boxt;
const FPTYPE *rec_boxt = region.rec_boxt;
const int nblock = (natom + TPB - 1) / TPB;
normalize_one<<<nblock, TPB>>>(coord, boxt, rec_boxt, natom);
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}

// int_data(temp cuda
Expand All @@ -362,16 +362,17 @@ int copy_coord_gpu(FPTYPE *out_c,
const int &total_cellnum,
const int *cell_info,
const Region<FPTYPE> &region) {
DPErrcheck(cudaGetLastError());
DPErrcheck(cudaDeviceSynchronize());
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
compute_int_data(int_data, in_c, cell_info, region, nloc, loc_cellnum,
total_cellnum);
int *int_data_cpu = new int
[loc_cellnum + 2 * total_cellnum + loc_cellnum + 1 + total_cellnum +
1]; // loc_cellnum_map,total_cellnum_map,mask_cellnum_map,sec_loc_cellnum_map,sec_total_cellnum_map
DPErrcheck(cudaMemcpy(int_data_cpu, int_data + 3 * nloc,
sizeof(int) * (loc_cellnum + 2 * total_cellnum),
cudaMemcpyDeviceToHost));
DPErrcheck(gpuMemcpy(int_data_cpu, int_data + 3 * nloc,
sizeof(int) * (loc_cellnum + 2 * total_cellnum),
gpuMemcpyDeviceToHost));
DPErrcheck(gpuGetLastError());
int *loc_cellnum_map = int_data_cpu;
int *total_cellnum_map = loc_cellnum_map + loc_cellnum;
int *mask_cellnum_map = total_cellnum_map + total_cellnum;
Expand All @@ -397,11 +398,12 @@ int copy_coord_gpu(FPTYPE *out_c,
// size of the output arrays is not large enough
return 1;
} else {
DPErrcheck(cudaMemcpy(int_data + nloc * 3 + loc_cellnum +
total_cellnum * 3 + total_cellnum * 3,
sec_loc_cellnum_map,
sizeof(int) * (loc_cellnum + 1 + total_cellnum + 1),
cudaMemcpyHostToDevice));
DPErrcheck(gpuMemcpy(int_data + nloc * 3 + loc_cellnum + total_cellnum * 3 +
total_cellnum * 3,
sec_loc_cellnum_map,
sizeof(int) * (loc_cellnum + 1 + total_cellnum + 1),
gpuMemcpyHostToDevice));
DPErrcheck(gpuGetLastError());
delete[] int_data_cpu;
build_loc_clist(int_data, nloc, loc_cellnum, total_cellnum);
copy_coord(out_c, out_t, mapping, int_data, in_c, in_t, nloc, *nall,
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 0f07afa

Please sign in to comment.