From 9db6ea5a19d19f79976d3c12768db13b8e1a990d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 8 Apr 2024 06:02:42 +0000 Subject: [PATCH] update --- .github/workflows/building.yml | 2 + .github/workflows/install.yml | 2 +- .github/workflows/nightly.yml | 2 + pyg_lib/csrc/ops/cpu/matmul_kernel.cpp | 307 ++++++++++++------------- setup.py | 13 -- 5 files changed, 152 insertions(+), 174 deletions(-) diff --git a/.github/workflows/building.yml b/.github/workflows/building.yml index d0398ac7..1c19e9f2 100644 --- a/.github/workflows/building.yml +++ b/.github/workflows/building.yml @@ -112,6 +112,8 @@ jobs: source ./.github/workflows/cuda/${{ runner.os }}-env.sh ${{ matrix.cuda-version }} python setup.py bdist_wheel --dist-dir=dist shell: bash + env: + TORCH_CUDA_ARCH_LIST: "5.0+PTX;6.0;7.0;7.5;8.0;8.6" - name: Test wheel run: | diff --git a/.github/workflows/install.yml b/.github/workflows/install.yml index 6ce26502..5200ad26 100644 --- a/.github/workflows/install.yml +++ b/.github/workflows/install.yml @@ -33,7 +33,7 @@ jobs: pip install --verbose -e . shell: bash env: - TORCH_CUDA_ARCH_LIST: "3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" + TORCH_CUDA_ARCH_LIST: "5.0+PTX;6.0;7.0;7.5;8.0;8.6" - name: Test imports run: | diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 96835042..8aeeba1b 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -118,6 +118,8 @@ jobs: source ./.github/workflows/cuda/${{ runner.os }}-env.sh ${{ matrix.cuda-version }} python setup.py bdist_wheel --dist-dir=dist shell: bash + env: + TORCH_CUDA_ARCH_LIST: "5.0+PTX;6.0;7.0;7.5;8.0;8.6" - name: Test wheel run: | diff --git a/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp b/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp index 73f8631d..d75f7a2b 100644 --- a/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp +++ b/pyg_lib/csrc/ops/cpu/matmul_kernel.cpp @@ -86,8 +86,7 @@ void mkl_blas_gemm_batched(const int* m_array, const int* ldc_array, const int group_count, const int* group_size) { - TORCH_INTERNAL_ASSERT(false, - "mkl_blas_gemm_batched: MKL BLAS is not supported"); + TORCH_INTERNAL_ASSERT(false, "MKL BLAS is not supported"); } void mkl_blas_gemm_batched(const int* m_array, @@ -103,8 +102,7 @@ void mkl_blas_gemm_batched(const int* m_array, const int* ldc_array, const int group_count, const int* group_size) { - TORCH_INTERNAL_ASSERT(false, - "mkl_blas_gemm_batched: MKL BLAS is not supported"); + TORCH_INTERNAL_ASSERT(false, "MKL BLAS is not supported"); } #endif @@ -206,82 +204,76 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector input, const std::vector other, std::vector out) { // matrix_params - /* using matrix_params = std::tuple; */ - /* phmap::flat_hash_map> groups; */ - /* for (size_t i = 0; i < input.size(); ++i) { */ - /* const matrix_params mp = {input[i].size(0), other[i].size(-1), */ - /* input[i].size(-1)}; */ - /* if (groups.count(mp)) { */ - /* groups[mp].push_back(i); */ - /* } else { */ - /* groups.insert({mp, {i}}); */ - /* } */ - /* } */ - - /* AT_DISPATCH_FLOATING_TYPES( */ - /* input.front().scalar_type(), "grouped_matmul_out_kernel_mkl_impl", [&] - * { */ - /* const auto group_count = static_cast(groups.size()); */ - /* std::vector alpha(group_count, 1); */ - /* std::vector beta(group_count, 0); */ - - /* std::vector ms(group_count); */ - /* std::vector ns(group_count); */ - /* std::vector ks(group_count); */ - /* std::vector ld_src0(group_count); */ - /* std::vector ld_src1(group_count); */ - /* std::vector ld_dst(group_count); */ - /* std::vector group_sizes(group_count); */ - /* std::vector src0; */ - /* std::vector src1; */ - /* std::vector dst; */ - - /* size_t group_idx = 0; */ - /* for (const auto& group_kv : groups) { */ - /* int m; */ - /* int n; */ - /* int k; */ - /* std::tie(m, n, k) = group_kv.first; */ - /* const auto& indices = group_kv.second; */ - - /* ms[group_idx] = m; */ - /* ns[group_idx] = n; */ - /* ks[group_idx] = k; */ - /* ld_src0[group_idx] = k; */ - /* ld_src1[group_idx] = n; */ - /* ld_dst[group_idx] = n; */ - /* group_sizes[group_idx] = indices.size(); */ - /* ++group_idx; */ - - /* for (const auto tensor_idx : indices) { */ - /* src0.push_back(input[tensor_idx].data_ptr()); */ - /* src1.push_back(other[tensor_idx].data_ptr()); */ - /* dst.push_back(out[tensor_idx].data_ptr()); */ - /* } */ - /* } */ - - /* auto src0_ptrs = const_cast(src0.data()); */ - /* auto src1_ptrs = const_cast(src1.data()); */ - /* auto dst_ptrs = dst.data(); */ - - /* #if AT_MKL_SEQUENTIAL() */ - /* // unlikely to happen - requires Torch to be built from source with - */ - /* // explicit flag denoting MKL sequential version */ - /* parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0, - */ - /* src1_ptrs, ld_src1, beta, dst_ptrs, */ - /* ld_dst, group_count, group_sizes); */ - /* #else */ - /* mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(), - */ - /* src0_ptrs, ld_src0.data(), src1_ptrs, - * ld_src1.data(), */ - /* beta.data(), dst_ptrs, ld_dst.data(), - * group_count, */ - /* group_sizes.data()); */ - /* #endif */ - /* }); */ + using matrix_params = std::tuple; + phmap::flat_hash_map> groups; + for (size_t i = 0; i < input.size(); ++i) { + const matrix_params mp = {input[i].size(0), other[i].size(-1), + input[i].size(-1)}; + if (groups.count(mp)) { + groups[mp].push_back(i); + } else { + groups.insert({mp, {i}}); + } + } + + AT_DISPATCH_FLOATING_TYPES( + input.front().scalar_type(), "grouped_matmul_out_kernel_mkl_impl", [&] { + const auto group_count = static_cast(groups.size()); + std::vector alpha(group_count, 1); + std::vector beta(group_count, 0); + + std::vector ms(group_count); + std::vector ns(group_count); + std::vector ks(group_count); + std::vector ld_src0(group_count); + std::vector ld_src1(group_count); + std::vector ld_dst(group_count); + std::vector group_sizes(group_count); + std::vector src0; + std::vector src1; + std::vector dst; + + size_t group_idx = 0; + for (const auto& group_kv : groups) { + int m; + int n; + int k; + std::tie(m, n, k) = group_kv.first; + const auto& indices = group_kv.second; + + ms[group_idx] = m; + ns[group_idx] = n; + ks[group_idx] = k; + ld_src0[group_idx] = k; + ld_src1[group_idx] = n; + ld_dst[group_idx] = n; + group_sizes[group_idx] = indices.size(); + ++group_idx; + + for (const auto tensor_idx : indices) { + src0.push_back(input[tensor_idx].data_ptr()); + src1.push_back(other[tensor_idx].data_ptr()); + dst.push_back(out[tensor_idx].data_ptr()); + } + } + + auto src0_ptrs = const_cast(src0.data()); + auto src1_ptrs = const_cast(src1.data()); + auto dst_ptrs = dst.data(); + +#if AT_MKL_SEQUENTIAL() + // unlikely to happen - requires Torch to be built from source with + // explicit flag denoting MKL sequential version + parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0, + src1_ptrs, ld_src1, beta, dst_ptrs, + ld_dst, group_count, group_sizes); +#else + mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(), + src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(), + beta.data(), dst_ptrs, ld_dst.data(), group_count, + group_sizes.data()); +#endif + }); } std::vector grouped_matmul_kernel(const at::TensorList input, @@ -334,86 +326,81 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input, const at::Tensor& other, at::Tensor& out, const at::IntArrayRef& sizes) { - /* const int n = other.size(-1); */ - /* const int k = input.size(-1); */ - /* const int nk = n * k; */ - /* phmap::flat_hash_map> groups; */ - /* std::vector offsets = {{0, 0, 0}}; */ - /* offsets.reserve(sizes.size() + 1); */ - /* for (size_t i = 0; i < sizes.size(); ++i) { */ - /* const int m = sizes[i]; */ - /* if (groups.count(m)) { */ - /* groups[m].push_back(i); */ - /* } else { */ - /* groups.insert({m, {i}}); */ - /* } */ - - /* offset_params offset = {m * k, nk, m * n}; */ - /* offset += offsets.back(); */ - /* offsets.push_back(offset); */ - /* } */ - /* offsets.pop_back(); */ - - /* AT_DISPATCH_FLOATING_TYPES( */ - /* input.scalar_type(), "segment_matmul_out_kernel_mkl_impl", [&] { */ - /* const auto group_count = static_cast(groups.size()); */ - /* std::vector alpha(group_count, 1); */ - /* std::vector beta(group_count, 0); */ - /* std::vector ns(group_count, n); */ - /* std::vector ks(group_count, k); */ - /* std::vector ld_src0(group_count, k); */ - /* std::vector ld_src1(group_count, n); */ - /* std::vector ld_dst(group_count, n); */ - - /* std::vector ms(group_count); */ - /* std::vector group_sizes(group_count); */ - /* std::vector src0; */ - /* std::vector src1; */ - /* std::vector dst; */ - - /* const auto src0_base_ptr = input.data_ptr(); */ - /* const auto src1_base_ptr = other.data_ptr(); */ - /* const auto dst_base_ptr = out.data_ptr(); */ - - /* size_t group_idx = 0; */ - /* for (const auto& group_kv : groups) { */ - /* int m = group_kv.first; */ - /* const auto& indices = group_kv.second; */ - - /* ms[group_idx] = m; */ - /* group_sizes[group_idx] = indices.size(); */ - /* ++group_idx; */ - - /* for (const auto offset_idx : indices) { */ - /* const auto offset = offsets[offset_idx]; */ - /* src0.push_back(src0_base_ptr + offset.src0_offset); */ - /* src1.push_back(src1_base_ptr + offset.src1_offset); */ - /* dst.push_back(dst_base_ptr + offset.dst_offset); */ - /* } */ - /* } */ - - /* auto src0_ptrs = const_cast(src0.data()); */ - /* auto src1_ptrs = const_cast(src1.data()); */ - /* auto dst_ptrs = dst.data(); */ - - /* #if AT_MKL_SEQUENTIAL() */ - /* // unlikely to happen - requires Torch to be built from source with - */ - /* // explicit flag denoting MKL sequential version */ - /* parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0, - */ - /* src1_ptrs, ld_src1, beta, dst_ptrs, */ - /* ld_dst, group_count, group_sizes); */ - /* #else */ - /* mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(), - */ - /* src0_ptrs, ld_src0.data(), src1_ptrs, - * ld_src1.data(), */ - /* beta.data(), dst_ptrs, ld_dst.data(), - * group_count, */ - /* group_sizes.data()); */ - /* #endif */ - /* }); */ + const int n = other.size(-1); + const int k = input.size(-1); + const int nk = n * k; + phmap::flat_hash_map> groups; + std::vector offsets = {{0, 0, 0}}; + offsets.reserve(sizes.size() + 1); + for (size_t i = 0; i < sizes.size(); ++i) { + const int m = sizes[i]; + if (groups.count(m)) { + groups[m].push_back(i); + } else { + groups.insert({m, {i}}); + } + + offset_params offset = {m * k, nk, m * n}; + offset += offsets.back(); + offsets.push_back(offset); + } + offsets.pop_back(); + + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "segment_matmul_out_kernel_mkl_impl", [&] { + const auto group_count = static_cast(groups.size()); + std::vector alpha(group_count, 1); + std::vector beta(group_count, 0); + std::vector ns(group_count, n); + std::vector ks(group_count, k); + std::vector ld_src0(group_count, k); + std::vector ld_src1(group_count, n); + std::vector ld_dst(group_count, n); + + std::vector ms(group_count); + std::vector group_sizes(group_count); + std::vector src0; + std::vector src1; + std::vector dst; + + const auto src0_base_ptr = input.data_ptr(); + const auto src1_base_ptr = other.data_ptr(); + const auto dst_base_ptr = out.data_ptr(); + + size_t group_idx = 0; + for (const auto& group_kv : groups) { + int m = group_kv.first; + const auto& indices = group_kv.second; + + ms[group_idx] = m; + group_sizes[group_idx] = indices.size(); + ++group_idx; + + for (const auto offset_idx : indices) { + const auto offset = offsets[offset_idx]; + src0.push_back(src0_base_ptr + offset.src0_offset); + src1.push_back(src1_base_ptr + offset.src1_offset); + dst.push_back(dst_base_ptr + offset.dst_offset); + } + } + + auto src0_ptrs = const_cast(src0.data()); + auto src1_ptrs = const_cast(src1.data()); + auto dst_ptrs = dst.data(); + +#if AT_MKL_SEQUENTIAL() + // unlikely to happen - requires Torch to be built from source with + // explicit flag denoting MKL sequential version + parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0, + src1_ptrs, ld_src1, beta, dst_ptrs, + ld_dst, group_count, group_sizes); +#else + mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(), + src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(), + beta.data(), dst_ptrs, ld_dst.data(), group_count, + group_sizes.data()); +#endif + }); } at::Tensor segment_matmul_kernel(const at::Tensor& input, diff --git a/setup.py b/setup.py index 7020c9b1..cbb4ecdd 100644 --- a/setup.py +++ b/setup.py @@ -66,19 +66,6 @@ def build_extension(self, ext): f'-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}', ] - # os.environ['TORCH_CUDA_ARCH_LIST'] = '8.0 8.6 9.0' - - # cuda_arch_list = os.getenv('TORCH_CUDA_ARCH_LIST') - # print("ARCH LIST") - # print("-----------") - # print(cuda_arch_list) - # cmake_args.append('-DCUDA_ARCH_PTX=5.0+PTX') - # if WITH_CUDA and cuda_arch_list is not None: - # cmake_args.append(f'-DCMAKE_CUDA_ARCHITECTURES={cuda_arch_list}') - # else: - # cuda_arch_list = "50;60;70;75;80;86" - # cmake_args.append(f'-DCMAKE_CUDA_ARCHITECTURES={cuda_arch_list}') - if CMakeBuild.check_env_flag('USE_MKL_BLAS'): include_dir = f"{sysconfig.get_path('data')}{os.sep}include" cmake_args.append(f'-DBLAS_INCLUDE_DIR={include_dir}')