diff --git a/setup.py b/setup.py index 43b898020..fe5fc2539 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,10 @@ def build_extension(self, ext): f'-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}', ] + arch_list = os.environ.get('TORCH_CUDA_ARCH_LIST') + if WITH_CUDA and arch_list is not None: + cmake_args.append(f'-DCMAKE_CUDA_ARCHITECTURES={arch_list}') + if CMakeBuild.check_env_flag('USE_MKL_BLAS'): include_dir = f"{sysconfig.get_path('data')}{os.sep}include" cmake_args.append(f'-DBLAS_INCLUDE_DIR={include_dir}')