From 2ec4e38d14831234c271f0c8f81c0a7bb4ac55ce Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Sat, 2 Nov 2024 04:09:18 +0000 Subject: [PATCH] Simplify obvious choices in gen_cmake_config.py (#3006) --- cmake/gen_cmake_config.py | 100 ++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 46 deletions(-) diff --git a/cmake/gen_cmake_config.py b/cmake/gen_cmake_config.py index 31972862dc..b03f686c4f 100644 --- a/cmake/gen_cmake_config.py +++ b/cmake/gen_cmake_config.py @@ -1,6 +1,6 @@ from collections import namedtuple -Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str"]) +Backend = namedtuple("Backend", ["name", "cmake_config_name", "prompt_str", "parent"]) if __name__ == "__main__": tvm_home = "" # pylint: disable=invalid-name @@ -13,65 +13,73 @@ cmake_config_str = f"set(TVM_SOURCE_DIR {tvm_home})\n" cmake_config_str += "set(CMAKE_BUILD_TYPE RelWithDebInfo)\n" + cuda_backend = Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): ", None) + opencl_backend = Backend("OpenCL", "USE_OPENCL", "Use OpenCL? (y/n) ", None) backends = [ - Backend("CUDA", "USE_CUDA", "Use CUDA? (y/n): "), - Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): "), - Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): "), - Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): "), - Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): "), + cuda_backend, + Backend("CUTLASS", "USE_CUTLASS", "Use CUTLASS? (y/n): ", cuda_backend), + Backend("CUBLAS", "USE_CUBLAS", "Use CUBLAS? (y/n): ", cuda_backend), + Backend("ROCm", "USE_ROCM", "Use ROCm? (y/n): ", None), + Backend("Vulkan", "USE_VULKAN", "Use Vulkan? (y/n): ", None), + Backend("Metal", "USE_METAL", "Use Metal (Apple M1/M2 GPU) ? (y/n): ", None), + opencl_backend, Backend( - "Metal", - "USE_METAL", - "Use Metal (Apple M1/M2 GPU) ? (y/n): ", + "OpenCLHostPtr", + "USE_OPENCL_ENABLE_HOST_PTR", + "Use OpenCLHostPtr? (y/n): ", + opencl_backend, ), - Backend( - "OpenCL", - "USE_OPENCL", - "Use OpenCL? (y/n) ", - ), - Backend("OpenCLHostPtr", "USE_OPENCL_ENABLE_HOST_PTR", "Use OpenCLHostPtr? (y/n): "), ] enabled_backends = set() for backend in backends: - while True: - use_backend = input(backend.prompt_str) - if use_backend in ["yes", "Y", "y"]: - cmake_config_str += f"set({backend.cmake_config_name} ON)\n" - enabled_backends.add(backend.name) - break - elif use_backend in ["no", "N", "n"]: - cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" - break - else: - print(f"Invalid input: {use_backend}. Please input again.") + if backend.parent is not None and backend.parent.name not in enabled_backends: + cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" + else: + while True: + use_backend = input(backend.prompt_str) + if use_backend in ["yes", "Y", "y"]: + cmake_config_str += f"set({backend.cmake_config_name} ON)\n" + enabled_backends.add(backend.name) + break + elif use_backend in ["no", "N", "n"]: + cmake_config_str += f"set({backend.cmake_config_name} OFF)\n" + break + else: + print(f"Invalid input: {use_backend}. Please input again.") if "CUDA" in enabled_backends: cmake_config_str += f"set(USE_THRUST ON)\n" # FlashInfer related use_flashInfer = False # pylint: disable=invalid-name - while True: - user_input = input("Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): ") - if user_input in ["yes", "Y", "y"]: - cmake_config_str += "set(USE_FLASHINFER ON)\n" - cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n" - cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n" - cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n" - cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n" - cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n" - cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n" - cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n" - cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n' - cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n' - use_flashInfer = True # pylint: disable=invalid-name - break - elif user_input in ["no", "N", "n"]: - cmake_config_str += "set(USE_FLASHINFER OFF)\n" - break - else: - print(f"Invalid input: {use_flashInfer}. Please input again.") + if "CUDA" in enabled_backends: + while True: + user_input = input( + "Use FlashInfer? (need CUDA w/ compute capability 80;86;89;90) (y/n): " + ) + if user_input in ["yes", "Y", "y"]: + cmake_config_str += "set(USE_FLASHINFER ON)\n" + cmake_config_str += "set(FLASHINFER_ENABLE_FP8 OFF)\n" + cmake_config_str += "set(FLASHINFER_ENABLE_BF16 OFF)\n" + cmake_config_str += "set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)\n" + cmake_config_str += "set(FLASHINFER_GEN_PAGE_SIZES 16)\n" + cmake_config_str += "set(FLASHINFER_GEN_HEAD_DIMS 128)\n" + cmake_config_str += "set(FLASHINFER_GEN_KV_LAYOUTS 0 1)\n" + cmake_config_str += "set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)\n" + cmake_config_str += 'set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")\n' + cmake_config_str += 'set(FLASHINFER_GEN_CASUALS "false" "true")\n' + use_flashInfer = True # pylint: disable=invalid-name + break + elif user_input in ["no", "N", "n"]: + cmake_config_str += "set(USE_FLASHINFER OFF)\n" + break + else: + print(f"Invalid input: {use_flashInfer}. Please input again.") + else: + cmake_config_str += "set(USE_FLASHINFER OFF)\n" + if use_flashInfer: while True: user_input = input("Enter your CUDA compute capability: ")