Skip to content

Commit

Permalink
solver/preconditioner/stop sycl type
Browse files Browse the repository at this point in the history
  • Loading branch information
yhmtsai committed Oct 25, 2024
1 parent 0a155d0 commit e9b24cb
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 71 deletions.
32 changes: 18 additions & 14 deletions dpcpp/preconditioner/isai_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,16 +626,20 @@ void generate_tri_inverse(std::shared_ptr<const DefaultExecutor> exec,
kernel::generate_l_inverse<subwarp_size, subwarps_per_block>(
grid, block, 0, exec->get_queue(),
static_cast<IndexType>(num_rows), input->get_const_row_ptrs(),
input->get_const_col_idxs(), input->get_const_values(),
input->get_const_col_idxs(),
as_device_type(input->get_const_values()),
inverse->get_row_ptrs(), inverse->get_col_idxs(),
inverse->get_values(), excess_rhs_ptrs, excess_nz_ptrs);
as_device_type(inverse->get_values()), excess_rhs_ptrs,
excess_nz_ptrs);
} else {
kernel::generate_u_inverse<subwarp_size, subwarps_per_block>(
grid, block, 0, exec->get_queue(),
static_cast<IndexType>(num_rows), input->get_const_row_ptrs(),
input->get_const_col_idxs(), input->get_const_values(),
input->get_const_col_idxs(),
as_device_type(input->get_const_values()),
inverse->get_row_ptrs(), inverse->get_col_idxs(),
inverse->get_values(), excess_rhs_ptrs, excess_nz_ptrs);
as_device_type(inverse->get_values()), excess_rhs_ptrs,
excess_nz_ptrs);
}
}
components::prefix_sum_nonnegative(exec, excess_rhs_ptrs, num_rows + 1);
Expand All @@ -661,9 +665,9 @@ void generate_general_inverse(std::shared_ptr<const DefaultExecutor> exec,
kernel::generate_general_inverse<subwarp_size, subwarps_per_block>(
grid, block, 0, exec->get_queue(), static_cast<IndexType>(num_rows),
input->get_const_row_ptrs(), input->get_const_col_idxs(),
input->get_const_values(), inverse->get_row_ptrs(),
inverse->get_col_idxs(), inverse->get_values(), excess_rhs_ptrs,
excess_nz_ptrs, spd);
as_device_type(input->get_const_values()), inverse->get_row_ptrs(),
inverse->get_col_idxs(), as_device_type(inverse->get_values()),
excess_rhs_ptrs, excess_nz_ptrs, spd);
}
components::prefix_sum_nonnegative(exec, excess_rhs_ptrs, num_rows + 1);
components::prefix_sum_nonnegative(exec, excess_nz_ptrs, num_rows + 1);
Expand Down Expand Up @@ -691,11 +695,11 @@ void generate_excess_system(std::shared_ptr<const DefaultExecutor> exec,
kernel::generate_excess_system<subwarp_size>(
grid, block, 0, exec->get_queue(), static_cast<IndexType>(num_rows),
input->get_const_row_ptrs(), input->get_const_col_idxs(),
input->get_const_values(), inverse->get_const_row_ptrs(),
inverse->get_const_col_idxs(), excess_rhs_ptrs, excess_nz_ptrs,
excess_system->get_row_ptrs(), excess_system->get_col_idxs(),
excess_system->get_values(), excess_rhs->get_values(), e_start,
e_end);
as_device_type(input->get_const_values()),
inverse->get_const_row_ptrs(), inverse->get_const_col_idxs(),
excess_rhs_ptrs, excess_nz_ptrs, excess_system->get_row_ptrs(),
excess_system->get_col_idxs(), excess_system->get_values(),
excess_rhs->get_values(), e_start, e_end);
}
}

Expand Down Expand Up @@ -737,8 +741,8 @@ void scatter_excess_solution(std::shared_ptr<const DefaultExecutor> exec,
kernel::copy_excess_solution<subwarp_size>(
grid, block, 0, exec->get_queue(), static_cast<IndexType>(num_rows),
inverse->get_const_row_ptrs(), excess_rhs_ptrs,
excess_solution->get_const_values(), inverse->get_values(), e_start,
e_end);
excess_solution->get_const_values(),
as_device_type(inverse->get_values()), e_start, e_end);
}
}

Expand Down
7 changes: 4 additions & 3 deletions dpcpp/preconditioner/jacobi_advanced_apply_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ void apply(std::shared_ptr<const DpcppExecutor> exec, size_type num_blocks,
syn::value_list<int, config::min_warps_per_block>(),
syn::type_list<>(), exec, num_blocks,
block_precisions.get_const_data(), block_pointers.get_const_data(),
blocks.get_const_data(), storage_scheme, alpha->get_const_values(),
b->get_const_values() + col, b->get_stride(), x->get_values() + col,
x->get_stride());
blocks.get_const_data(), storage_scheme,
as_device_type(alpha->get_const_values()),
as_device_type(b->get_const_values()) + col, b->get_stride(),
as_device_type(x->get_values()) + col, x->get_stride());
}
}

