Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 8, 2024
1 parent e4cee39 commit 9db6ea5
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 174 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/building.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ jobs:
source ./.github/workflows/cuda/${{ runner.os }}-env.sh ${{ matrix.cuda-version }}
python setup.py bdist_wheel --dist-dir=dist
shell: bash
env:
TORCH_CUDA_ARCH_LIST: "5.0+PTX;6.0;7.0;7.5;8.0;8.6"

- name: Test wheel
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
pip install --verbose -e .
shell: bash
env:
TORCH_CUDA_ARCH_LIST: "3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
TORCH_CUDA_ARCH_LIST: "5.0+PTX;6.0;7.0;7.5;8.0;8.6"

- name: Test imports
run: |
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ jobs:
source ./.github/workflows/cuda/${{ runner.os }}-env.sh ${{ matrix.cuda-version }}
python setup.py bdist_wheel --dist-dir=dist
shell: bash
env:
TORCH_CUDA_ARCH_LIST: "5.0+PTX;6.0;7.0;7.5;8.0;8.6"

- name: Test wheel
run: |
Expand Down
307 changes: 147 additions & 160 deletions pyg_lib/csrc/ops/cpu/matmul_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ void mkl_blas_gemm_batched(const int* m_array,
const int* ldc_array,
const int group_count,
const int* group_size) {
TORCH_INTERNAL_ASSERT(false,
"mkl_blas_gemm_batched: MKL BLAS is not supported");
TORCH_INTERNAL_ASSERT(false, "MKL BLAS is not supported");
}

