Skip to content

Commit

Permalink
factorization sycl type
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 25, 2024
1 parent 1f870ba commit 0a155d0
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 26 deletions.
2 changes: 1 addition & 1 deletion dpcpp/factorization/factorization_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ void add_diagonal_elements(std::shared_ptr<const DpcppExecutor> exec,
array<bool> needs_change_device{exec, 1};
needs_change_device = needs_change_host;

auto dpcpp_old_values = mtx->get_const_values();
auto dpcpp_old_values = as_device_type(mtx->get_const_values());
auto dpcpp_old_col_idxs = mtx->get_const_col_idxs();
auto dpcpp_old_row_ptrs = mtx->get_row_ptrs();
auto dpcpp_row_ptrs_add = row_ptrs_addition.get_data();
Expand Down
14 changes: 7 additions & 7 deletions dpcpp/factorization/par_ic_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ void init_factor(std::shared_ptr<const DefaultExecutor> exec,
auto num_rows = l->get_size()[0];
auto num_blocks = ceildiv(num_rows, default_block_size);
auto l_row_ptrs = l->get_const_row_ptrs();
auto l_vals = l->get_values();
auto l_vals = as_device_type(l->get_values());
kernel::ic_init(num_blocks, default_block_size, 0, exec->get_queue(),
l_row_ptrs, l_vals, num_rows);
}
Expand All @@ -143,12 +143,12 @@ void compute_factor(std::shared_ptr<const DefaultExecutor> exec,
auto nnz = l->get_num_stored_elements();
auto num_blocks = ceildiv(nnz, default_block_size);
for (size_type i = 0; i < iterations; ++i) {
kernel::ic_sweep(num_blocks, default_block_size, 0, exec->get_queue(),
a_lower->get_const_row_idxs(),
a_lower->get_const_col_idxs(),
a_lower->get_const_values(), l->get_const_row_ptrs(),
l->get_const_col_idxs(), l->get_values(),
static_cast<IndexType>(l->get_num_stored_elements()));
kernel::ic_sweep(
num_blocks, default_block_size, 0, exec->get_queue(),
a_lower->get_const_row_idxs(), a_lower->get_const_col_idxs(),
a_lower->get_const_values(), l->get_const_row_ptrs(),
l->get_const_col_idxs(), as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()));
}
}