Expand Down
9 changes: 5 additions & 4 deletions dpcpp/preconditioner/jacobi_generate_instantiate.inc.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,14 +365,15 @@ void generate(syn::value_list<int, max_block_size>,
warps_per_block>(
grid_size, block_size, 0, exec->get_queue(), mtx->get_size()[0],
mtx->get_const_row_ptrs(), mtx->get_const_col_idxs(),
mtx->get_const_values(), accuracy, block_data, storage_scheme,
conditioning, block_precisions, block_ptrs, num_blocks);
as_device_type(mtx->get_const_values()), accuracy, block_data,
storage_scheme, conditioning, block_precisions, block_ptrs,
num_blocks);
} else {
kernel::generate<max_block_size, subwarp_size, warps_per_block>(
grid_size, block_size, 0, exec->get_queue(), mtx->get_size()[0],
mtx->get_const_row_ptrs(), mtx->get_const_col_idxs(),
mtx->get_const_values(), block_data, storage_scheme, block_ptrs,
num_blocks);
as_device_type(mtx->get_const_values()), block_data, storage_scheme,
block_ptrs, num_blocks);
}
}

Expand Down
4 changes: 2 additions & 2 deletions dpcpp/preconditioner/jacobi_simple_apply_kernel.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ void simple_apply(
syn::type_list<>(), exec, num_blocks,
block_precisions.get_const_data(), block_pointers.get_const_data(),
blocks.get_const_data(), storage_scheme,
b->get_const_values() + col, b->get_stride(), x->get_values() + col,
x->get_stride());
as_device_type(b->get_const_values()) + col, b->get_stride(),
as_device_type(x->get_values()) + col, x->get_stride());
}
}

Expand Down
24 changes: 13 additions & 11 deletions dpcpp/solver/cb_gmres_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,11 +939,11 @@ void initialize(std::shared_ptr<const DpcppExecutor> exec,

initialize_kernel<block_size>(
grid_dim, block_dim, 0, exec->get_queue(), b->get_size()[0],
b->get_size()[1], krylov_dim, b->get_const_values(), b->get_stride(),
residual->get_values(), residual->get_stride(),
givens_sin->get_values(), givens_sin->get_stride(),
givens_cos->get_values(), givens_cos->get_stride(),
stop_status->get_data());
b->get_size()[1], krylov_dim, as_device_type(b->get_const_values()),
b->get_stride(), as_device_type(residual->get_values()),
residual->get_stride(), givens_sin->get_values(),
givens_sin->get_stride(), givens_cos->get_values(),
givens_cos->get_stride(), stop_status->get_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_CB_GMRES_INITIALIZE_KERNEL);
Expand Down Expand Up @@ -990,7 +990,8 @@ void restart(std::shared_ptr<const DpcppExecutor> exec,
const dim3 block_size_nrm(default_dot_dim, default_dot_dim);
multinorminf_without_stop_kernel(
grid_size_nrm, block_size_nrm, 0, exec->get_queue(), num_rows,
num_rhs, residual->get_const_values(), residual->get_stride(),
num_rhs, as_device_type(residual->get_const_values()),
residual->get_stride(),
arnoldi_norm->get_values() + 2 * stride_arnoldi, 0);
}

Expand All @@ -1009,7 +1010,7 @@ void restart(std::shared_ptr<const DpcppExecutor> exec,
1, 1);
restart_2_kernel<block_size>(
grid_dim_2, block_dim, 0, exec->get_queue(), residual->get_size()[0],
residual->get_size()[1], residual->get_const_values(),
residual->get_size()[1], as_device_type(residual->get_const_values()),
residual->get_stride(), residual_norm->get_const_values(),
residual_norm_collection->get_values(), krylov_bases,
next_krylov_basis->get_values(), next_krylov_basis->get_stride(),
Expand Down Expand Up @@ -1255,9 +1256,10 @@ void solve_upper_triangular(
solve_upper_triangular_kernel<block_size>(
grid_dim, block_dim, 0, exec->get_queue(), hessenberg->get_size()[1],
num_rhs, residual_norm_collection->get_const_values(),
residual_norm_collection->get_stride(), hessenberg->get_const_values(),
hessenberg->get_stride(), y->get_values(), y->get_stride(),
final_iter_nums->get_const_data());
residual_norm_collection->get_stride(),
as_device_type(hessenberg->get_const_values()),
hessenberg->get_stride(), as_device_type(y->get_values()),
y->get_stride(), final_iter_nums->get_const_data());
}