void mkl_blas_gemm_batched(const int* m_array,
Expand All @@ -103,8 +102,7 @@ void mkl_blas_gemm_batched(const int* m_array,
const int* ldc_array,
const int group_count,
const int* group_size) {
TORCH_INTERNAL_ASSERT(false,
"mkl_blas_gemm_batched: MKL BLAS is not supported");
TORCH_INTERNAL_ASSERT(false, "MKL BLAS is not supported");
}

#endif
Expand Down Expand Up @@ -206,82 +204,76 @@ void grouped_matmul_out_kernel_mkl_impl(const std::vector<at::Tensor> input,
const std::vector<at::Tensor> other,
std::vector<at::Tensor> out) {
// matrix_params<M, N, K>
/* using matrix_params = std::tuple<int, int, int>; */
/* phmap::flat_hash_map<matrix_params, std::vector<size_t>> groups; */
/* for (size_t i = 0; i < input.size(); ++i) { */
/* const matrix_params mp = {input[i].size(0), other[i].size(-1), */
/* input[i].size(-1)}; */
/* if (groups.count(mp)) { */
/* groups[mp].push_back(i); */
/* } else { */
/* groups.insert({mp, {i}}); */
/* } */
/* } */

/* AT_DISPATCH_FLOATING_TYPES( */
/* input.front().scalar_type(), "grouped_matmul_out_kernel_mkl_impl", [&]
* { */
/* const auto group_count = static_cast<int>(groups.size()); */
/* std::vector<scalar_t> alpha(group_count, 1); */
/* std::vector<scalar_t> beta(group_count, 0); */

/* std::vector<int> ms(group_count); */
/* std::vector<int> ns(group_count); */
/* std::vector<int> ks(group_count); */
/* std::vector<int> ld_src0(group_count); */
/* std::vector<int> ld_src1(group_count); */
/* std::vector<int> ld_dst(group_count); */
/* std::vector<int> group_sizes(group_count); */
/* std::vector<scalar_t*> src0; */
/* std::vector<scalar_t*> src1; */
/* std::vector<scalar_t*> dst; */

/* size_t group_idx = 0; */
/* for (const auto& group_kv : groups) { */
/* int m; */
/* int n; */
/* int k; */
/* std::tie(m, n, k) = group_kv.first; */
/* const auto& indices = group_kv.second; */

/* ms[group_idx] = m; */
/* ns[group_idx] = n; */
/* ks[group_idx] = k; */
/* ld_src0[group_idx] = k; */
/* ld_src1[group_idx] = n; */
/* ld_dst[group_idx] = n; */
/* group_sizes[group_idx] = indices.size(); */
/* ++group_idx; */

/* for (const auto tensor_idx : indices) { */
/* src0.push_back(input[tensor_idx].data_ptr<scalar_t>()); */
/* src1.push_back(other[tensor_idx].data_ptr<scalar_t>()); */
/* dst.push_back(out[tensor_idx].data_ptr<scalar_t>()); */
/* } */
/* } */

/* auto src0_ptrs = const_cast<const scalar_t**>(src0.data()); */
/* auto src1_ptrs = const_cast<const scalar_t**>(src1.data()); */
/* auto dst_ptrs = dst.data(); */

/* #if AT_MKL_SEQUENTIAL() */
/* // unlikely to happen - requires Torch to be built from source with
*/
/* // explicit flag denoting MKL sequential version */
/* parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0,
*/
/* src1_ptrs, ld_src1, beta, dst_ptrs, */
/* ld_dst, group_count, group_sizes); */
/* #else */
/* mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(),
*/
/* src0_ptrs, ld_src0.data(), src1_ptrs,
* ld_src1.data(), */
/* beta.data(), dst_ptrs, ld_dst.data(),
* group_count, */
/* group_sizes.data()); */
/* #endif */
/* }); */
using matrix_params = std::tuple<int, int, int>;
phmap::flat_hash_map<matrix_params, std::vector<size_t>> groups;
for (size_t i = 0; i < input.size(); ++i) {
const matrix_params mp = {input[i].size(0), other[i].size(-1),
input[i].size(-1)};
if (groups.count(mp)) {
groups[mp].push_back(i);
} else {
groups.insert({mp, {i}});
}
}

AT_DISPATCH_FLOATING_TYPES(
input.front().scalar_type(), "grouped_matmul_out_kernel_mkl_impl", [&] {
const auto group_count = static_cast<int>(groups.size());
std::vector<scalar_t> alpha(group_count, 1);
std::vector<scalar_t> beta(group_count, 0);

std::vector<int> ms(group_count);
std::vector<int> ns(group_count);
std::vector<int> ks(group_count);
std::vector<int> ld_src0(group_count);
std::vector<int> ld_src1(group_count);
std::vector<int> ld_dst(group_count);
std::vector<int> group_sizes(group_count);
std::vector<scalar_t*> src0;
std::vector<scalar_t*> src1;
std::vector<scalar_t*> dst;

size_t group_idx = 0;
for (const auto& group_kv : groups) {
int m;
int n;
int k;
std::tie(m, n, k) = group_kv.first;
const auto& indices = group_kv.second;

ms[group_idx] = m;
ns[group_idx] = n;
ks[group_idx] = k;
ld_src0[group_idx] = k;
ld_src1[group_idx] = n;
ld_dst[group_idx] = n;
group_sizes[group_idx] = indices.size();
++group_idx;

for (const auto tensor_idx : indices) {
src0.push_back(input[tensor_idx].data_ptr<scalar_t>());
src1.push_back(other[tensor_idx].data_ptr<scalar_t>());
dst.push_back(out[tensor_idx].data_ptr<scalar_t>());
}
}

auto src0_ptrs = const_cast<const scalar_t**>(src0.data());
auto src1_ptrs = const_cast<const scalar_t**>(src1.data());
auto dst_ptrs = dst.data();

#if AT_MKL_SEQUENTIAL()
// unlikely to happen - requires Torch to be built from source with
// explicit flag denoting MKL sequential version
parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0,
src1_ptrs, ld_src1, beta, dst_ptrs,
ld_dst, group_count, group_sizes);
#else
mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(),
src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(),
beta.data(), dst_ptrs, ld_dst.data(), group_count,
group_sizes.data());
#endif
});
}

