From 0a155d028895a0cfb894f76b8f6d7606718d4034 Mon Sep 17 00:00:00 2001 From: "Yu-Hsiang M. Tsai" Date: Fri, 25 Oct 2024 13:31:14 +0200 Subject: [PATCH] factorization sycl type --- dpcpp/factorization/factorization_kernels.dp.cpp | 2 +- dpcpp/factorization/par_ic_kernels.dp.cpp | 14 +++++++------- dpcpp/factorization/par_ict_kernels.dp.cpp | 13 +++++++------ .../par_ilut_approx_filter_kernel.dp.cpp | 4 ++-- dpcpp/factorization/par_ilut_filter_kernel.dp.cpp | 2 +- dpcpp/factorization/par_ilut_select_kernel.dp.cpp | 2 +- dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp | 8 ++++---- dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp | 9 +++++---- 8 files changed, 28 insertions(+), 26 deletions(-) diff --git a/dpcpp/factorization/factorization_kernels.dp.cpp b/dpcpp/factorization/factorization_kernels.dp.cpp index 1d9912b4f12..04bd49c2c9a 100644 --- a/dpcpp/factorization/factorization_kernels.dp.cpp +++ b/dpcpp/factorization/factorization_kernels.dp.cpp @@ -496,7 +496,7 @@ void add_diagonal_elements(std::shared_ptr exec, array needs_change_device{exec, 1}; needs_change_device = needs_change_host; - auto dpcpp_old_values = mtx->get_const_values(); + auto dpcpp_old_values = as_device_type(mtx->get_const_values()); auto dpcpp_old_col_idxs = mtx->get_const_col_idxs(); auto dpcpp_old_row_ptrs = mtx->get_row_ptrs(); auto dpcpp_row_ptrs_add = row_ptrs_addition.get_data(); diff --git a/dpcpp/factorization/par_ic_kernels.dp.cpp b/dpcpp/factorization/par_ic_kernels.dp.cpp index 5428460fac5..0ae155a4f82 100644 --- a/dpcpp/factorization/par_ic_kernels.dp.cpp +++ b/dpcpp/factorization/par_ic_kernels.dp.cpp @@ -125,7 +125,7 @@ void init_factor(std::shared_ptr exec, auto num_rows = l->get_size()[0]; auto num_blocks = ceildiv(num_rows, default_block_size); auto l_row_ptrs = l->get_const_row_ptrs(); - auto l_vals = l->get_values(); + auto l_vals = as_device_type(l->get_values()); kernel::ic_init(num_blocks, default_block_size, 0, exec->get_queue(), l_row_ptrs, l_vals, num_rows); } @@ -143,12 +143,12 @@ void compute_factor(std::shared_ptr exec, auto nnz = l->get_num_stored_elements(); auto num_blocks = ceildiv(nnz, default_block_size); for (size_type i = 0; i < iterations; ++i) { - kernel::ic_sweep(num_blocks, default_block_size, 0, exec->get_queue(), - a_lower->get_const_row_idxs(), - a_lower->get_const_col_idxs(), - a_lower->get_const_values(), l->get_const_row_ptrs(), - l->get_const_col_idxs(), l->get_values(), - static_cast(l->get_num_stored_elements())); + kernel::ic_sweep( + num_blocks, default_block_size, 0, exec->get_queue(), + a_lower->get_const_row_idxs(), a_lower->get_const_col_idxs(), + a_lower->get_const_values(), l->get_const_row_ptrs(), + l->get_const_col_idxs(), as_device_type(l->get_values()), + static_cast(l->get_num_stored_elements())); } } diff --git a/dpcpp/factorization/par_ict_kernels.dp.cpp b/dpcpp/factorization/par_ict_kernels.dp.cpp index fb99b662dec..4f11bf7b7b1 100644 --- a/dpcpp/factorization/par_ict_kernels.dp.cpp +++ b/dpcpp/factorization/par_ict_kernels.dp.cpp @@ -402,13 +402,13 @@ void add_candidates(syn::value_list, matrix::CsrBuilder l_new_builder(l_new); auto llh_row_ptrs = llh->get_const_row_ptrs(); auto llh_col_idxs = llh->get_const_col_idxs(); - auto llh_vals = llh->get_const_values(); + auto llh_vals = as_device_type(llh->get_const_values()); auto a_row_ptrs = a->get_const_row_ptrs(); auto a_col_idxs = a->get_const_col_idxs(); - auto a_vals = a->get_const_values(); + auto a_vals = as_device_type(a->get_const_values()); auto l_row_ptrs = l->get_const_row_ptrs(); auto l_col_idxs = l->get_const_col_idxs(); - auto l_vals = l->get_const_values(); + auto l_vals = as_device_type(l->get_const_values()); auto l_new_row_ptrs = l_new->get_row_ptrs(); // count non-zeros per row kernel::ict_tri_spgeam_nnz( @@ -450,9 +450,10 @@ void compute_factor(syn::value_list, auto num_blocks = ceildiv(total_nnz, block_size); kernel::ict_sweep( num_blocks, default_block_size, 0, exec->get_queue(), - a->get_const_row_ptrs(), a->get_const_col_idxs(), a->get_const_values(), - l->get_const_row_ptrs(), l_coo->get_const_row_idxs(), - l->get_const_col_idxs(), l->get_values(), + a->get_const_row_ptrs(), a->get_const_col_idxs(), + as_device_type(a->get_const_values()), l->get_const_row_ptrs(), + l_coo->get_const_row_idxs(), l->get_const_col_idxs(), + as_device_type(l->get_values()), static_cast(l->get_num_stored_elements())); } diff --git a/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp b/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp index 776ffba3fb1..c808f7e0ae8 100644 --- a/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp @@ -58,7 +58,7 @@ void threshold_filter_approx(syn::value_list, matrix::Csr* m_out, matrix::Coo* m_out_coo) { - auto values = m->get_const_values(); + auto values = as_device_type(m->get_const_values()); IndexType size = m->get_num_stored_elements(); using AbsType = remove_complex; constexpr auto bucket_count = kernel::searchtree_width; @@ -102,7 +102,7 @@ void threshold_filter_approx(syn::value_list, // filter the elements auto old_row_ptrs = m->get_const_row_ptrs(); auto old_col_idxs = m->get_const_col_idxs(); - auto old_vals = m->get_const_values(); + auto old_vals = as_device_type(m->get_const_values()); // compute nnz for each row auto num_rows = static_cast(m->get_size()[0]); auto block_size = default_block_size / subgroup_size; diff --git a/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp b/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp index 5ce9df8a0a9..732a8dc6135 100644 --- a/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_filter_kernel.dp.cpp @@ -57,7 +57,7 @@ void threshold_filter(syn::value_list, { auto old_row_ptrs = a->get_const_row_ptrs(); auto old_col_idxs = a->get_const_col_idxs(); - auto old_vals = a->get_const_values(); + auto old_vals = as_device_type(a->get_const_values()); // compute nnz for each row auto num_rows = static_cast(a->get_size()[0]); auto block_size = default_block_size / subgroup_size; diff --git a/dpcpp/factorization/par_ilut_select_kernel.dp.cpp b/dpcpp/factorization/par_ilut_select_kernel.dp.cpp index 589f8267f21..43c13fc730b 100644 --- a/dpcpp/factorization/par_ilut_select_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_select_kernel.dp.cpp @@ -61,7 +61,7 @@ void threshold_select(std::shared_ptr exec, array>& tmp2, remove_complex& threshold) { - auto values = m->get_const_values(); + auto values = as_device_type(m->get_const_values()); IndexType size = m->get_num_stored_elements(); using AbsType = remove_complex; constexpr auto bucket_count = kernel::searchtree_width; diff --git a/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp b/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp index 246228763bf..f9643fbe66b 100644 --- a/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp @@ -356,16 +356,16 @@ void add_candidates(syn::value_list, matrix::CsrBuilder u_new_builder(u_new); auto lu_row_ptrs = lu->get_const_row_ptrs(); auto lu_col_idxs = lu->get_const_col_idxs(); - auto lu_vals = lu->get_const_values(); + auto lu_vals = as_device_type(lu->get_const_values()); auto a_row_ptrs = a->get_const_row_ptrs(); auto a_col_idxs = a->get_const_col_idxs(); - auto a_vals = a->get_const_values(); + auto a_vals = as_device_type(a->get_const_values()); auto l_row_ptrs = l->get_const_row_ptrs(); auto l_col_idxs = l->get_const_col_idxs(); - auto l_vals = l->get_const_values(); + auto l_vals = as_device_type(l->get_const_values()); auto u_row_ptrs = u->get_const_row_ptrs(); auto u_col_idxs = u->get_const_col_idxs(); - auto u_vals = u->get_const_values(); + auto u_vals = as_device_type(u->get_const_values()); auto l_new_row_ptrs = l_new->get_row_ptrs(); auto u_new_row_ptrs = u_new->get_row_ptrs(); // count non-zeros per row diff --git a/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp b/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp index 601e5dc12d3..4644bb155d2 100644 --- a/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp +++ b/dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp @@ -176,12 +176,13 @@ void compute_l_u_factors(syn::value_list, auto num_blocks = ceildiv(total_nnz, block_size); kernel::sweep( num_blocks, default_block_size, 0, exec->get_queue(), - a->get_const_row_ptrs(), a->get_const_col_idxs(), a->get_const_values(), - l->get_const_row_ptrs(), l_coo->get_const_row_idxs(), - l->get_const_col_idxs(), l->get_values(), + a->get_const_row_ptrs(), a->get_const_col_idxs(), + as_device_type(a->get_const_values()), l->get_const_row_ptrs(), + l_coo->get_const_row_idxs(), l->get_const_col_idxs(), + as_device_type(l->get_values()), static_cast(l->get_num_stored_elements()), u_coo->get_const_row_idxs(), u_coo->get_const_col_idxs(), - u->get_values(), u_csc->get_const_row_ptrs(), + as_device_type(u->get_values()), u_csc->get_const_row_ptrs(), u_csc->get_const_col_idxs(), u_csc->get_values(), static_cast(u->get_num_stored_elements())); }