-
Notifications
You must be signed in to change notification settings - Fork 13
/
CMakeLists.txt
executable file
·99 lines (82 loc) · 3.5 KB
/
CMakeLists.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
cmake_minimum_required(VERSION 3.12)
project(flash-attention LANGUAGES CXX CUDA)
find_program(GPU_dev_info nvidia-smi)
if(NOT GPU_dev_info)
message(FATAL_ERROR "GPU driver not found.")
endif()
find_program(NVCC nvcc)
if(NOT NVCC)
message(FATAL_ERROR "NVCC not found. Please make sure CUDA Toolkit is installed.")
endif()
execute_process(COMMAND ${GPU_dev_info} OUTPUT_VARIABLE GPU_dev_version)
string(REGEX MATCH "H100" GPU_H100 ${GPU_dev_version})
execute_process(COMMAND ${NVCC} --version OUTPUT_VARIABLE NVCC_VERSION_OUTPUT)
string(REGEX MATCH "([0-9]+\\.[0-9]+)" NVCC_VERSION ${NVCC_VERSION_OUTPUT})
if(NOT NVCC_VERSION)
message(FATAL_ERROR "Failed to determine NVCC version.")
endif()
if(GPU_H100 AND NVCC_VERSION GREATER_EQUAL 12)
set(compute_capability "90a")
add_definitions(-DHOPPER)
else()
set(compute_capability "80")
endif()
message("compute_capability: ${compute_capability}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --use_fast_math -t 8 \
-gencode=arch=compute_${compute_capability},code=\\\"sm_${compute_capability},compute_${compute_capability}\\\" \
")
# ################ CUDA ################
find_package(CUDAToolkit REQUIRED)
include_directories("./cutlass/include")
include_directories("./include")
include_directories(${CUDAToolkit_INCLUDE_DIRS})
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64)
if(compute_capability STREQUAL "90a")
add_library(flash_attn SHARED
src/flash.cu
src/flash_fwd_hdim128_fp16_sm80.cu
src/flash_fwd_hdim160_fp16_sm80.cu
src/flash_fwd_hdim192_fp16_sm80.cu
src/flash_fwd_hdim224_fp16_sm80.cu
src/flash_fwd_hdim256_fp16_sm80.cu
src/flash_fwd_hdim32_fp16_sm80.cu
src/flash_fwd_hdim64_fp16_sm80.cu
src/flash_fwd_hdim96_fp16_sm80.cu
src/flash_fwd_split_hdim128_fp16_sm80.cu
src/flash_fwd_split_hdim160_fp16_sm80.cu
src/flash_fwd_split_hdim192_fp16_sm80.cu
src/flash_fwd_split_hdim224_fp16_sm80.cu
src/flash_fwd_split_hdim256_fp16_sm80.cu
src/flash_fwd_split_hdim32_fp16_sm80.cu
src/flash_fwd_split_hdim64_fp16_sm80.cu
src/flash_fwd_split_hdim96_fp16_sm80.cu
hopper/flash_fwd_hdim64_fp16_sm90.cu
hopper/flash_fwd_hdim128_fp16_sm90.cu
hopper/flash_fwd_hdim256_fp16_sm90.cu
hopper/flash_fwd_hdim64_e4m3_sm90.cu
hopper/flash_fwd_hdim128_e4m3_sm90.cu
hopper/flash_fwd_hdim256_e4m3_sm90.cu
)
else()
add_library(flash_attn SHARED
src/flash.cu
src/flash_fwd_hdim128_fp16_sm80.cu
src/flash_fwd_hdim160_fp16_sm80.cu
src/flash_fwd_hdim192_fp16_sm80.cu
src/flash_fwd_hdim224_fp16_sm80.cu
src/flash_fwd_hdim256_fp16_sm80.cu
src/flash_fwd_hdim32_fp16_sm80.cu
src/flash_fwd_hdim64_fp16_sm80.cu
src/flash_fwd_hdim96_fp16_sm80.cu
src/flash_fwd_split_hdim128_fp16_sm80.cu
src/flash_fwd_split_hdim160_fp16_sm80.cu
src/flash_fwd_split_hdim192_fp16_sm80.cu
src/flash_fwd_split_hdim224_fp16_sm80.cu
src/flash_fwd_split_hdim256_fp16_sm80.cu
src/flash_fwd_split_hdim32_fp16_sm80.cu
src/flash_fwd_split_hdim64_fp16_sm80.cu
src/flash_fwd_split_hdim96_fp16_sm80.cu
)
endif()
set_target_properties(flash_attn PROPERTIES CUDA_ARCHITECTURES "${compute_capability}")