std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
Expand Down Expand Up @@ -334,86 +326,81 @@ void segment_matmul_out_kernel_mkl_impl(const at::Tensor& input,
const at::Tensor& other,
at::Tensor& out,
const at::IntArrayRef& sizes) {
/* const int n = other.size(-1); */
/* const int k = input.size(-1); */
/* const int nk = n * k; */
/* phmap::flat_hash_map<int, std::vector<size_t>> groups; */
/* std::vector<offset_params> offsets = {{0, 0, 0}}; */
/* offsets.reserve(sizes.size() + 1); */
/* for (size_t i = 0; i < sizes.size(); ++i) { */
/* const int m = sizes[i]; */
/* if (groups.count(m)) { */
/* groups[m].push_back(i); */
/* } else { */
/* groups.insert({m, {i}}); */
/* } */

/* offset_params offset = {m * k, nk, m * n}; */
/* offset += offsets.back(); */
/* offsets.push_back(offset); */
/* } */
/* offsets.pop_back(); */

/* AT_DISPATCH_FLOATING_TYPES( */
/* input.scalar_type(), "segment_matmul_out_kernel_mkl_impl", [&] { */
/* const auto group_count = static_cast<int>(groups.size()); */
/* std::vector<scalar_t> alpha(group_count, 1); */
/* std::vector<scalar_t> beta(group_count, 0); */
/* std::vector<int> ns(group_count, n); */
/* std::vector<int> ks(group_count, k); */
/* std::vector<int> ld_src0(group_count, k); */
/* std::vector<int> ld_src1(group_count, n); */
/* std::vector<int> ld_dst(group_count, n); */

/* std::vector<int> ms(group_count); */
/* std::vector<int> group_sizes(group_count); */
/* std::vector<scalar_t*> src0; */
/* std::vector<scalar_t*> src1; */
/* std::vector<scalar_t*> dst; */

/* const auto src0_base_ptr = input.data_ptr<scalar_t>(); */
/* const auto src1_base_ptr = other.data_ptr<scalar_t>(); */
/* const auto dst_base_ptr = out.data_ptr<scalar_t>(); */

/* size_t group_idx = 0; */
/* for (const auto& group_kv : groups) { */
/* int m = group_kv.first; */
/* const auto& indices = group_kv.second; */

/* ms[group_idx] = m; */
/* group_sizes[group_idx] = indices.size(); */
/* ++group_idx; */

/* for (const auto offset_idx : indices) { */
/* const auto offset = offsets[offset_idx]; */
/* src0.push_back(src0_base_ptr + offset.src0_offset); */
/* src1.push_back(src1_base_ptr + offset.src1_offset); */
/* dst.push_back(dst_base_ptr + offset.dst_offset); */
/* } */
/* } */

/* auto src0_ptrs = const_cast<const scalar_t**>(src0.data()); */
/* auto src1_ptrs = const_cast<const scalar_t**>(src1.data()); */
/* auto dst_ptrs = dst.data(); */

/* #if AT_MKL_SEQUENTIAL() */
/* // unlikely to happen - requires Torch to be built from source with
*/
/* // explicit flag denoting MKL sequential version */
/* parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0,
*/
/* src1_ptrs, ld_src1, beta, dst_ptrs, */
/* ld_dst, group_count, group_sizes); */
/* #else */
/* mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(),
*/
/* src0_ptrs, ld_src0.data(), src1_ptrs,
* ld_src1.data(), */
/* beta.data(), dst_ptrs, ld_dst.data(),
* group_count, */
/* group_sizes.data()); */
/* #endif */
/* }); */
const int n = other.size(-1);
const int k = input.size(-1);
const int nk = n * k;
phmap::flat_hash_map<int, std::vector<size_t>> groups;
std::vector<offset_params> offsets = {{0, 0, 0}};
offsets.reserve(sizes.size() + 1);
for (size_t i = 0; i < sizes.size(); ++i) {
const int m = sizes[i];
if (groups.count(m)) {
groups[m].push_back(i);
} else {
groups.insert({m, {i}});
}

offset_params offset = {m * k, nk, m * n};
offset += offsets.back();
offsets.push_back(offset);
}
offsets.pop_back();

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "segment_matmul_out_kernel_mkl_impl", [&] {
const auto group_count = static_cast<int>(groups.size());
std::vector<scalar_t> alpha(group_count, 1);
std::vector<scalar_t> beta(group_count, 0);
std::vector<int> ns(group_count, n);
std::vector<int> ks(group_count, k);
std::vector<int> ld_src0(group_count, k);
std::vector<int> ld_src1(group_count, n);
std::vector<int> ld_dst(group_count, n);

std::vector<int> ms(group_count);
std::vector<int> group_sizes(group_count);
std::vector<scalar_t*> src0;
std::vector<scalar_t*> src1;
std::vector<scalar_t*> dst;

const auto src0_base_ptr = input.data_ptr<scalar_t>();
const auto src1_base_ptr = other.data_ptr<scalar_t>();
const auto dst_base_ptr = out.data_ptr<scalar_t>();

size_t group_idx = 0;
for (const auto& group_kv : groups) {
int m = group_kv.first;
const auto& indices = group_kv.second;

ms[group_idx] = m;
group_sizes[group_idx] = indices.size();
++group_idx;

for (const auto offset_idx : indices) {
const auto offset = offsets[offset_idx];
src0.push_back(src0_base_ptr + offset.src0_offset);
src1.push_back(src1_base_ptr + offset.src1_offset);
dst.push_back(dst_base_ptr + offset.dst_offset);
}
}

auto src0_ptrs = const_cast<const scalar_t**>(src0.data());
auto src1_ptrs = const_cast<const scalar_t**>(src1.data());
auto dst_ptrs = dst.data();

#if AT_MKL_SEQUENTIAL()
// unlikely to happen - requires Torch to be built from source with
// explicit flag denoting MKL sequential version
parallel_mkl_blas_gemm_batched(ms, ns, ks, alpha, src0_ptrs, ld_src0,
src1_ptrs, ld_src1, beta, dst_ptrs,
ld_dst, group_count, group_sizes);
#else
mkl_blas_gemm_batched(ms.data(), ns.data(), ks.data(), alpha.data(),
src0_ptrs, ld_src0.data(), src1_ptrs, ld_src1.data(),
beta.data(), dst_ptrs, ld_dst.data(), group_count,
group_sizes.data());
#endif
});
}