Expand All @@ -1283,7 +1285,7 @@ void calculate_qy(std::shared_ptr<const DpcppExecutor> exec,

calculate_Qy_kernel<block_size>(
grid_dim, block_dim, 0, exec->get_queue(), num_rows, num_cols,
krylov_bases, y->get_const_values(), y->get_stride(),
krylov_bases, as_device_type(y->get_const_values()), y->get_stride(),
before_preconditioner->get_values(), stride_before_preconditioner,
final_iter_nums->get_const_data());
// Calculate qy
Expand Down
82 changes: 47 additions & 35 deletions dpcpp/solver/idr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ void initialize_m(std::shared_ptr<const DpcppExecutor> exec,

const auto grid_dim = ceildiv(m_stride * subspace_dim, default_block_size);
initialize_m_kernel(grid_dim, default_block_size, 0, exec->get_queue(),
subspace_dim, nrhs, m->get_values(), m_stride,
stop_status->get_data());
subspace_dim, nrhs, as_device_type(m->get_values()),
m_stride, stop_status->get_data());
}


Expand Down Expand Up @@ -638,8 +638,9 @@ void solve_lower_triangular(std::shared_ptr<const DpcppExecutor> exec,
const auto grid_dim = ceildiv(nrhs, default_block_size);
solve_lower_triangular_kernel(
grid_dim, default_block_size, 0, exec->get_queue(), subspace_dim, nrhs,
m->get_const_values(), m->get_stride(), f->get_const_values(),
f->get_stride(), c->get_values(), c->get_stride(),
as_device_type(m->get_const_values()), m->get_stride(),
as_device_type(f->get_const_values()), f->get_stride(),
as_device_type(c->get_values()), c->get_stride(),
stop_status->get_const_data());
}

Expand All @@ -662,30 +663,34 @@ void update_g_and_u(std::shared_ptr<const DpcppExecutor> exec,
const dim3 block_dim(default_dot_dim, default_dot_dim);

for (size_type i = 0; i < k; i++) {
const auto p_i = p->get_const_values() + i * p_stride;
const auto p_i = as_device_type(p->get_const_values()) + i * p_stride;
if (nrhs > 1 || is_complex<ValueType>()) {
components::fill_array(exec, alpha->get_values(), nrhs,
zero<ValueType>());
components::fill_array(exec, as_device_type(alpha->get_values()),
nrhs, zero<ValueType>());
multidot_kernel(grid_dim, block_dim, 0, exec->get_queue(), size,
nrhs, p_i, g_k->get_values(), g_k->get_stride(),
alpha->get_values(), stop_status->get_const_data());
as_device_type(alpha->get_values()),
stop_status->get_const_data());
} else {
onemkl::dot(*exec->get_queue(), size, p_i, 1, g_k->get_values(),
g_k->get_stride(), alpha->get_values());
g_k->get_stride(), as_device_type(alpha->get_values()));
}
update_g_k_and_u_kernel<default_block_size>(
ceildiv(size * g_k->get_stride(), default_block_size),
default_block_size, 0, exec->get_queue(), k, i, size, nrhs,
alpha->get_const_values(), m->get_const_values(), m->get_stride(),
g->get_const_values(), g->get_stride(), g_k->get_values(),
g_k->get_stride(), u->get_values(), u->get_stride(),
as_device_type(alpha->get_const_values()),
as_device_type(m->get_const_values()), m->get_stride(),
as_device_type(g->get_const_values()), g->get_stride(),
g_k->get_values(), g_k->get_stride(),
as_device_type(u->get_values()), u->get_stride(),
stop_status->get_const_data());
}
update_g_kernel<default_block_size>(
ceildiv(size * g_k->get_stride(), default_block_size),
default_block_size, 0, exec->get_queue(), k, size, nrhs,
g_k->get_const_values(), g_k->get_stride(), g->get_values(),
g->get_stride(), stop_status->get_const_data());
g_k->get_const_values(), g_k->get_stride(),
as_device_type(g->get_values()), g->get_stride(),
stop_status->get_const_data());
}


