Skip to content

Commit

Permalink
Fix cmake errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Nov 14, 2024
1 parent b345cc8 commit 2d03e1d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 145 deletions.
78 changes: 12 additions & 66 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG be692b48b01620eedabeef8325df5d4eeed6c2ae
GIT_TAG 1dbae0329c6d907b72b373667b4d5716bae4415f
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
# GIT_SHALLOW FALSE
)
FetchContent_MakeAvailable(cutlass)

Expand Down Expand Up @@ -258,11 +258,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()

#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm cutlass_scaled_sparse_mm, and cutlass_compressor kernels
# For Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/sparse/cutlass/sparse_compressor.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
Expand All @@ -271,12 +274,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
message(STATUS "Not building cutlass_c3x kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"later if you intend on running FP8 quantized models or sparse on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
message(STATUS "Not building cutlass_c3x as no compatible archs found "
"in CUDA target architectures")
endif()

Expand All @@ -285,63 +288,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SCALED_MM_3X_ARCHS)
endif()

#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_compressor.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
message(STATUS "Building test_util for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building test_util as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building test_util as no compatible archs found "
"in CUDA target architectures")
endif()

# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set(SCALED_MM_3X_ARCHS)
endif()

#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
message(STATUS "Building test_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building test_mm_c3x as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building test_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()

# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set(SCALED_MM_3X_ARCHS)
endif()


#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
Expand Down Expand Up @@ -458,8 +404,8 @@ define_gpu_extension_target(
# Setting this variable sidesteps the issue by calling the driver directly.
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)

include(nm_cutlass_c.cmake)
build_nm_cutlass_c()
# include(nm_cutlass_c.cmake)
# build_nm_cutlass_c()

#
# _moe_C extension
Expand Down
45 changes: 19 additions & 26 deletions benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,13 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
a_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)

# Create tensors
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)

Check failure on line 87 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:87:81: E501 Line too long (82 > 80)
aT = a.t()
bT = b.t()
bf16_a = a.to(dtype=torch.bfloat16)

Check failure on line 90 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:90:5: F841 Local variable `bf16_a` is assigned to but never used
bf16_bT = bT.to(dtype=torch.bfloat16)

Check failure on line 91 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:91:5: F841 Local variable `bf16_bT` is assigned to but never used
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

Check failure on line 94 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:94:5: F841 Local variable `bias` is assigned to but never used
Expand All @@ -94,7 +99,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))
bT.to(dtype=torch.bfloat16, device="cuda")))

# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
Expand All @@ -103,7 +108,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
bT,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
Expand All @@ -115,7 +120,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
bT,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
Expand All @@ -128,7 +133,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
bT,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
Expand All @@ -140,7 +145,7 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
bT,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
Expand All @@ -149,24 +154,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b,
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a,

Check failure on line 157 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:157:81: E501 Line too long (85 > 80)
torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16))

# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.bfloat16,
bias))

# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm, a_compressed, e, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))
ops.cutlass_scaled_sparse_mm, b_compressed, e, aT, scale_b, scale_a, torch.float16))

Check failure on line 162 in benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/sparse_mm/bench_v1.py:162:81: E501 Line too long (101 > 80)

return timers

Expand Down Expand Up @@ -307,12 +300,12 @@ def bench_bf16(dtype: torch.dtype, m: int, k: int, n: int, label: str,

def bench_v1(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
# if dtype == torch.int8:
# return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
if dtype == torch.float16:
return bench_fp16(dtype, m, k, n, label, sub_label)
if dtype == torch.bfloat16:
return bench_bf16(dtype, m, k, n, label, sub_label)
# if dtype == torch.float16:
# return bench_fp16(dtype, m, k, n, label, sub_label)
# if dtype == torch.bfloat16:
# return bench_bf16(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type")
44 changes: 0 additions & 44 deletions nm_cutlass_c.cmake

This file was deleted.

18 changes: 9 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,15 +462,15 @@ def _read_requirements(filename: str) -> List[str]:
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))

if _is_cuda():
sparse_mm_generated_dir = './csrc/sparse/cutlass/generator/generated/'
sparse_mm_generated_dirs = \
[x for x in Path(sparse_mm_generated_dir).iterdir() if x.is_dir()]
sparse_mm_generated_dir_names = [x.name for x in sparse_mm_generated_dirs]
nm_cutlass_extensions = \
[f"vllm._nm_cutlass_{x}_C" for x in sparse_mm_generated_dir_names]
for x in nm_cutlass_extensions:
ext_modules.append(CMakeExtension(name=x))
# if _is_cuda():
# sparse_mm_generated_dir = './csrc/sparse/cutlass/generator/generated/'
# sparse_mm_generated_dirs = \
# [x for x in Path(sparse_mm_generated_dir).iterdir() if x.is_dir()]
# sparse_mm_generated_dir_names = [x.name for x in sparse_mm_generated_dirs]
# nm_cutlass_extensions = \
# [f"vllm._nm_cutlass_{x}_C" for x in sparse_mm_generated_dir_names]
# for x in nm_cutlass_extensions:
# ext_modules.append(CMakeExtension(name=x))

if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))
Expand Down

0 comments on commit 2d03e1d

Please sign in to comment.