Skip to content

Commit

Permalink
Update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Faraz9877 committed Dec 12, 2024
1 parent f222aae commit c7a3a7d
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 32 deletions.
2 changes: 0 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);

bool cutlass_scaled_sparse_mm_supports_fp8(int64_t cuda_device_capability);

void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& e, torch::Tensor const& b,
torch::Tensor const& a_scales,
Expand Down
16 changes: 0 additions & 16 deletions csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,6 @@ void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
c10::optional<torch::Tensor> const& bias);
#endif

bool cutlass_scaled_sparse_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)

#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif

return false;
}

int32_t test_get_sm_version_num() {
int32_t major_capability, minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
Expand Down
12 changes: 3 additions & 9 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);

// Test
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor e,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);

// Test
ops.def(
"cutlass_scaled_sparse_mm_supports_fp8(int cuda_device_capability) -> "
"bool");
ops.impl("cutlass_scaled_sparse_mm_supports_fp8",
&cutlass_scaled_sparse_mm_supports_fp8);

// Test
// CUTLASS sparse matrix compressor
ops.def(
"cutlass_compress_entry(Tensor! a_compressed, Tensor! e,"
" Tensor a) -> bool");
Expand Down
5 changes: 0 additions & 5 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,6 @@ def cutlass_scaled_mm_azp(a: torch.Tensor,
return out


def cutlass_scaled_sparse_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_sparse_mm_supports_fp8(
cuda_device_capability)


def cutlass_compress_entry(a: torch.Tensor) \
-> Tuple[torch.Tensor, torch.Tensor]:
assert (a.dtype is torch.int8 or a.dtype is torch.float8_e4m3fn or \
Expand Down

0 comments on commit c7a3a7d

Please sign in to comment.