Skip to content

Commit

Permalink
Use PIMPL in VectorHandler class.
Browse files Browse the repository at this point in the history
  • Loading branch information
pelesh committed Oct 19, 2023
1 parent 678a324 commit cec1aa7
Show file tree
Hide file tree
Showing 10 changed files with 680 additions and 156 deletions.
2 changes: 2 additions & 0 deletions resolve/vector/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
set(Vector_SRC
Vector.cpp
VectorHandler.cpp
VectorHandlerCpu.cpp
VectorHandlerCuda.cpp
)


Expand Down
165 changes: 41 additions & 124 deletions resolve/vector/VectorHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include <resolve/cuda/cudaKernels.h>
#include <resolve/vector/Vector.hpp>
#include <resolve/workspace/LinAlgWorkspaceFactory.hpp>
#include <resolve/vector/VectorHandlerImpl.hpp>
#include <resolve/vector/VectorHandlerCpu.hpp>
#include <resolve/vector/VectorHandlerCuda.hpp>
#include "VectorHandler.hpp"

namespace ReSolve {
Expand All @@ -14,16 +17,33 @@ namespace ReSolve {
*/
VectorHandler::VectorHandler()
{
cpuImpl_ = new VectorHandlerCpu();
isCpuEnabled_ = true;
}

/**
* @brief constructor
*
* @param new_workspace - workspace to be set
*/
VectorHandler:: VectorHandler(LinAlgWorkspace* new_workspace)
VectorHandler::VectorHandler(LinAlgWorkspace* new_workspace)
{
workspace_ = new_workspace;
cpuImpl_ = new VectorHandlerCpu(new_workspace);
isCpuEnabled_ = true;
}

/**
* @brief constructor
*
* @param new_workspace - workspace to be set
*/
VectorHandler::VectorHandler(LinAlgWorkspaceCUDA* new_workspace)
{
cudaImpl_ = new VectorHandlerCuda(new_workspace);
cpuImpl_ = new VectorHandlerCpu();

isCudaEnabled_ = true;
isCpuEnabled_ = true;
}

/**
Expand All @@ -46,28 +66,11 @@ namespace ReSolve {

real_type VectorHandler::dot(vector::Vector* x, vector::Vector* y, std::string memspace)
{
if (memspace == "cuda" ){
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
cublasHandle_t handle_cublas = workspaceCUDA->getCublasHandle();
double nrm = 0.0;
cublasStatus_t st= cublasDdot (handle_cublas, x->getSize(), x->getData("cuda"), 1, y->getData("cuda"), 1, &nrm);
if (st!=0) {printf("dot product crashed with code %d \n", st);}
return nrm;
if (memspace == "cuda" ) {
return cudaImpl_->dot(x, y);
} else {
if (memspace == "cpu") {
real_type* x_data = x->getData("cpu");
real_type* y_data = y->getData("cpu");
real_type sum = 0.0;
real_type c = 0.0;
real_type t, y;
for (int i = 0; i < x->getSize(); ++i){
y = (x_data[i] * y_data[i]) - c;
t = sum + y;
c = (t - sum) - y;
sum = t;
// sum += (x_data[i] * y_data[i]);
}
return sum;
return cpuImpl_->dot(x, y);
} else {
out::error() << "Not implemented (yet)" << std::endl;
return NAN;
Expand All @@ -85,20 +88,11 @@ namespace ReSolve {
*/
void VectorHandler::scal(const real_type* alpha, vector::Vector* x, std::string memspace)
{
if (memspace == "cuda" ) {
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
cublasHandle_t handle_cublas = workspaceCUDA->getCublasHandle();
cublasStatus_t st = cublasDscal(handle_cublas, x->getSize(), alpha, x->getData("cuda"), 1);
if (st!=0) {
ReSolve::io::Logger::error() << "scal crashed with code " << st << "\n";
}
if (memspace == "cuda" ) {
cudaImpl_->scal(alpha, x);
} else {
if (memspace == "cpu") {
real_type* x_data = x->getData("cpu");

for (int i = 0; i < x->getSize(); ++i){
x_data[i] *= (*alpha);
}
cpuImpl_->scal(alpha, x);
} else {
out::error() << "Not implemented (yet)" << std::endl;
}
Expand All @@ -114,26 +108,14 @@ namespace ReSolve {
* @param[in] memspace String containg memspace (cpu or cuda)
*
*/
void VectorHandler::axpy(const real_type* alpha, vector::Vector* x, vector::Vector* y, std::string memspace )
void VectorHandler::axpy(const real_type* alpha, vector::Vector* x, vector::Vector* y, std::string memspace)
{
//AXPY: y = alpha * x + y
if (memspace == "cuda" ) {
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
cublasHandle_t handle_cublas = workspaceCUDA->getCublasHandle();
cublasDaxpy(handle_cublas,
x->getSize(),
alpha,
x->getData("cuda"),
1,
y->getData("cuda"),
1);
if (memspace == "cuda" ) {
cudaImpl_->axpy(alpha, x, y);
} else {
if (memspace == "cpu") {
real_type* x_data = x->getData("cpu");
real_type* y_data = y->getData("cpu");
for (int i = 0; i < x->getSize(); ++i){
y_data[i] = (*alpha) * x_data[i] + y_data[i];
}
cpuImpl_->axpy(alpha, x, y);
} else {
out::error() <<"Not implemented (yet)" << std::endl;
}
Expand All @@ -160,38 +142,9 @@ namespace ReSolve {
void VectorHandler::gemv(std::string transpose, index_type n, index_type k, const real_type* alpha, const real_type* beta, vector::Vector* V, vector::Vector* y, vector::Vector* x, std::string memspace)
{
if (memspace == "cuda") {
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
cublasHandle_t handle_cublas = workspaceCUDA->getCublasHandle();
if (transpose == "T") {

cublasDgemv(handle_cublas,
CUBLAS_OP_T,
n,
k,
alpha,
V->getData("cuda"),
n,
y->getData("cuda"),
1,
beta,
x->getData("cuda"),
1);

} else {
cublasDgemv(handle_cublas,
CUBLAS_OP_N,
n,
k,
alpha,
V->getData("cuda"),
n,
y->getData("cuda"),
1,
beta,
x->getData("cuda"),
1);
}

cudaImpl_->gemv(transpose, n, k, alpha, beta, V, y, x);
} else if (memspace == "cpu") {
cpuImpl_->gemv(transpose, n, k, alpha, beta, V, y, x);
} else {
out::error() << "Not implemented (yet)" << std::endl;
}
Expand All @@ -213,26 +166,9 @@ namespace ReSolve {
{
using namespace constants;
if (memspace == "cuda") {
if (k < 200) {
mass_axpy(size, k, x->getData("cuda"), y->getData("cuda"),alpha->getData("cuda"));
} else {
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
cublasHandle_t handle_cublas = workspaceCUDA->getCublasHandle();
cublasDgemm(handle_cublas,
CUBLAS_OP_N,
CUBLAS_OP_N,
size, // m
1, // n
k + 1, // k
&MINUSONE, // alpha
x->getData("cuda"), // A
size, // lda
alpha->getData("cuda"), // B
k + 1, // ldb
&ONE,
y->getData("cuda"), // c
size); // ldc
}
cudaImpl_->massAxpy(size, alpha, k, x, y);
} else if (memspace == "cpu") {
cpuImpl_->massAxpy(size, alpha, k, x, y);
} else {
out::error() << "Not implemented (yet)" << std::endl;
}
Expand All @@ -254,29 +190,10 @@ namespace ReSolve {
*/
void VectorHandler::massDot2Vec(index_type size, vector::Vector* V, index_type k, vector::Vector* x, vector::Vector* res, std::string memspace)
{
using namespace constants;

if (memspace == "cuda") {
if (k < 200) {
mass_inner_product_two_vectors(size, k, x->getData("cuda") , x->getData(1, "cuda"), V->getData("cuda"), res->getData("cuda"));
} else {
LinAlgWorkspaceCUDA* workspaceCUDA = (LinAlgWorkspaceCUDA*) workspace_;
cublasHandle_t handle_cublas = workspaceCUDA->getCublasHandle();
cublasDgemm(handle_cublas,
CUBLAS_OP_T,
CUBLAS_OP_N,
k + 1, //m
2, //n
size, //k
&ONE, //alpha
V->getData("cuda"), //A
size, //lda
x->getData("cuda"), //B
size, //ldb
&ZERO,
res->getData("cuda"), //c
k + 1); //ldc
}
cudaImpl_->massDot2Vec(size, V, k, x, res);
} else if (memspace == "cpu") {
cpuImpl_->massDot2Vec(size, V, k, x, res);
} else {
out::error() << "Not implemented (yet)" << std::endl;
}
Expand Down
23 changes: 19 additions & 4 deletions resolve/vector/VectorHandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ namespace ReSolve
{
class Vector;
}
class VectorHandlerImpl;
class LinAlgWorkspace;
class LinAlgWorkspaceCUDA;
}


Expand All @@ -16,13 +18,14 @@ namespace ReSolve { //namespace vector {
public:
VectorHandler();
VectorHandler(LinAlgWorkspace* new_workspace);
VectorHandler(LinAlgWorkspaceCUDA* new_workspace);
~VectorHandler();

//y = alpha x + y
void axpy(const real_type* alpha, vector::Vector* x, vector::Vector* y, std::string memspace );
void axpy(const real_type* alpha, vector::Vector* x, vector::Vector* y, std::string memspace);

//dot: x \cdot y
real_type dot(vector::Vector* x, vector::Vector* y, std::string memspace );
real_type dot(vector::Vector* x, vector::Vector* y, std::string memspace);

//scal = alpha * x
void scal(const real_type* alpha, vector::Vector* x, std::string memspace);
Expand All @@ -40,9 +43,21 @@ namespace ReSolve { //namespace vector {
* if `transpose = T` (yes), `x = beta*x + alpha*V^T*y`,
* where `x` is `[k x 1]`, `V` is `[n x k]` and `y` is `[n x 1]`.
*/
void gemv(std::string transpose, index_type n, index_type k, const real_type* alpha, const real_type* beta, vector::Vector* V, vector::Vector* y, vector::Vector* x, std::string memspace);
void gemv(std::string transpose,
index_type n,
index_type k,
const real_type* alpha,
const real_type* beta,
vector::Vector* V,
vector::Vector* y,
vector::Vector* x,
std::string memspace);
private:
LinAlgWorkspace* workspace_;
VectorHandlerImpl* cpuImpl_{nullptr};
VectorHandlerImpl* cudaImpl_{nullptr};

bool isCpuEnabled_{false};
bool isCudaEnabled_{false};
};

} //} // namespace ReSolve::vector
Loading

0 comments on commit cec1aa7

Please sign in to comment.