diff --git a/tests/kernels/test_permute_cols.py b/tests/kernels/test_permute_cols.py new file mode 100644 index 0000000000000..dbc0fd0b3f8d6 --- /dev/null +++ b/tests/kernels/test_permute_cols.py @@ -0,0 +1,14 @@ +import torch +import pytest +from vllm._custom_ops import permute_cols +from tests.kernels.utils import opcheck + + +@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +def test_permute_cols(shape, dtype): + x = torch.randn(shape, dtype=dtype).cuda() + perm = torch.randperm(x.shape[1]).to(torch.int).cuda() + opcheck(torch.ops._C.permute_cols, (x, perm)) + y = permute_cols(x, perm) + torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a896a6ce3b729..8ef3c49140272 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -576,6 +576,18 @@ def machete_prepack_B(b_q_weight: torch.Tensor, return torch.ops._C.machete_prepack_B(b_q_weight, b_type) +# TODO: has to be a better way to do this +try: + torch.ops._C.permute_cols # noqa B018 + + @torch.library.register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) +except Exception: + pass + + def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.ops._C.permute_cols(a, perm)