From e93063f11ec1415254d89a173df0cc6a3dfd1e13 Mon Sep 17 00:00:00 2001 From: david-cortes-intel Date: Wed, 16 Oct 2024 12:11:00 +0200 Subject: [PATCH] Use spectral decomposition as more accurate fallback from Cholesky (#2930) * Use spectral decomposition as more accurate fallback from Cholesky * correct buffer size * add optional pragma (has no effect though) * add test with an overdetermined system * Update cpp/daal/src/algorithms/service_kernel_math.h Co-authored-by: Victoriya Fedotova * PR comments * use syevr which is faster than syev * avoid too eager threshold for aborting fp32 cholesky * missing commas * fix declaration types * fix incorrect diagonal indexing * add test with multi-output and overdetermined * missing dispatch * skip recently introduced test on GPU * missing file * rename GPU check and move to correct file --------- Co-authored-by: Victoriya Fedotova --- cpp/daal/src/algorithms/service_kernel_math.h | 178 +++++++++++++++--- cpp/daal/src/externals/service_lapack.h | 56 ++++++ .../src/externals/service_lapack_declar_ref.h | 10 + cpp/daal/src/externals/service_lapack_mkl.h | 68 +++++++ cpp/daal/src/externals/service_lapack_ref.h | 50 +++++ .../dal/algo/linear_regression/test/batch.cpp | 8 + .../algo/linear_regression/test/fixture.hpp | 145 ++++++++++++++ 7 files changed, 486 insertions(+), 29 deletions(-) diff --git a/cpp/daal/src/algorithms/service_kernel_math.h b/cpp/daal/src/algorithms/service_kernel_math.h index e23104ea1c5..6f8fd06b6b4 100644 --- a/cpp/daal/src/algorithms/service_kernel_math.h +++ b/cpp/daal/src/algorithms/service_kernel_math.h @@ -24,6 +24,8 @@ #ifndef __SERVICE_KERNEL_MATH_H__ #define __SERVICE_KERNEL_MATH_H__ +#include + #include "services/daal_defines.h" #include "services/env_detect.h" #include "src/algorithms/service_error_handling.h" @@ -660,6 +662,18 @@ bool solveEquationsSystemWithCholesky(FPType * a, FPType * b, size_t n, size_t n } if (info != 0) return false; + /* Note: there can be cases in which the matrix is singular / rank-deficient, but due to numerical + inaccuracies, Cholesky still succeeds. In such cases, it might produce a solution successfully, but + it will not be the minimum-norm solution, and might be prone towards having too large numbers. Thus + it's preferrable to fall back to a different type of solver that can work correctly with those. + Note that the thresholds chosen there are just a guess and not based on any properties of floating + points or academic research. */ + const FPType threshold_chol_diag = 1e-6; + for (size_t ix = 0; ix < n; ix++) + { + if (a[ix * (n + 1)] < threshold_chol_diag) return false; + } + /* Solve L*L' * x = b */ if (sequential) { @@ -673,72 +687,178 @@ bool solveEquationsSystemWithCholesky(FPType * a, FPType * b, size_t n, size_t n } template -bool solveEquationsSystemWithPLU(FPType * a, FPType * b, size_t n, size_t nX, bool sequential, bool extendFromSymmetric) +bool solveEquationsSystemWithSpectralDecomposition(FPType * a, FPType * b, size_t n, size_t nX, bool sequential) { - if (extendFromSymmetric) + /* Storage for the eigenvalues. + Note: this allocates more size than they might require when nX > 1, because the same + buffer will get reused later on and needs the extra size. Those additional entries + will not be filled with eigenvalues. */ + TArrayScalable eigenvalues(n * nX); + DAAL_CHECK_MALLOC(eigenvalues.get()); + + TArrayScalable eigenvectors(n * n); + DAAL_CHECK_MALLOC(eigenvectors.get()); + + TArrayScalable buffer_isuppz(2 * n); + DAAL_CHECK_MALLOC(buffer_isuppz.get()); + + /* SYEV parameters */ + const char jobz = 'V'; + const char range = 'A'; + const char uplo = 'U'; + FPType zero = 0; + DAAL_INT info; + DAAL_INT num_eigenvalues; + + /* Query the procedure for size of required buffer */ + DAAL_INT lwork_query_indicator = -1; + FPType buffer_size_work = 0; + DAAL_INT buffer_size_iwork = 0; + if (sequential) { - /* Extend symmetric matrix to generic through filling of upper triangle */ - for (size_t i = 0; i < n; ++i) - { - for (size_t j = 0; j < i; ++j) - { - a[j * n + i] = a[i * n + j]; - } - } + LapackInst::xxsyevr(&jobz, &range, &uplo, (DAAL_INT *)&n, a, (DAAL_INT *)&n, nullptr, nullptr, nullptr, nullptr, &zero, + &num_eigenvalues, eigenvalues.get(), eigenvectors.get(), (DAAL_INT *)&n, buffer_isuppz.get(), + &buffer_size_work, &lwork_query_indicator, &buffer_size_iwork, &lwork_query_indicator, &info); } - /* GETRF and GETRS parameters */ - char trans = 'N'; - DAAL_INT info = 0; + else + { + LapackInst::xsyevr(&jobz, &range, &uplo, (DAAL_INT *)&n, a, (DAAL_INT *)&n, nullptr, nullptr, nullptr, nullptr, &zero, + &num_eigenvalues, eigenvalues.get(), eigenvectors.get(), (DAAL_INT *)&n, buffer_isuppz.get(), + &buffer_size_work, &lwork_query_indicator, &buffer_size_iwork, &lwork_query_indicator, &info); + } + + if (info) return false; + + /* Check that buffer sizes will not overflow when passed to LAPACK */ + if (static_cast(buffer_size_work) > std::numeric_limits::max()) return false; + if (buffer_size_iwork < 0) return false; - TArrayScalable ipiv(n); - DAAL_CHECK_MALLOC(ipiv.get()); + /* Allocate work buffers as needed */ + DAAL_INT work_buffer_size = static_cast(buffer_size_work); + TArrayScalable work_buffer(work_buffer_size); + DAAL_CHECK_MALLOC(work_buffer.get()); + TArrayScalable iwork_buffer(buffer_size_iwork); + DAAL_CHECK_MALLOC(iwork_buffer.get()); - /* Perform P*L*U factorization of A */ + /* Perform Q*diag(l)*Q' factorization of A */ if (sequential) { - LapackInst::xxgetrf((DAAL_INT *)&n, (DAAL_INT *)&n, a, (DAAL_INT *)&n, ipiv.get(), &info); + LapackInst::xxsyevr(&jobz, &range, &uplo, (DAAL_INT *)&n, a, (DAAL_INT *)&n, nullptr, nullptr, nullptr, nullptr, &zero, + &num_eigenvalues, eigenvalues.get(), eigenvectors.get(), (DAAL_INT *)&n, buffer_isuppz.get(), + work_buffer.get(), &work_buffer_size, iwork_buffer.get(), &buffer_size_iwork, &info); } else { - LapackInst::xgetrf((DAAL_INT *)&n, (DAAL_INT *)&n, a, (DAAL_INT *)&n, ipiv.get(), &info); + LapackInst::xsyevr(&jobz, &range, &uplo, (DAAL_INT *)&n, a, (DAAL_INT *)&n, nullptr, nullptr, nullptr, nullptr, &zero, + &num_eigenvalues, eigenvalues.get(), eigenvectors.get(), (DAAL_INT *)&n, buffer_isuppz.get(), + work_buffer.get(), &work_buffer_size, iwork_buffer.get(), &buffer_size_iwork, &info); } - if (info != 0) return false; + if (info) return false; + + /* Components with small singular values get eliminated using the exact same logic as 'gelsd' with default parameters + Note: these are hard-coded versions of machine epsilon for single and double precision. They aren't obtained through + 'std::numeric_limits' in order to avoid potential template instantiation errors with some types. */ + const FPType eps = std::is_same::value ? 1.1920929e-07 : 2.220446049250313e-16; + if (eigenvalues[n - 1] <= eps) return false; + const double component_threshold = eps * eigenvalues[n - 1]; + DAAL_INT num_discarded; + for (num_discarded = 0; num_discarded < static_cast(n) - 1; num_discarded++) + { + if (eigenvalues[num_discarded] > component_threshold) break; + } + + /* Create the square root of the inverse: Qis = Q * diag(1 / sqrt(l)) */ + DAAL_INT num_taken = static_cast(n) - num_discarded; + daal::internal::MathInst::vSqrt(num_taken, eigenvalues.get() + num_discarded, eigenvalues.get() + num_discarded); + DAAL_INT one = 1; + PRAGMA_IVDEP + for (size_t col = num_discarded; col < n; col++) + { + const FPType scale = eigenvalues[col]; + if (sequential) + { + LapackInst::xxrscl((DAAL_INT *)&n, &scale, eigenvectors.get() + col * n, &one); + } - /* Solve P*L*U * x = b */ + else + { + LapackInst::xrscl((DAAL_INT *)&n, &scale, eigenvectors.get() + col * n, &one); + } + } + + /* Now calculate the actual solution: Qis * Qis' * B */ + char trans_yes = 'T'; + char trans_no = 'N'; + FPType one_fp = 1; + const size_t eigenvectors_offset = static_cast(num_discarded) * n; if (sequential) { - LapackInst::xxgetrs(&trans, (DAAL_INT *)&n, (DAAL_INT *)&nX, a, (DAAL_INT *)&n, ipiv.get(), b, (DAAL_INT *)&n, &info); + if (nX == 1) + { + BlasInst::xxgemv(&trans_yes, (DAAL_INT *)&n, &num_taken, &one_fp, eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, + b, &one, &zero, eigenvalues.get(), &one); + BlasInst::xxgemv(&trans_no, (DAAL_INT *)&n, &num_taken, &one_fp, eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, + eigenvalues.get(), &one, &zero, b, &one); + } + + else + { + BlasInst::xxgemm(&trans_yes, &trans_no, &num_taken, (DAAL_INT *)&nX, (DAAL_INT *)&n, &one_fp, + eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, b, (DAAL_INT *)&n, &zero, eigenvalues.get(), + &num_taken); + BlasInst::xxgemm(&trans_no, &trans_no, (DAAL_INT *)&n, (DAAL_INT *)&nX, &num_taken, &one_fp, + eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, eigenvalues.get(), &num_taken, &zero, b, + (DAAL_INT *)&n); + } } + else { - LapackInst::xgetrs(&trans, (DAAL_INT *)&n, (DAAL_INT *)&nX, a, (DAAL_INT *)&n, ipiv.get(), b, (DAAL_INT *)&n, &info); + if (nX == 1) + { + BlasInst::xgemv(&trans_yes, (DAAL_INT *)&n, &num_taken, &one_fp, eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, b, + &one, &zero, eigenvalues.get(), &one); + BlasInst::xgemv(&trans_no, (DAAL_INT *)&n, &num_taken, &one_fp, eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, + eigenvalues.get(), &one, &zero, b, &one); + } + + else + { + BlasInst::xgemm(&trans_yes, &trans_no, &num_taken, (DAAL_INT *)&nX, (DAAL_INT *)&n, &one_fp, + eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, b, (DAAL_INT *)&n, &zero, eigenvalues.get(), + &num_taken); + BlasInst::xxgemm(&trans_no, &trans_no, (DAAL_INT *)&n, (DAAL_INT *)&nX, &num_taken, &one_fp, + eigenvectors.get() + eigenvectors_offset, (DAAL_INT *)&n, eigenvalues.get(), &num_taken, &zero, b, + (DAAL_INT *)&n); + } } - return (info == 0); + + return true; } template bool solveSymmetricEquationsSystem(FPType * a, FPType * b, size_t n, size_t nX, bool sequential) { - /* Copy data for fallback from Cholesky to PLU factorization */ + /* Copy data for fallback from Cholesky to spectral decomposition */ TArrayScalable aCopy(n * n); - TArrayScalable bCopy(n); + TArrayScalable bCopy(n * nX); DAAL_CHECK_MALLOC(aCopy.get()); DAAL_CHECK_MALLOC(bCopy.get()); int copy_status = services::internal::daal_memcpy_s(aCopy.get(), n * n * sizeof(FPType), a, n * n * sizeof(FPType)); - copy_status += services::internal::daal_memcpy_s(bCopy.get(), n * sizeof(FPType), b, n * sizeof(FPType)); + copy_status += services::internal::daal_memcpy_s(bCopy.get(), n * nX * sizeof(FPType), b, n * nX * sizeof(FPType)); if (copy_status != 0) return false; /* Try to solve with Cholesky factorization */ if (!solveEquationsSystemWithCholesky(a, b, n, nX, sequential)) { - /* Fallback to PLU factorization */ - bool status = solveEquationsSystemWithPLU(aCopy.get(), bCopy.get(), n, nX, sequential, true); + /* Fall back to spectral decomposition */ + bool status = solveEquationsSystemWithSpectralDecomposition(aCopy.get(), bCopy.get(), n, nX, sequential); if (status) { - status = status && (services::internal::daal_memcpy_s(b, n * sizeof(FPType), bCopy.get(), n * sizeof(FPType)) == 0); + status = status && (services::internal::daal_memcpy_s(b, n * nX * sizeof(FPType), bCopy.get(), n * nX * sizeof(FPType)) == 0); } return status; } diff --git a/cpp/daal/src/externals/service_lapack.h b/cpp/daal/src/externals/service_lapack.h index fbf2ebd44f7..130dceb8248 100644 --- a/cpp/daal/src/externals/service_lapack.h +++ b/cpp/daal/src/externals/service_lapack.h @@ -21,6 +21,11 @@ //-- */ +/* Note: this file is not auto-generated. These 'x'/'xx' functions are manually added here on an +as-needed basis, and are only used internally within the library so their signatures might not +match LAPACK's to every minutiae like passing pointers to scalars or passing them by value, or +having 'const' qualifiers or not. */ + #ifndef __SERVICE_LAPACK_H__ #define __SERVICE_LAPACK_H__ @@ -181,6 +186,9 @@ struct Lapack _impl::xxgesvd(jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, info); } + /* Note: 'syevd' and 'syevr' both compute symmetric eigenvalues, but they use different routines. 'syevd' + is slower but more precise and should thus be preferred for situations in which small numerical inaccuracies + have adverse side effects, whereas 'syevr' is faster and more suitable for general usage. */ static void xsyevd(char * jobz, char * uplo, SizeType * n, fpType * a, SizeType * lda, fpType * w, fpType * work, SizeType * lwork, SizeType * iwork, SizeType * liwork, SizeType * info) { @@ -193,6 +201,22 @@ struct Lapack _impl::xxsyevd(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); } + static void xsyevr(const char * jobz, const char * range, const char * uplo, const SizeType * n, fpType * a, const SizeType * lda, + const fpType * vl, const fpType * vu, const SizeType * il, const SizeType * iu, const fpType * abstol, SizeType * m, + fpType * w, fpType * z, const SizeType * ldz, SizeType * isuppz, fpType * work, const SizeType * lwork, SizeType * iwork, + const SizeType * liwork, SizeType * info) + { + _impl::xsyevr(jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, liwork, info); + } + + static void xxsyevr(const char * jobz, const char * range, const char * uplo, const SizeType * n, fpType * a, const SizeType * lda, + const fpType * vl, const fpType * vu, const SizeType * il, const SizeType * iu, const fpType * abstol, SizeType * m, + fpType * w, fpType * z, const SizeType * ldz, SizeType * isuppz, fpType * work, const SizeType * lwork, SizeType * iwork, + const SizeType * liwork, SizeType * info) + { + _impl::xxsyevr(jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, liwork, info); + } + static void xormqr(char * side, char * trans, SizeType * m, SizeType * n, SizeType * k, fpType * a, SizeType * lda, fpType * tau, fpType * c, SizeType * ldc, fpType * work, SizeType * lwork, SizeType * info) { @@ -204,6 +228,10 @@ struct Lapack { _impl::xxormqr(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, info); } + + static void xrscl(const SizeType * n, const fpType * sa, fpType * sx, const SizeType * incx) { _impl::xrscl(n, sa, sx, incx); } + + static void xxrscl(const SizeType * n, const fpType * sa, fpType * sx, const SizeType * incx) { _impl::xxrscl(n, sa, sx, incx); } }; template @@ -361,6 +389,24 @@ struct LapackAutoDispatch DAAL_DISPATCH_LAPACK_BY_CPU(fpType, xxsyevd, jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); } + static void xsyevr(const char * jobz, const char * range, const char * uplo, const SizeType * n, fpType * a, const SizeType * lda, + const fpType * vl, const fpType * vu, const SizeType * il, const SizeType * iu, const fpType * abstol, SizeType * m, + fpType * w, fpType * z, const SizeType * ldz, SizeType * isuppz, fpType * work, const SizeType * lwork, SizeType * iwork, + const SizeType * liwork, SizeType * info) + { + DAAL_DISPATCH_LAPACK_BY_CPU(fpType, xsyevr, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, + liwork, info); + } + + static void xxsyevr(const char * jobz, const char * range, const char * uplo, const SizeType * n, fpType * a, const SizeType * lda, + const fpType * vl, const fpType * vu, const SizeType * il, const SizeType * iu, const fpType * abstol, SizeType * m, + fpType * w, fpType * z, const SizeType * ldz, SizeType * isuppz, fpType * work, const SizeType * lwork, SizeType * iwork, + const SizeType * liwork, SizeType * info) + { + DAAL_DISPATCH_LAPACK_BY_CPU(fpType, xxsyevr, jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, + liwork, info); + } + static void xormqr(char * side, char * trans, SizeType * m, SizeType * n, SizeType * k, fpType * a, SizeType * lda, fpType * tau, fpType * c, SizeType * ldc, fpType * work, SizeType * lwork, SizeType * info) { @@ -372,6 +418,16 @@ struct LapackAutoDispatch { DAAL_DISPATCH_LAPACK_BY_CPU(fpType, xxormqr, side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, info); } + + static void xrscl(SizeType * n, const fpType * sa, fpType * sx, const SizeType * incx) + { + DAAL_DISPATCH_LAPACK_BY_CPU(fpType, xrscl, n, sa, sx, incx); + } + + static void xxrscl(SizeType * n, const fpType * sa, fpType * sx, const SizeType * incx) + { + DAAL_DISPATCH_LAPACK_BY_CPU(fpType, xxrscl, n, sa, sx, incx); + } }; } // namespace internal diff --git a/cpp/daal/src/externals/service_lapack_declar_ref.h b/cpp/daal/src/externals/service_lapack_declar_ref.h index 7e6c9c195d1..418c9f05b02 100644 --- a/cpp/daal/src/externals/service_lapack_declar_ref.h +++ b/cpp/daal/src/externals/service_lapack_declar_ref.h @@ -82,10 +82,20 @@ extern "C" extern void ssyevd_(char *, char *, DAAL_INT *, float *, DAAL_INT *, float *, float *, DAAL_INT *, DAAL_INT *, DAAL_INT *, DAAL_INT *); extern void dsyevd_(char *, char *, DAAL_INT *, double *, DAAL_INT *, double *, double *, DAAL_INT *, DAAL_INT *, DAAL_INT *, DAAL_INT *); + extern void ssyevr_(const char *, const char *, const char *, const DAAL_INT *, float *, const DAAL_INT *, const float *, const float *, + const DAAL_INT *, const DAAL_INT *, const float *, DAAL_INT *, float *, float *, const DAAL_INT *, DAAL_INT *, float *, + const DAAL_INT *, DAAL_INT *, const DAAL_INT *, DAAL_INT *); + extern void dsyevr_(const char *, const char *, const char *, const DAAL_INT *, double *, const DAAL_INT *, const double *, const double *, + const DAAL_INT *, const DAAL_INT *, const double *, DAAL_INT *, double *, double *, const DAAL_INT *, DAAL_INT *, double *, + const DAAL_INT *, DAAL_INT *, const DAAL_INT *, DAAL_INT *); + extern void sormqr_(char *, char *, DAAL_INT *, DAAL_INT *, DAAL_INT *, float *, DAAL_INT *, float *, float *, DAAL_INT *, float *, DAAL_INT *, DAAL_INT *); extern void dormqr_(char *, char *, DAAL_INT *, DAAL_INT *, DAAL_INT *, double *, DAAL_INT *, double *, double *, DAAL_INT *, double *, DAAL_INT *, DAAL_INT *); + + extern void drscl_(const DAAL_INT *, const double *, double *, const DAAL_INT *); + extern void srscl_(const DAAL_INT *, const float *, float *, const DAAL_INT *); } } // namespace ref diff --git a/cpp/daal/src/externals/service_lapack_mkl.h b/cpp/daal/src/externals/service_lapack_mkl.h index 37a81c3262f..f75bb1fcf7a 100644 --- a/cpp/daal/src/externals/service_lapack_mkl.h +++ b/cpp/daal/src/externals/service_lapack_mkl.h @@ -243,6 +243,28 @@ struct MklLapack mkl_set_num_threads_local(old_nthr); } + static void xsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, double * a, const DAAL_INT * lda, + const double * vl, const double * vu, const DAAL_INT * il, const DAAL_INT * iu, const double * abstol, DAAL_INT * m, + double * w, double * z, const DAAL_INT * ldz, DAAL_INT * isuppz, double * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + __DAAL_MKLFN_CALL_LAPACK(dsyevr, (jobz, range, uplo, (const MKL_INT *)n, a, (const MKL_INT *)lda, vl, vu, (const MKL_INT *)il, + (const MKL_INT *)iu, abstol, (MKL_INT *)m, w, z, (const MKL_INT *)ldz, (MKL_INT *)isuppz, work, + (const MKL_INT *)lwork, (MKL_INT *)iwork, (const MKL_INT *)liwork, (MKL_INT *)info)); + } + + static void xxsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, double * a, const DAAL_INT * lda, + const double * vl, const double * vu, const DAAL_INT * il, const DAAL_INT * iu, const double * abstol, DAAL_INT * m, + double * w, double * z, const DAAL_INT * ldz, DAAL_INT * isuppz, double * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + int old_nthr = mkl_set_num_threads_local(1); + __DAAL_MKLFN_CALL_LAPACK(dsyevr, (jobz, range, uplo, (const MKL_INT *)n, a, (const MKL_INT *)lda, vl, vu, (const MKL_INT *)il, + (const MKL_INT *)iu, abstol, (MKL_INT *)m, w, z, (const MKL_INT *)ldz, (MKL_INT *)isuppz, work, + (const MKL_INT *)lwork, (MKL_INT *)iwork, (const MKL_INT *)liwork, (MKL_INT *)info)); + mkl_set_num_threads_local(old_nthr); + } + static void xormqr(char * side, char * trans, DAAL_INT * m, DAAL_INT * n, DAAL_INT * k, double * a, DAAL_INT * lda, double * tau, double * c, DAAL_INT * ldc, double * work, DAAL_INT * lwork, DAAL_INT * info) { @@ -258,6 +280,18 @@ struct MklLapack (MKL_INT *)lwork, (MKL_INT *)info)); mkl_set_num_threads_local(old_nthr); } + + static void xrscl(const DAAL_INT * n, const double * sa, double * sx, const DAAL_INT * incx) + { + __DAAL_MKLFN_CALL_LAPACK(drscl, ((MKL_INT *)n, sa, sx, (MKL_INT *)incx)); + } + + static void xxrscl(const DAAL_INT * n, const double * sa, double * sx, const DAAL_INT * incx) + { + int old_nthr = mkl_set_num_threads_local(1); + __DAAL_MKLFN_CALL_LAPACK(drscl, ((MKL_INT *)n, sa, sx, (MKL_INT *)incx)); + mkl_set_num_threads_local(old_nthr); + } }; /* @@ -461,6 +495,28 @@ struct MklLapack mkl_set_num_threads_local(old_nthr); } + static void xsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, float * a, const DAAL_INT * lda, + const float * vl, const float * vu, const DAAL_INT * il, const DAAL_INT * iu, const float * abstol, DAAL_INT * m, float * w, + float * z, const DAAL_INT * ldz, DAAL_INT * isuppz, float * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + __DAAL_MKLFN_CALL_LAPACK(ssyevr, (jobz, range, uplo, (const MKL_INT *)n, a, (const MKL_INT *)lda, vl, vu, (const MKL_INT *)il, + (const MKL_INT *)iu, abstol, (MKL_INT *)m, w, z, (const MKL_INT *)ldz, (MKL_INT *)isuppz, work, + (const MKL_INT *)lwork, (MKL_INT *)iwork, (const MKL_INT *)liwork, (MKL_INT *)info)); + } + + static void xxsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, float * a, const DAAL_INT * lda, + const float * vl, const float * vu, const DAAL_INT * il, const DAAL_INT * iu, const float * abstol, DAAL_INT * m, float * w, + float * z, const DAAL_INT * ldz, DAAL_INT * isuppz, float * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + int old_nthr = mkl_set_num_threads_local(1); + __DAAL_MKLFN_CALL_LAPACK(ssyevr, (jobz, range, uplo, (const MKL_INT *)n, a, (const MKL_INT *)lda, vl, vu, (const MKL_INT *)il, + (const MKL_INT *)iu, abstol, (MKL_INT *)m, w, z, (const MKL_INT *)ldz, (MKL_INT *)isuppz, work, + (const MKL_INT *)lwork, (MKL_INT *)iwork, (const MKL_INT *)liwork, (MKL_INT *)info)); + mkl_set_num_threads_local(old_nthr); + } + static void xormqr(char * side, char * trans, DAAL_INT * m, DAAL_INT * n, DAAL_INT * k, float * a, DAAL_INT * lda, float * tau, float * c, DAAL_INT * ldc, float * work, DAAL_INT * lwork, DAAL_INT * info) { @@ -476,6 +532,18 @@ struct MklLapack (MKL_INT *)lwork, (MKL_INT *)info)); mkl_set_num_threads_local(old_nthr); } + + static void xrscl(const DAAL_INT * n, const float * sa, float * sx, const DAAL_INT * incx) + { + __DAAL_MKLFN_CALL_LAPACK(srscl, ((MKL_INT *)n, sa, sx, (MKL_INT *)incx)); + } + + static void xxrscl(const DAAL_INT * n, const float * sa, float * sx, const DAAL_INT * incx) + { + int old_nthr = mkl_set_num_threads_local(1); + __DAAL_MKLFN_CALL_LAPACK(srscl, ((MKL_INT *)n, sa, sx, (MKL_INT *)incx)); + mkl_set_num_threads_local(old_nthr); + } }; } // namespace mkl diff --git a/cpp/daal/src/externals/service_lapack_ref.h b/cpp/daal/src/externals/service_lapack_ref.h index 4b87d88cac8..6d52326d3e1 100644 --- a/cpp/daal/src/externals/service_lapack_ref.h +++ b/cpp/daal/src/externals/service_lapack_ref.h @@ -204,6 +204,23 @@ struct OpenBlasLapack dsyevd_(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); } + static void xsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, double * a, const DAAL_INT * lda, + const double * vl, const double * vu, const DAAL_INT * il, const DAAL_INT * iu, const double * abstol, DAAL_INT * m, + double * w, double * z, const DAAL_INT * ldz, DAAL_INT * isuppz, double * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + dsyevr_(jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, liwork, info); + } + + static void xxsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, double * a, const DAAL_INT * lda, + const double * vl, const double * vu, const DAAL_INT * il, const DAAL_INT * iu, const double * abstol, DAAL_INT * m, + double * w, double * z, const DAAL_INT * ldz, DAAL_INT * isuppz, double * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + openblas_thread_setter ots(1); + dsyevr_(jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, liwork, info); + } + static void xormqr(char * side, char * trans, DAAL_INT * m, DAAL_INT * n, DAAL_INT * k, double * a, DAAL_INT * lda, double * tau, double * c, DAAL_INT * ldc, double * work, DAAL_INT * lwork, DAAL_INT * info) { @@ -216,6 +233,14 @@ struct OpenBlasLapack openblas_thread_setter ots(1); dormqr_(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, info); } + + static void xrscl(const DAAL_INT * n, const double * sa, double * sx, const DAAL_INT * incx) { drscl_(n, sa, sx, incx); } + + static void xxrscl(const DAAL_INT * n, const double * sa, double * sx, const DAAL_INT * incx) + { + openblas_thread_setter ots(1); + drscl_(n, sa, sx, incx); + } }; /* @@ -381,6 +406,23 @@ struct OpenBlasLapack ssyevd_(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); } + static void xsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, float * a, const DAAL_INT * lda, + const float * vl, const float * vu, const DAAL_INT * il, const DAAL_INT * iu, const float * abstol, DAAL_INT * m, float * w, + float * z, const DAAL_INT * ldz, DAAL_INT * isuppz, float * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + ssyevr_(jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, liwork, info); + } + + static void xxsyevr(const char * jobz, const char * range, const char * uplo, const DAAL_INT * n, float * a, const DAAL_INT * lda, + const float * vl, const float * vu, const DAAL_INT * il, const DAAL_INT * iu, const float * abstol, DAAL_INT * m, float * w, + float * z, const DAAL_INT * ldz, DAAL_INT * isuppz, float * work, const DAAL_INT * lwork, DAAL_INT * iwork, + const DAAL_INT * liwork, DAAL_INT * info) + { + openblas_thread_setter ots(1); + ssyevr_(jobz, range, uplo, n, a, lda, vl, vu, il, iu, abstol, m, w, z, ldz, isuppz, work, lwork, iwork, liwork, info); + } + static void xormqr(char * side, char * trans, DAAL_INT * m, DAAL_INT * n, DAAL_INT * k, float * a, DAAL_INT * lda, float * tau, float * c, DAAL_INT * ldc, float * work, DAAL_INT * lwork, DAAL_INT * info) { @@ -393,6 +435,14 @@ struct OpenBlasLapack openblas_thread_setter ots(1); sormqr_(side, trans, m, n, k, a, lda, tau, c, ldc, work, lwork, info); } + + static void xrscl(const DAAL_INT * n, const float * sa, float * sx, const DAAL_INT * incx) { srscl_(n, sa, sx, incx); } + + static void xxrscl(const DAAL_INT * n, const float * sa, float * sx, const DAAL_INT * incx) + { + openblas_thread_setter ots(1); + srscl_(n, sa, sx, incx); + } }; } // namespace ref diff --git a/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp b/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp index 00ec7babbb9..a5585202c31 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/batch.cpp @@ -50,6 +50,14 @@ TEMPLATE_LIST_TEST_M(lr_batch_test, "LR common flow", "[lr][batch]", lr_types) { this->run_and_check_linear(); } +TEMPLATE_LIST_TEST_M(lr_batch_test, "LR with non-PSD matrix", "[lr][batch-nonpsd]", lr_types) { + SKIP_IF(this->non_psd_system_not_supported_on_device()); + + this->generate(777); + this->run_and_check_linear_indefinite(); + this->run_and_check_linear_indefinite_multioutput(); +} + TEMPLATE_LIST_TEST_M(lr_batch_test, "RR common flow", "[rr][batch]", lr_types) { SKIP_IF(this->not_float64_friendly()); diff --git a/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp b/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp index fb935174cfe..e1e54092552 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp @@ -65,6 +65,10 @@ class lr_test : public te::crtp_algo_fixture { return static_cast(this); } + bool non_psd_system_not_supported_on_device() { + return this->get_policy().is_gpu(); + } + table compute_responses(const table& beta, const table& bias, const table& data) const { const auto s_count = data.get_row_count(); @@ -291,6 +295,147 @@ class lr_test : public te::crtp_algo_fixture { } } + /* Note: difference between this test and the above, is that the linear system to solve + here is not positive-definite, thus it has an infinite number of possible solutions. The + solution here is the one with minimum norm, which is typically more desirable. */ + void run_and_check_linear_indefinite(double tol = 1e-3) { + const double X[] = { -0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309, + 0.57710379, -0.63646365, 0.54195222, -0.31659545, -0.32238912, + 0.09716732, -1.52593041, 1.1921661, -0.67108968, 1.00026942 }; + const double y[] = { 0.13632112, 1.53203308, -0.65996941 }; + auto X_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(3, 5) + .copy_data(X, 3, 5) + .build(); + auto y_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(3, 1) + .copy_data(y, 3, 1) + .build(); + + auto desc = this->get_descriptor(); + auto train_res = this->train(desc, X_tbl, y_tbl); + const auto coefs = train_res.get_coefficients(); + + if (desc.get_result_options().test(result_options::intercept)) { + const double expected_beta[] = { 0.27785494, + 0.53011669, + 0.34352259, + 0.40506216, + -1.26026447 }; + const double expected_intercept[] = { 1.24485441 }; + const auto expected_beta_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(1, 5) + .copy_data(expected_beta, 1, 5) + .build(); + const auto expected_intercept_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(1, 1) + .copy_data(expected_intercept, 1, 1) + .build(); + + const auto intercept = train_res.get_intercept(); + + SECTION("Checking intercept values") { + check_if_close(intercept, expected_intercept_tbl, tol); + } + SECTION("Checking coefficient values") { + check_if_close(coefs, expected_beta_tbl, tol); + } + } + + else { + const double expected_beta[] = { 0.38217445, + 0.2732197, + 1.87135517, + 0.63458468, + -2.08473134 }; + const auto expected_beta_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(1, 5) + .copy_data(expected_beta, 1, 5) + .build(); + SECTION("Checking coefficient values") { + check_if_close(coefs, expected_beta_tbl, tol); + } + } + } + + void run_and_check_linear_indefinite_multioutput(double tol = 1e-3) { + const double X[] = { -0.98912135, -0.36778665, 1.28792526, 0.19397442, 0.9202309, + 0.57710379, -0.63646365, 0.54195222, -0.31659545, -0.32238912, + 0.09716732, -1.52593041, 1.1921661, -0.67108968, 1.00026942 }; + const double y[] = { 0.13632112, 1.53203308, -0.65996941, + -0.31179486, 0.33776913, -2.2074711 }; + auto X_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(3, 5) + .copy_data(X, 3, 5) + .build(); + auto y_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(3, 2) + .copy_data(y, 3, 2) + .build(); + + auto desc = this->get_descriptor(); + auto train_res = this->train(desc, X_tbl, y_tbl); + const auto coefs = train_res.get_coefficients(); + + if (desc.get_result_options().test(result_options::intercept)) { + const double expected_beta[] = { + -0.18692112, -0.20034801, -0.09590892, -0.13672683, 0.56229012, + -0.97006008, 1.39413595, 0.49238012, 1.11041239, -0.79213452, + }; + const double expected_intercept[] = { -0.48964358, 0.96467681 }; + const auto expected_beta_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(2, 5) + .copy_data(expected_beta, 2, 5) + .build(); + const auto expected_intercept_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(1, 2) + .copy_data(expected_intercept, 1, 2) + .build(); + + const auto intercept = train_res.get_intercept(); + + SECTION("Checking intercept values") { + check_if_close(intercept, expected_intercept_tbl, tol); + } + SECTION("Checking coefficient values") { + check_if_close(coefs, expected_beta_tbl, tol); + } + } + + else { + const double expected_beta[] = { -0.22795353, -0.09930168, -0.69685744, -0.22700585, + 0.88658098, -0.88921961, 1.19505839, 1.67634561, + 1.2882766, -1.43103981 }; + const auto expected_beta_tbl = oneapi::dal::detail::homogen_table_builder() + .set_data_type(data_type::float64) + .set_layout(data_layout::row_major) + .allocate(2, 5) + .copy_data(expected_beta, 2, 5) + .build(); + SECTION("Checking coefficient values") { + check_if_close(coefs, expected_beta_tbl, tol); + } + } + } + template std::vector split_table_by_rows(const dal::table& t, std::int64_t split_count) { ONEDAL_ASSERT(0l < split_count);