Skip to content

Commit

Permalink
fix correctness test
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Sep 16, 2024
1 parent 096dd4a commit a98f691
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
8 changes: 7 additions & 1 deletion csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
} else {
// Extra strides are assumed to be 0 or 1
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/machete/machete_mm_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));
args.group_size);
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");

Expand Down
3 changes: 2 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def machete_gemm_fake(
@torch.library.register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight)
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)

@torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
Expand Down

0 comments on commit a98f691

Please sign in to comment.