From fe0afc58529bdf148bfda580a60e3d778b5b8f1d Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Wed, 22 Mar 2023 14:05:48 +0000 Subject: [PATCH] use accumulate type in BF16 gemm(include dot, mv) ref path (#96074) Fix https://github.com/pytorch/pytorch/issues/95125 and https://github.com/pytorch/pytorch/issues/83863 for bf16 accumulation in gemm ref path Pull Request resolved: https://github.com/pytorch/pytorch/pull/96074 Approved by: https://github.com/lezcano, https://github.com/peterbell10 --- aten/src/ATen/native/BlasKernel.cpp | 49 ++++++++--- aten/src/ATen/native/LinearAlgebra.cpp | 40 +++++---- aten/src/ATen/native/cpu/BlasKernel.cpp | 109 ++++++++++++++++++++---- test/test_linalg.py | 47 ++++++++++ test/test_sparse_csr.py | 1 + 5 files changed, 201 insertions(+), 45 deletions(-) diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 87182b3514dfdc..6a4f3d8168f57b 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -1,13 +1,14 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include -#include #include +#include #include -#include #include #include - +#include +#include +#include +#include +#include #if AT_BUILD_WITH_BLAS() extern "C" double ddot_(int *n, double *x, int *incx, double *y, int *incy); extern "C" void dscal_(int *n, double *a, double *x, int *incx); @@ -180,9 +181,10 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t return; } + using opmath_t = at::opmath_type; if ((trans == 'T') || (trans == 't')) { for (const auto i : c10::irange(n)) { - scalar_t sum = 0; + opmath_t sum = 0; scalar_t *row_ = a + lda * i; for (const auto j : c10::irange(m)) { sum += x[j * incx] * row_[j]; @@ -196,15 +198,37 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, scalar_t *a, int64_t } else { if (beta != scalar_t(1) && beta != scalar_t(0)) scal(m, beta, y, incy); + bool is_low_precision = !std::is_same::value; + std::vector sum; + if (is_low_precision) { + sum.resize(m); + } for (const auto j : c10::irange(n)) { scalar_t *column_ = a + lda * j; - scalar_t z = alpha * x[j * incx]; + opmath_t z = alpha * static_cast(x[j * incx]); for (const auto i : c10::irange(m)) { //output values are ignored if beta is 0, and set to 0, nans and infs are not propagated if (j==0 && beta==scalar_t(0)) { - y[i * incy] = scalar_t(0); + if (!is_low_precision) { + y[i * incy] = 0; + } + } + if (is_low_precision) { + sum[i] += z * column_[i]; + } else { + y[i * incy] += z * column_[i]; + } + } + } + if (is_low_precision) { + if (beta == scalar_t(0)) { + for (const auto i : c10::irange(m)) { + y[i * incy] = sum[i]; + } + } else { + for (const auto i : c10::irange(m)) { + y[i * incy] += sum[i]; } - y[i * incy] += z * column_[i]; } } } @@ -263,11 +287,12 @@ scalar_t dot_naive( Functor op) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t i; - scalar_t sum = 0; + using opmath_t = at::opmath_type; + opmath_t sum = 0; for (i = 0; i < n; i++) { - sum += op(x[i * incx], y[i * incy]); + sum += op(static_cast(x[i * incx]), static_cast(y[i * incy])); } - return sum; + return static_cast(sum); } } // namespace blas_impl diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 39998c35731582..5a4a0389a84b59 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1,23 +1,23 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include #include #include #include #include #include -#include +#include +#include +#include +#include +#include +#include +#include #include #include #include #include #include #include -#include -#include -#include -#include -#include -#include +#include #include #include #include @@ -1533,14 +1533,16 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T int64_t js = result.size(2); int64_t ks = self.size(2); - scalar_t alpha = alpha_.to(); - scalar_t beta = beta_.to(); + using opmath_t = at::opmath_type; + opmath_t alpha = alpha_.to(); + opmath_t beta = beta_.to(); auto r0 = result.accessor(); auto s0 = self.accessor(); auto m0 = mat2.accessor(); int64_t grain_size = std::min(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1); + using opmath_t = at::opmath_type; parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) { for (const auto b : c10::irange(b_begin, b_end)) { auto r1 = r0[b]; @@ -1550,17 +1552,19 @@ inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, const T auto r2 = r1[i]; auto s2 = s1[i]; for (const auto j : c10::irange(js)) { - scalar_t &r = r2[j]; + opmath_t acc_value = 0;//is_bmm ? opmath_t(0) : opmath_t(r2[j]); + for (const auto k : c10::irange(ks)) { + acc_value += static_cast(s2[k]) * + static_cast(m1[k][j]); + } if (is_bmm) { - r = 0; - for (const auto k : c10::irange(ks)) { - r += s2[k] * m1[k][j]; - } + r2[j] = acc_value; } else { // For beta == 0, the r's value will be ignored, especially for nan value. - r = beta == scalar_t(0) ? scalar_t(0) : beta * r; - for (const auto k : c10::irange(ks)) { - r += alpha * s2[k] * m1[k][j]; + if (beta == opmath_t{0}) { + r2[j] = alpha * acc_value; + } else { + r2[j] = static_cast(r2[j]) * beta + alpha * acc_value; } } } diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index e143d8a654693a..3114e0b39182eb 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -53,15 +53,20 @@ auto sum(int64_t N, Func f) { return partial_sums[0]; } - template -void gemm_notrans_( - int64_t m, int64_t n, int64_t k, +typename std::enable_if::value, void>::type +gemm_notrans_( + int64_t m, + int64_t n, + int64_t k, opmath_t alpha, - const scalar_t *a, int64_t lda, - const scalar_t *b, int64_t ldb, + const scalar_t* a, + int64_t lda, + const scalar_t* b, + int64_t ldb, opmath_t beta, - scalar_t *c, int64_t ldc) { + scalar_t* c, + int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); @@ -83,6 +88,37 @@ void gemm_notrans_( } } +// std::is_same || std::is_same +template +typename std::enable_if::value, void>::type +gemm_notrans_( + int64_t m, + int64_t n, + int64_t k, + opmath_t alpha, + const scalar_t* a, + int64_t lda, + const scalar_t* b, + int64_t ldb, + opmath_t beta, + scalar_t* c, + int64_t ldc) { + // c += alpha * (a @ b) + for (const auto i : c10::irange(m)) { + for (const auto j : c10::irange(n)) { + const auto dot = sum(k, [&](int64_t l) -> opmath_t { + return static_cast(a[l * lda + i]) * + static_cast(b[j * ldb + l]); + }); + if (beta == opmath_t(0)) { + c[j * ldc + i] = alpha * dot; + } else { + c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot; + } + } + } +} + template void gemm_transa_( int64_t m, int64_t n, int64_t k, @@ -111,13 +147,19 @@ void gemm_transa_( } template -void gemm_transb_( - int64_t m, int64_t n, int64_t k, +typename std::enable_if::value, void>::type +gemm_transb_( + int64_t m, + int64_t n, + int64_t k, opmath_t alpha, - const scalar_t *a, int64_t lda, - const scalar_t *b, int64_t ldb, + const scalar_t* a, + int64_t lda, + const scalar_t* b, + int64_t ldb, opmath_t beta, - scalar_t *c, int64_t ldc) { + scalar_t* c, + int64_t ldc) { // c *= beta scale_(m, n, beta, c, ldc); @@ -139,6 +181,37 @@ void gemm_transb_( } } +// std::is_same || std::is_same +template +typename std::enable_if::value, void>::type +gemm_transb_( + int64_t m, + int64_t n, + int64_t k, + opmath_t alpha, + const scalar_t* a, + int64_t lda, + const scalar_t* b, + int64_t ldb, + opmath_t beta, + scalar_t* c, + int64_t ldc) { + // c += alpha * (a @ b.T) + for (const auto i : c10::irange(m)) { + for (const auto j : c10::irange(n)) { + const auto dot = sum(k, [&](int64_t l) -> opmath_t { + return static_cast(a[l * lda + i]) * + static_cast(b[l * ldb + j]); + }); + if (beta == opmath_t(0)) { + c[j * ldc + i] = alpha * dot; + } else { + c[j * ldc + i] = beta * c[j * ldc + i] + alpha * dot; + } + } + } +} + template void gemm_transab_( int64_t m, int64_t n, int64_t k, @@ -173,13 +246,19 @@ void gemm_core_( const scalar_t *b, int64_t ldb, opmath_t beta, scalar_t *c, int64_t ldc) { - if(transa == TransposeType::NoTranspose && transb == TransposeType::NoTranspose) { + if (transa == TransposeType::NoTranspose && + transb == TransposeType::NoTranspose) { return gemm_notrans_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } else if(transa == TransposeType::Transpose && transb != TransposeType::Transpose) { + } else if ( + transa == TransposeType::Transpose && + transb != TransposeType::Transpose) { gemm_transa_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } else if(transa == TransposeType::NoTranspose && transb == TransposeType::Transpose) { + } else if ( + transa == TransposeType::NoTranspose && + transb == TransposeType::Transpose) { gemm_transb_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); - } else { // transa == TransposeType::Transpose && transb == TransposeType::Transpose + } else { // transa == TransposeType::Transpose && transb == + // TransposeType::Transpose gemm_transab_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } } diff --git a/test/test_linalg.py b/test/test_linalg.py index 1bfe5677d47ad9..b097437e06b81c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -7347,6 +7347,53 @@ def test_permute_matmul(self): c = a.permute(0, 1, 3, 2).matmul(b) self.assertEqual([c.min(), c.max(), c.sum()], [24, 24, 414720]) + def test_bfloat16_accumulation_with_ref_path(self): + # fix https://github.com/pytorch/pytorch/issues/95125 + # and https://github.com/pytorch/pytorch/issues/83863 + # for bf16 accumulation in gemm ref path + def check_correctness(fn, *args): + expected = fn(*args).bfloat16() + with torch.backends.mkldnn.flags(enabled=False): + def test(): + bf16_args = (arg.bfloat16() for arg in args) + tmp_result = fn(*bf16_args) + return tmp_result + c = test() + assert (torch.all(c == expected)), "Incorrect result with\n" \ + f"expected: {expected}\n" \ + f"got: {c}\n" + # test matmul + for transa in [True, False]: + for transb in [True, False]: + a = torch.ones(300, 300) + b = torch.ones(300, 300) + if transa: + a = a.transpose(0, 1).contiguous().transpose(0, 1) + if transb: + b = b.transpose(0, 1).contiguous().transpose(0, 1) + check_correctness(torch.matmul, a, b) + # test bmm + a = torch.ones(1, 1, 300) + b = torch.ones(1, 300, 1) + check_correctness(torch.bmm, a, b) + # test baddbmm + a = torch.ones(1, 1, 300) + b = torch.ones(1, 300, 1) + c = torch.ones(1, 1, 1) + check_correctness(torch.baddbmm, c, a, b) + # test mv/addmv + for trans in [True, False]: + c = torch.ones(300) * -300 + a = torch.ones(300, 300) + if trans: + a = a.transpose(0, 1).contiguous().transpose(0, 1) + b = torch.ones(300) + check_correctness(torch.mv, a, b) + check_correctness(torch.addmv, c, a, b) + # test dot + a = torch.ones(300) + b = torch.ones(300) + check_correctness(torch.dot, a, b) instantiate_device_type_tests(TestLinalg, globals()) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 54c460e27b7ba8..2e5e93f031e5d0 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -2466,6 +2466,7 @@ def test_sampled_addmm_errors(self, device, dtype): @onlyCPU @dtypes(torch.float32, torch.float64, torch.bfloat16) + @precisionOverride({torch.bfloat16: 0.01}) def test_sparse_mm_reduce_sum(self, device, dtype): def run_test(m, n, k, nnz, train): sparse = self.genSparseCSRTensor((m, k), nnz, dtype=dtype, device=device, index_dtype=torch.int64)