From 7b7f61e9f03ebd875f68bf9db12633b28757b6af Mon Sep 17 00:00:00 2001 From: Khalil Date: Thu, 18 Jul 2024 15:13:54 +0200 Subject: [PATCH] feature: Linear Regression online spmd support (#2846) --- .../gpu/finalize_train_kernel_norm_eq_dpc.cpp | 121 ++--------------- .../finalize_train_kernel_norm_eq_impl.hpp | 51 +++++++ ...finalize_train_kernel_norm_eq_impl_dpc.cpp | 127 ++++++++++++++++++ .../linear_regression/backend/gpu/misc.hpp | 4 +- .../backend/gpu/train_kernel_norm_eq_dpc.cpp | 17 ++- .../detail/finalize_train_ops_dpc.cpp | 12 +- .../algo/linear_regression/test/fixture.hpp | 3 + .../linear_regression/test/online_spmd.cpp | 126 +++++++++++++++++ 8 files changed, 336 insertions(+), 125 deletions(-) create mode 100644 cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp create mode 100644 cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp create mode 100644 cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp index d3431663249..a74723e1b00 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_dpc.cpp @@ -14,129 +14,32 @@ * limitations under the License. *******************************************************************************/ -#include "oneapi/dal/detail/common.hpp" -#include "oneapi/dal/backend/dispatcher.hpp" -#include "oneapi/dal/backend/primitives/ndarray.hpp" -#include "oneapi/dal/backend/primitives/lapack.hpp" -#include "oneapi/dal/backend/primitives/utils.hpp" - -#include "oneapi/dal/table/row_accessor.hpp" - -#include "oneapi/dal/algo/linear_regression/common.hpp" -#include "oneapi/dal/algo/linear_regression/train_types.hpp" -#include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp" #include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel.hpp" -#include "oneapi/dal/algo/linear_regression/backend/gpu/update_kernel.hpp" -#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp" - -namespace oneapi::dal::linear_regression::backend { - -using dal::backend::context_gpu; - -namespace be = dal::backend; -namespace pr = be::primitives; - -template -static train_result call_dal_kernel(const context_gpu& ctx, - const detail::descriptor_base& desc, - const detail::train_parameters& params, - const partial_train_result& input) { - using dal::detail::check_mul_overflow; - - using model_t = model; - using model_impl_t = detail::model_impl; - - auto& queue = ctx.get_queue(); - - const bool compute_intercept = desc.get_compute_intercept(); - - constexpr auto uplo = pr::mkl::uplo::upper; - constexpr auto alloc = sycl::usm::alloc::device; - - const auto response_count = input.get_partial_xty().get_row_count(); - const auto ext_feature_count = input.get_partial_xty().get_column_count(); - const auto feature_count = ext_feature_count - compute_intercept; - - const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count }; - - const auto xtx_nd = - pr::table2ndarray(queue, input.get_partial_xtx(), sycl::usm::alloc::device); - const auto xty_nd = pr::table2ndarray(queue, - input.get_partial_xty(), - sycl::usm::alloc::device); - - const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 }; +#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp" - const auto betas_size = check_mul_overflow(response_count, feature_count + 1); - auto betas_arr = array::zeros(queue, betas_size, alloc); - - double alpha = desc.get_alpha(); - sycl::event ridge_event; - if (alpha != 0.0) { - ridge_event = add_ridge_penalty(queue, xtx_nd, compute_intercept, alpha); - } - - auto nxtx = pr::ndarray::empty(queue, xtx_shape, alloc); - auto nxty = pr::ndview::wrap_mutable(betas_arr, betas_shape); - auto solve_event = pr::solve_system(queue, - compute_intercept, - xtx_nd, - xty_nd, - nxtx, - nxty, - { ridge_event }); - sycl::event::wait_and_throw({ solve_event }); - - auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1); - - const auto model_impl = std::make_shared(betas); - const auto model = dal::detail::make_private(model_impl); - - const auto options = desc.get_result_options(); - auto result = train_result().set_model(model).set_result_options(options); - - if (options.test(result_options::intercept)) { - auto arr = array::zeros(queue, response_count, alloc); - auto dst = pr::ndview::wrap_mutable(arr, { 1l, response_count }); - const auto src = nxty.get_col_slice(0l, 1l).t(); - - pr::copy(queue, dst, src).wait_and_throw(); - - auto intercept = homogen_table::wrap(arr, 1l, response_count); - result.set_intercept(intercept); - } - - if (options.test(result_options::coefficients)) { - const auto size = check_mul_overflow(response_count, feature_count); - - auto arr = array::zeros(queue, size, alloc); - const auto src = nxty.get_col_slice(1l, feature_count + 1); - auto dst = pr::ndview::wrap_mutable(arr, { response_count, feature_count }); +#include "oneapi/dal/detail/common.hpp" - pr::copy(queue, dst, src).wait_and_throw(); +#include "oneapi/dal/backend/dispatcher.hpp" - auto coefficients = homogen_table::wrap(arr, response_count, feature_count); - result.set_coefficients(coefficients); - } +namespace oneapi::dal::linear_regression::backend { - return result; -} +namespace bk = dal::backend; template -static train_result train(const context_gpu& ctx, - const detail::descriptor_base& desc, - const detail::train_parameters& params, - const partial_train_result& input) { - return call_dal_kernel(ctx, desc, params, input); +static train_result finalize_train(const bk::context_gpu& ctx, + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const partial_train_result& input) { + return finalize_train_kernel_norm_eq_impl(ctx)(desc, params, input); } template struct finalize_train_kernel_gpu { - train_result operator()(const context_gpu& ctx, + train_result operator()(const bk::context_gpu& ctx, const detail::descriptor_base& desc, const detail::train_parameters& params, const partial_train_result& input) const { - return train(ctx, desc, params, input); + return finalize_train(ctx, desc, params, input); } }; diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp new file mode 100644 index 00000000000..6eeaf17c0da --- /dev/null +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel.hpp" +#include "oneapi/dal/backend/primitives/utils.hpp" + +#ifdef ONEDAL_DATA_PARALLEL + +namespace oneapi::dal::linear_regression::backend { + +namespace bk = dal::backend; + +template +class finalize_train_kernel_norm_eq_impl { + using comm_t = bk::communicator; + using input_t = partial_train_result; + using result_t = train_result; + using descriptor_t = detail::descriptor_base; + using train_parameters_t = detail::train_parameters; + +public: + finalize_train_kernel_norm_eq_impl(const bk::context_gpu& ctx) + : q(ctx.get_queue()), + comm_(ctx.get_communicator()) {} + result_t operator()(const descriptor_t& desc, + const train_parameters_t& params, + const input_t& input); + +private: + sycl::queue q; + comm_t comm_; +}; + +} // namespace oneapi::dal::linear_regression::backend + +#endif // ONEDAL_DATA_PARALLEL diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp new file mode 100644 index 00000000000..c470f45403e --- /dev/null +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl_dpc.cpp @@ -0,0 +1,127 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/linear_regression/backend/gpu/finalize_train_kernel_norm_eq_impl.hpp" +#include "oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp" +#include "oneapi/dal/algo/linear_regression/backend/model_impl.hpp" + +#include "oneapi/dal/backend/primitives/lapack.hpp" + +namespace oneapi::dal::linear_regression::backend { + +namespace be = dal::backend; +namespace pr = be::primitives; + +using be::context_gpu; + +template +train_result finalize_train_kernel_norm_eq_impl::operator()( + const detail::descriptor_base& desc, + const detail::train_parameters& params, + const partial_train_result& input) { + using dal::detail::check_mul_overflow; + + using model_t = model; + using model_impl_t = detail::model_impl; + + const bool compute_intercept = desc.get_compute_intercept(); + + constexpr auto uplo = pr::mkl::uplo::upper; + constexpr auto alloc = sycl::usm::alloc::device; + + const auto response_count = input.get_partial_xty().get_row_count(); + const auto ext_feature_count = input.get_partial_xty().get_column_count(); + const auto feature_count = ext_feature_count - compute_intercept; + + const pr::ndshape<2> xtx_shape{ ext_feature_count, ext_feature_count }; + + const auto xtx_nd = + pr::table2ndarray(q, input.get_partial_xtx(), sycl::usm::alloc::device); + const auto xty_nd = pr::table2ndarray(q, + input.get_partial_xty(), + sycl::usm::alloc::device); + + const pr::ndshape<2> betas_shape{ response_count, feature_count + 1 }; + + const auto betas_size = check_mul_overflow(response_count, feature_count + 1); + auto betas_arr = array::zeros(q, betas_size, alloc); + + if (comm_.get_rank_count() > 1) { + { + ONEDAL_PROFILER_TASK(xtx_allreduce); + auto xtx_arr = + dal::array::wrap(q, xtx_nd.get_mutable_data(), xtx_nd.get_count()); + comm_.allreduce(xtx_arr).wait(); + } + { + ONEDAL_PROFILER_TASK(xty_allreduce); + auto xty_arr = + dal::array::wrap(q, xty_nd.get_mutable_data(), xty_nd.get_count()); + comm_.allreduce(xty_arr).wait(); + } + } + + double alpha = desc.get_alpha(); + sycl::event ridge_event; + if (alpha != 0.0) { + ridge_event = add_ridge_penalty(q, xtx_nd, compute_intercept, alpha); + } + + auto nxtx = pr::ndarray::empty(q, xtx_shape, alloc); + auto nxty = pr::ndview::wrap_mutable(betas_arr, betas_shape); + auto solve_event = + pr::solve_system(q, compute_intercept, xtx_nd, xty_nd, nxtx, nxty, { ridge_event }); + sycl::event::wait_and_throw({ solve_event }); + + auto betas = homogen_table::wrap(betas_arr, response_count, feature_count + 1); + + const auto model_impl = std::make_shared(betas); + const auto model = dal::detail::make_private(model_impl); + + const auto options = desc.get_result_options(); + auto result = train_result().set_model(model).set_result_options(options); + + if (options.test(result_options::intercept)) { + auto arr = array::zeros(q, response_count, alloc); + auto dst = pr::ndview::wrap_mutable(arr, { 1l, response_count }); + const auto src = nxty.get_col_slice(0l, 1l).t(); + + pr::copy(q, dst, src).wait_and_throw(); + + auto intercept = homogen_table::wrap(arr, 1l, response_count); + result.set_intercept(intercept); + } + + if (options.test(result_options::coefficients)) { + const auto size = check_mul_overflow(response_count, feature_count); + + auto arr = array::zeros(q, size, alloc); + const auto src = nxty.get_col_slice(1l, feature_count + 1); + auto dst = pr::ndview::wrap_mutable(arr, { response_count, feature_count }); + + pr::copy(q, dst, src).wait_and_throw(); + + auto coefficients = homogen_table::wrap(arr, response_count, feature_count); + result.set_coefficients(coefficients); + } + + return result; +} + +template class finalize_train_kernel_norm_eq_impl; +template class finalize_train_kernel_norm_eq_impl; + +} // namespace oneapi::dal::linear_regression::backend diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp index 5ad5ba647ec..723fde68fb9 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/misc.hpp @@ -44,7 +44,7 @@ sycl::event add_ridge_penalty(sycl::queue& q, Float alpha, const bk::event_vector& deps = {}) { ONEDAL_ASSERT(xtx.has_mutable_data()); - ONEDAL_ASSERT(be::is_known_usm(q, xtx.get_mutable_data())); + ONEDAL_ASSERT(bk::is_known_usm(q, xtx.get_mutable_data())); ONEDAL_ASSERT(xtx.get_dimension(0) == xtx.get_dimension(1)); Float* xtx_ptr = xtx.get_mutable_data(); @@ -52,7 +52,7 @@ sycl::event add_ridge_penalty(sycl::queue& q, std::int64_t original_feature_count = feature_count - compute_intercept; return q.submit([&](sycl::handler& cgh) { - const auto range = be::make_range_1d(original_feature_count); + const auto range = bk::make_range_1d(original_feature_count); cgh.depends_on(deps); std::int64_t step = feature_count + 1; cgh.parallel_for(range, [=](sycl::id<1> idx) { diff --git a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp index 25b08aa7710..04d76fe86b7 100644 --- a/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/backend/gpu/train_kernel_norm_eq_dpc.cpp @@ -104,17 +104,9 @@ static train_result call_dal_kernel(const context_gpu& ctx, old_x_arr = std::move(x_arr), old_y_arr = std::move(y_arr); } - const be::event_vector solve_deps{ last_xty_event, last_xtx_event }; - - double alpha = desc.get_alpha(); - if (alpha != 0.0) { - last_xtx_event = - add_ridge_penalty(queue, xtx, compute_intercept, alpha, { last_xtx_event }); - } - auto& comm = ctx.get_communicator(); if (comm.get_rank_count() > 1) { - sycl::event::wait_and_throw(solve_deps); + sycl::event::wait_and_throw({ last_xty_event, last_xtx_event }); { ONEDAL_PROFILER_TASK(xtx_allreduce); auto xtx_arr = dal::array::wrap(queue, xtx.get_mutable_data(), xtx.get_count()); @@ -127,6 +119,13 @@ static train_result call_dal_kernel(const context_gpu& ctx, } } + double alpha = desc.get_alpha(); + if (alpha != 0.0) { + last_xtx_event = + add_ridge_penalty(queue, xtx, compute_intercept, alpha, { last_xtx_event }); + } + const be::event_vector solve_deps{ last_xty_event, last_xtx_event }; + auto nxtx = pr::ndarray::empty(queue, xtx_shape, alloc); auto nxty = pr::ndview::wrap_mutable(betas_arr, betas_shape); auto solve_event = diff --git a/cpp/oneapi/dal/algo/linear_regression/detail/finalize_train_ops_dpc.cpp b/cpp/oneapi/dal/algo/linear_regression/detail/finalize_train_ops_dpc.cpp index 3592aeefccb..21a5ce8108d 100644 --- a/cpp/oneapi/dal/algo/linear_regression/detail/finalize_train_ops_dpc.cpp +++ b/cpp/oneapi/dal/algo/linear_regression/detail/finalize_train_ops_dpc.cpp @@ -38,7 +38,7 @@ struct finalize_train_ops_dispatcher { const partial_train_result& input) const { using kernel_dispatcher_t = dal::backend::kernel_dispatcher< KERNEL_SINGLE_NODE_CPU(parameters::train_parameters_cpu), - KERNEL_SINGLE_NODE_GPU(parameters::train_parameters_gpu)>; + KERNEL_UNIVERSAL_SPMD_GPU(parameters::train_parameters_gpu)>; return kernel_dispatcher_t{}(ctx, desc, input); } @@ -56,14 +56,16 @@ struct finalize_train_ops_dispatcher { const partial_train_result& input) const { using kernel_dispatcher_t = dal::backend::kernel_dispatcher< KERNEL_SINGLE_NODE_CPU(backend::finalize_train_kernel_cpu), - KERNEL_SINGLE_NODE_GPU(backend::finalize_train_kernel_gpu)>; + KERNEL_UNIVERSAL_SPMD_GPU(backend::finalize_train_kernel_gpu)>; return kernel_dispatcher_t{}(ctx, desc, params, input); } }; -#define INSTANTIATE(F, M, T) \ - template struct ONEDAL_EXPORT \ - finalize_train_ops_dispatcher; +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT \ + finalize_train_ops_dispatcher; \ + template struct ONEDAL_EXPORT \ + finalize_train_ops_dispatcher; INSTANTIATE(float, method::norm_eq, task::regression) INSTANTIATE(double, method::norm_eq, task::regression) diff --git a/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp b/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp index aedf0165454..fb935174cfe 100644 --- a/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp +++ b/cpp/oneapi/dal/algo/linear_regression/test/fixture.hpp @@ -54,6 +54,9 @@ class lr_test : public te::crtp_algo_fixture { using test_input_t = infer_input; using test_result_t = infer_result; + using partial_input_t = partial_train_input<>; + using partial_result_t = partial_train_result<>; + te::table_id get_homogen_table_id() const { return te::table_id::homogen(); } diff --git a/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp b/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp new file mode 100644 index 00000000000..c0f7968adfc --- /dev/null +++ b/cpp/oneapi/dal/algo/linear_regression/test/online_spmd.cpp @@ -0,0 +1,126 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/linear_regression/test/fixture.hpp" +#include "oneapi/dal/test/engine/tables.hpp" +#include "oneapi/dal/test/engine/io.hpp" + +namespace oneapi::dal::linear_regression::test { + +namespace te = dal::test::engine; +namespace la = te::linalg; +namespace linear_regression = oneapi::dal::linear_regression; + +template +class lr_online_spmd_test : public lr_test> { +public: + using base_t = lr_test>; + using float_t = typename base_t::float_t; + using input_t = typename base_t::train_input_t; + using partial_input_t = typename base_t::partial_input_t; + using partial_result_t = typename base_t::partial_result_t; + using result_t = typename base_t::train_result_t; + + void set_rank_count(std::int64_t rank_count) { + n_rank = rank_count; + } + + std::int64_t get_rank_count() { + return n_rank; + } + + void generate_dimensions() { + this->t_count_ = GENERATE(307, 12999); + this->s_count_ = GENERATE(10000); + this->f_count_ = GENERATE(2, 17); + this->r_count_ = GENERATE(2, 15); + this->intercept_ = GENERATE(0, 1); + } + + template + result_t finalize_train_override(Args&&... args) { + return this->finalize_train_via_spmd_threads_and_merge(n_rank, std::forward(args)...); + } + + result_t merge_finalize_train_result_override(const std::vector& results) { + return results[0]; + } + + template + std::vector split_finalize_train_input_override(std::int64_t split_count, + Args&&... args) { + ONEDAL_ASSERT(split_count == n_rank); + const std::vector input{ std::forward(args)... }; + + return input; + } + + void run_and_check_linear_online_spmd(std::int64_t n_rank, + std::int64_t n_blocks, + std::int64_t seed = 888, + double tol = 1e-2) { + table x_train, y_train, x_test, y_test; + std::tie(x_train, y_train, x_test, y_test) = this->prepare_inputs(seed, tol); + + const auto desc = this->get_descriptor(); + std::vector partial_results; + auto input_table_x = base_t::template split_table_by_rows(x_train, n_rank); + auto input_table_y = base_t::template split_table_by_rows(y_train, n_rank); + for (int64_t i = 0; i < n_rank; i++) { + partial_result_t partial_result; + auto input_table_x_blocks = + base_t::template split_table_by_rows(input_table_x[i], n_blocks); + auto input_table_y_blocks = + base_t::template split_table_by_rows(input_table_y[i], n_blocks); + for (int64_t j = 0; j < n_blocks; j++) { + partial_result = this->partial_train(desc, + partial_result, + input_table_x_blocks[j], + input_table_y_blocks[j]); + } + partial_results.push_back(partial_result); + } + + const auto train_result = this->finalize_train_override(desc, partial_results); + + SECTION("Checking intercept values") { + if (desc.get_result_options().test(result_options::intercept)) + base_t::check_if_close(train_result.get_intercept(), base_t::bias_, tol); + } + + SECTION("Checking coefficient values") { + if (desc.get_result_options().test(result_options::coefficients)) + base_t::check_if_close(train_result.get_coefficients(), base_t::beta_, tol); + } + } + +private: + std::int64_t n_rank; +}; + +TEMPLATE_LIST_TEST_M(lr_online_spmd_test, "lr common flow", "[lr][integration][spmd]", lr_types) { + SKIP_IF(this->get_policy().is_cpu()); + SKIP_IF(this->not_float64_friendly()); + + this->generate(777); + + this->set_rank_count(GENERATE(1, 2, 4)); + std::int64_t n_blocks = GENERATE(1, 3, 10); + + this->run_and_check_linear_online_spmd(this->get_rank_count(), n_blocks); +} + +} // namespace oneapi::dal::linear_regression::test