Expand All @@ -705,8 +710,8 @@ void update_m(std::shared_ptr<const DpcppExecutor> exec, const size_type nrhs,
const dim3 block_dim(default_dot_dim, default_dot_dim);

for (size_type i = k; i < subspace_dim; i++) {
const auto p_i = p->get_const_values() + i * p_stride;
auto m_i = m->get_values() + i * m_stride + k * nrhs;
const auto p_i = as_device_type(p->get_const_values()) + i * p_stride;
auto m_i = as_device_type(m->get_values()) + i * m_stride + k * nrhs;
if (nrhs > 1 || is_complex<ValueType>()) {
components::fill_array(exec, m_i, nrhs, zero<ValueType>());
multidot_kernel(grid_dim, block_dim, 0, exec->get_queue(), size,
Expand Down Expand Up @@ -735,15 +740,18 @@ void update_x_r_and_f(std::shared_ptr<const DpcppExecutor> exec,
const auto subspace_dim = m->get_size()[0];

const auto grid_dim = ceildiv(size * x->get_stride(), default_block_size);
update_x_r_and_f_kernel(grid_dim, default_block_size, 0, exec->get_queue(),
k, size, subspace_dim, nrhs, m->get_const_values(),
m->get_stride(), g->get_const_values(),
g->get_stride(), u->get_const_values(),
u->get_stride(), f->get_values(), f->get_stride(),
r->get_values(), r->get_stride(), x->get_values(),
x->get_stride(), stop_status->get_const_data());
components::fill_array(exec, f->get_values() + k * f->get_stride(), nrhs,
zero<ValueType>());
update_x_r_and_f_kernel(
grid_dim, default_block_size, 0, exec->get_queue(), k, size,
subspace_dim, nrhs, as_device_type(m->get_const_values()),
m->get_stride(), as_device_type(g->get_const_values()), g->get_stride(),
as_device_type(u->get_const_values()), u->get_stride(),
as_device_type(f->get_values()), f->get_stride(),
as_device_type(r->get_values()), r->get_stride(),
as_device_type(x->get_values()), x->get_stride(),
stop_status->get_const_data());
components::fill_array(
exec, as_device_type(f->get_values()) + k * f->get_stride(), nrhs,
zero<ValueType>());
}


Expand Down Expand Up @@ -780,11 +788,12 @@ void step_1(std::shared_ptr<const DpcppExecutor> exec, const size_type nrhs,

const auto grid_dim = ceildiv(nrhs * num_rows, default_block_size);
step_1_kernel(grid_dim, default_block_size, 0, exec->get_queue(), k,
num_rows, subspace_dim, nrhs, residual->get_const_values(),
residual->get_stride(), c->get_const_values(),
c->get_stride(), g->get_const_values(), g->get_stride(),
v->get_values(), v->get_stride(),
stop_status->get_const_data());
num_rows, subspace_dim, nrhs,
as_device_type(residual->get_const_values()),
residual->get_stride(), as_device_type(c->get_const_values()),
c->get_stride(), as_device_type(g->get_const_values()),
g->get_stride(), as_device_type(v->get_values()),
v->get_stride(), stop_status->get_const_data());
}

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_IDR_STEP_1_KERNEL);
Expand All @@ -805,10 +814,12 @@ void step_2(std::shared_ptr<const DpcppExecutor> exec, const size_type nrhs,

const auto grid_dim = ceildiv(nrhs * num_rows, default_block_size);
step_2_kernel(grid_dim, default_block_size, 0, exec->get_queue(), k,
num_rows, subspace_dim, nrhs, omega->get_const_values(),
num_rows, subspace_dim, nrhs,
as_device_type(omega->get_const_values()),
preconditioned_vector->get_const_values(),
preconditioned_vector->get_stride(), c->get_const_values(),
c->get_stride(), u->get_values(), u->get_stride(),
preconditioned_vector->get_stride(),
as_device_type(c->get_const_values()), c->get_stride(),
as_device_type(u->get_values()), u->get_stride(),
stop_status->get_const_data());
}

Expand Down Expand Up @@ -841,8 +852,9 @@ void compute_omega(
{
const auto grid_dim = ceildiv(nrhs, config::warp_size);
compute_omega_kernel(grid_dim, config::warp_size, 0, exec->get_queue(),
nrhs, kappa, tht->get_const_values(),
residual_norm->get_const_values(), omega->get_values(),
nrhs, kappa, as_device_type(tht->get_const_values()),
residual_norm->get_const_values(),
as_device_type(omega->get_values()),
stop_status->get_const_data());
}

Expand Down
4 changes: 2 additions & 2 deletions dpcpp/stop/residual_norm_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void residual_norm(std::shared_ptr<const DpcppExecutor> exec,
});

auto orig_tau_val = orig_tau->get_const_values();
auto tau_val = tau->get_const_values();
auto tau_val = as_device_type(tau->get_const_values());
auto stop_status_val = stop_status->get_data();
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
Expand Down Expand Up @@ -102,7 +102,7 @@ void implicit_residual_norm(
});

auto orig_tau_val = orig_tau->get_const_values();
auto tau_val = tau->get_const_values();
auto tau_val = as_device_type(tau->get_const_values());
auto stop_status_val = stop_status->get_data();
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(
Expand Down

0 comments on commit e9b24cb

Please sign in to comment.