Expand Down
13 changes: 7 additions & 6 deletions dpcpp/factorization/par_ict_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,13 @@ void add_candidates(syn::value_list<int, subgroup_size>,
matrix::CsrBuilder<ValueType, IndexType> l_new_builder(l_new);
auto llh_row_ptrs = llh->get_const_row_ptrs();
auto llh_col_idxs = llh->get_const_col_idxs();
auto llh_vals = llh->get_const_values();
auto llh_vals = as_device_type(llh->get_const_values());
auto a_row_ptrs = a->get_const_row_ptrs();
auto a_col_idxs = a->get_const_col_idxs();
auto a_vals = a->get_const_values();
auto a_vals = as_device_type(a->get_const_values());
auto l_row_ptrs = l->get_const_row_ptrs();
auto l_col_idxs = l->get_const_col_idxs();
auto l_vals = l->get_const_values();
auto l_vals = as_device_type(l->get_const_values());
auto l_new_row_ptrs = l_new->get_row_ptrs();
// count non-zeros per row
kernel::ict_tri_spgeam_nnz<subgroup_size>(
Expand Down Expand Up @@ -450,9 +450,10 @@ void compute_factor(syn::value_list<int, subgroup_size>,
auto num_blocks = ceildiv(total_nnz, block_size);
kernel::ict_sweep<subgroup_size>(
num_blocks, default_block_size, 0, exec->get_queue(),
a->get_const_row_ptrs(), a->get_const_col_idxs(), a->get_const_values(),
l->get_const_row_ptrs(), l_coo->get_const_row_idxs(),
l->get_const_col_idxs(), l->get_values(),
a->get_const_row_ptrs(), a->get_const_col_idxs(),
as_device_type(a->get_const_values()), l->get_const_row_ptrs(),
l_coo->get_const_row_idxs(), l->get_const_col_idxs(),
as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()));
}

Expand Down
4 changes: 2 additions & 2 deletions dpcpp/factorization/par_ilut_approx_filter_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void threshold_filter_approx(syn::value_list<int, subgroup_size>,
matrix::Csr<ValueType, IndexType>* m_out,
matrix::Coo<ValueType, IndexType>* m_out_coo)
{
auto values = m->get_const_values();
auto values = as_device_type(m->get_const_values());
IndexType size = m->get_num_stored_elements();
using AbsType = remove_complex<ValueType>;
constexpr auto bucket_count = kernel::searchtree_width;
Expand Down Expand Up @@ -102,7 +102,7 @@ void threshold_filter_approx(syn::value_list<int, subgroup_size>,
// filter the elements
auto old_row_ptrs = m->get_const_row_ptrs();
auto old_col_idxs = m->get_const_col_idxs();
auto old_vals = m->get_const_values();
auto old_vals = as_device_type(m->get_const_values());
// compute nnz for each row
auto num_rows = static_cast<IndexType>(m->get_size()[0]);
auto block_size = default_block_size / subgroup_size;
Expand Down
2 changes: 1 addition & 1 deletion dpcpp/factorization/par_ilut_filter_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void threshold_filter(syn::value_list<int, subgroup_size>,
{
auto old_row_ptrs = a->get_const_row_ptrs();
auto old_col_idxs = a->get_const_col_idxs();
auto old_vals = a->get_const_values();
auto old_vals = as_device_type(a->get_const_values());
// compute nnz for each row
auto num_rows = static_cast<IndexType>(a->get_size()[0]);
auto block_size = default_block_size / subgroup_size;
Expand Down
2 changes: 1 addition & 1 deletion dpcpp/factorization/par_ilut_select_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void threshold_select(std::shared_ptr<const DefaultExecutor> exec,
array<remove_complex<ValueType>>& tmp2,
remove_complex<ValueType>& threshold)
{
auto values = m->get_const_values();
auto values = as_device_type(m->get_const_values());
IndexType size = m->get_num_stored_elements();
using AbsType = remove_complex<ValueType>;
constexpr auto bucket_count = kernel::searchtree_width;
Expand Down
8 changes: 4 additions & 4 deletions dpcpp/factorization/par_ilut_spgeam_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,16 +356,16 @@ void add_candidates(syn::value_list<int, subgroup_size>,
matrix::CsrBuilder<ValueType, IndexType> u_new_builder(u_new);
auto lu_row_ptrs = lu->get_const_row_ptrs();
auto lu_col_idxs = lu->get_const_col_idxs();
auto lu_vals = lu->get_const_values();
auto lu_vals = as_device_type(lu->get_const_values());
auto a_row_ptrs = a->get_const_row_ptrs();
auto a_col_idxs = a->get_const_col_idxs();
auto a_vals = a->get_const_values();
auto a_vals = as_device_type(a->get_const_values());
auto l_row_ptrs = l->get_const_row_ptrs();
auto l_col_idxs = l->get_const_col_idxs();
auto l_vals = l->get_const_values();
auto l_vals = as_device_type(l->get_const_values());
auto u_row_ptrs = u->get_const_row_ptrs();
auto u_col_idxs = u->get_const_col_idxs();
auto u_vals = u->get_const_values();
auto u_vals = as_device_type(u->get_const_values());
auto l_new_row_ptrs = l_new->get_row_ptrs();
auto u_new_row_ptrs = u_new->get_row_ptrs();
// count non-zeros per row
Expand Down
9 changes: 5 additions & 4 deletions dpcpp/factorization/par_ilut_sweep_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,13 @@ void compute_l_u_factors(syn::value_list<int, subgroup_size>,
auto num_blocks = ceildiv(total_nnz, block_size);
kernel::sweep<subgroup_size>(
num_blocks, default_block_size, 0, exec->get_queue(),
a->get_const_row_ptrs(), a->get_const_col_idxs(), a->get_const_values(),
l->get_const_row_ptrs(), l_coo->get_const_row_idxs(),
l->get_const_col_idxs(), l->get_values(),
a->get_const_row_ptrs(), a->get_const_col_idxs(),
as_device_type(a->get_const_values()), l->get_const_row_ptrs(),
l_coo->get_const_row_idxs(), l->get_const_col_idxs(),
as_device_type(l->get_values()),
static_cast<IndexType>(l->get_num_stored_elements()),
u_coo->get_const_row_idxs(), u_coo->get_const_col_idxs(),
u->get_values(), u_csc->get_const_row_ptrs(),
as_device_type(u->get_values()), u_csc->get_const_row_ptrs(),
u_csc->get_const_col_idxs(), u_csc->get_values(),
static_cast<IndexType>(u->get_num_stored_elements()));
}
Expand Down

0 comments on commit 0a155d0

Please sign in to comment.