at::Tensor segment_matmul_kernel(const at::Tensor& input,
Expand Down
13 changes: 0 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,6 @@ def build_extension(self, ext):
f'-DCMAKE_PREFIX_PATH={torch.utils.cmake_prefix_path}',
]

# os.environ['TORCH_CUDA_ARCH_LIST'] = '8.0 8.6 9.0'

# cuda_arch_list = os.getenv('TORCH_CUDA_ARCH_LIST')
# print("ARCH LIST")
# print("-----------")
# print(cuda_arch_list)
# cmake_args.append('-DCUDA_ARCH_PTX=5.0+PTX')
# if WITH_CUDA and cuda_arch_list is not None:
# cmake_args.append(f'-DCMAKE_CUDA_ARCHITECTURES={cuda_arch_list}')
# else:
# cuda_arch_list = "50;60;70;75;80;86"
# cmake_args.append(f'-DCMAKE_CUDA_ARCHITECTURES={cuda_arch_list}')

if CMakeBuild.check_env_flag('USE_MKL_BLAS'):
include_dir = f"{sysconfig.get_path('data')}{os.sep}include"
cmake_args.append(f'-DBLAS_INCLUDE_DIR={include_dir}')
Expand Down

0 comments on commit 9db6ea5

Please sign in to comment.