Skip to content

Commit

Permalink
...EJB
Browse files Browse the repository at this point in the history
  • Loading branch information
ebylaska committed May 20, 2024
1 parent 3ece142 commit 5c9ef68
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions Nwpw/nwpwlib/device/gdevices_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,115 @@ class Gdevices {
}


/**************************************
* *
* computeTrans3_Mult *
* *
**************************************/
/**
* @brief Computes the 3D transformation of projection data using transformation matrices.
*
* This function computes transformation sums for multiple projection (`prj`) functions and
* `psi` functions. It uses projection data (`prj`) and input data (`psi`) to calculate
* intermediate values, and then computes the transformation sums (`sum3`) using the
* transformation matrices (`Gx`, `Gy`, `Gz`).
*
* @param ne Number of `psi` functions to process.
* @param nprj Number of `prj` functions to process.
* @param psi Pointer to the input data array containing the `psi` functions.
* @param prj Pointer to the projection data array containing the `prj` functions.
* @param ng Grid size for the main computation (number of complex numbers).
* @param ng0 Reduced grid size for the secondary computation (number of complex numbers).
* @param Gx Transformation matrix in the x-direction.
* @param Gy Transformation matrix in the y-direction.
* @param Gz Transformation matrix in the z-direction.
* @param xtmp1 Buffer for intermediate computations.
* @param sum3 Array to store the computed transformation sums.
*
* @note This function can be optimized using SIMD instructions, OpenMP for parallelization,
* and GPU acceleration to enhance performance.
*/
void computeTrans3_Mult(const int ne, const int nprj,
const double *psi, const double *prj,
int ng, int ng0,
double *Gx, double *Gy, double *Gz, double *xtmp1,
double *sum3)
{
int one = 1;
int count3 = 0;
int nshift = 2*ng;

for (auto l=0; l<nprj; ++l)
for (auto n=0; n<ne; ++n)
{
// Perform cct_pack_iconjgMul
const double *a = prj + l*nshift;
const double *b = psi + n*nshift;
for (int i=0; i<ng; ++i)
xtmp1[i] = a[2*i]*b[2*i+1] - a[2*i+1]*b[2*i];

double tsumx = 2.0*DDOT_PWDFT(ng,Gx,one,xtmp1,one);
double tsumy = 2.0*DDOT_PWDFT(ng,Gy,one,xtmp1,one);
double tsumz = 2.0*DDOT_PWDFT(ng,Gz,one,xtmp1,one);
tsumx -= DDOT_PWDFT(ng0,Gx,one,xtmp1,one);
tsumy -= DDOT_PWDFT(ng0,Gy,one,xtmp1,one);
tsumz -= DDOT_PWDFT(ng0,Gz,one,xtmp1,one);

sum3[count3] = tsumx;
sum3[count3+1] = tsumy;
sum3[count3+2] = tsumz;
count3 += 3;
}
}



/**************************************
* *
* computeTrans_Mult *
* *
**************************************/
/**
* @brief Computes the transformation of projection data using matrix multiplication.
*
* This function computes matrix multiplication between projection (`prj`) and input
* (`psi`) data using a BLAS DGEMM routine, which is optimized for double-precision
* general matrix multiplication. It first performs the matrix multiplication with the
* full grid size (`npack2`), and if a reduced grid size (`npack0`) is specified,
* subtracts the result using another matrix multiplication.
*
* The BLAS DGEMM function can be optimized for various architectures, leveraging
* specialized hardware and libraries to enhance performance.
*
* @param ne Number of `psi` functions to process.
* @param nprj Number of `prj` functions to process.
* @param psi Pointer to the input data array containing the `psi` functions.
* @param prj Pointer to the projection data array containing the `prj` functions.
* @param ng Grid size for the main computation (number of complex numbers).
* @param ng0 Reduced grid size for the secondary computation (number of complex numbers).
* @param sum1 Array to store the computed matrix multiplication results.
*/
void computeTrans_Mult(int ne, int nprj, double alpha, double alpha1, int ng, int ng0,
double *psi, double *prj, double beta, double beta1, double *sum1)
{
int npack2 = 2*ng;
int npack0 = 2*ng0;
double rtwo = 2.0;
double rzero = 0.0;
double rone = 1.0;
double rmone = -1.0;

TN_dgemm(ne,nprj,npack2,alpha,psi,prj,beta,sum)
DGEMM_PWDFT((char *)"T", (char *)"N",ne,nprj,npack2,alpha,psi,npack2,prj,npack2,beta,sum1,ne);
if (npack0 > 0)
{
DGEMM_PWDFT((char *)"T", (char *)"N",ne,nprj,npack0,alpha1,psi,npack2,prj,npack2,beta1,sum1,ne);
}
}




/**************************************
* *
* WW6_zgemm *
Expand Down

0 comments on commit 5c9ef68

Please sign in to comment.