Skip to content

Commit

Permalink
fixing zgemm calls to use complex alpha and beta...EJB
Browse files Browse the repository at this point in the history
  • Loading branch information
ebylaska committed Dec 15, 2023
1 parent 8b0a116 commit 1398068
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 58 deletions.
33 changes: 25 additions & 8 deletions Nwpw/band/lib/cpsp/CPseudopotential.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,8 @@ void CPseudopotential::v_nonlocal(double *psi, double *Hpsi)
double scal = 1.0 / omega;
int one = 1;
int ntmp, nshift, nn;
double rone = 1.0;
double rmone = -1.0;
double rone[2] = {1.0,0.0};
double rmone[2] = {-1.0,0.0};

nn = mypneb->neq[0] + mypneb->neq[1];
nshift0 = mypneb->npack1_max();
Expand Down Expand Up @@ -1650,6 +1650,13 @@ void CPseudopotential::v_nonlocal(double *psi, double *Hpsi)
jend = ii;
mypneb->cc_pack_inprjzdot(nbq1, 2*nn, nprjall, psi, prjtmp, zsw1);
parall->Vector_SumAll(1, 2*nn*nprjall, zsw1);


std::cout << "VNONLOCAL nprjall=" << nprjall << " zsw1= ";
for (auto kk=0; kk<20; ++kk)
std::cout << zsw1[kk] << " ";
std::cout << std::endl;
std::cout << std::endl;

/* sw2 = Gijl*sw1 */
ll = 0;
Expand All @@ -1662,19 +1669,29 @@ void CPseudopotential::v_nonlocal(double *psi, double *Hpsi)
ll += nprj[ia];
}
}

std::cout << "VNONLOCAL nprjall=" << nprjall << " zsw2= ";
for (auto kk=0; kk<20; ++kk)
std::cout << zsw2[kk] << " ";
std::cout << std::endl;
std::cout << std::endl;

ntmp = 2*nn*nprjall;
DSCAL_PWDFT(ntmp, scal, zsw2, one);

// DGEMM_PWDFT((char*) "N",(char*) "T",nshift,nn,nprjall,
// DGEMM_PWDFT((char*) "N",(char*) "C",nshift,nn,nprjall,
// rmone,
// prjtmp,nshift,
// sw2, nn,
// zsw2, nn,
// rone,
// Hpsi,nshift);
mypneb->c3db::mygdevice.NC2_zgemm(nshift, nn, nprjall, rmone, prjtmp, zsw2, rone, Hpsi);
}
mypneb->c3db::mygdevice.hpsi_copy_gpu2host(nshift0, nn, Hpsi);
std::cout << "HPSI =";
for (auto kk=0; kk<20; ++kk)
std::cout << Hpsi[kk] << " ";
std::cout << std::endl << std::endl;
#else

for (ii = 0; ii < (myion->nion); ++ii)
Expand Down Expand Up @@ -1767,8 +1784,8 @@ void CPseudopotential::v_nonlocal_fion(double *psi, double *Hpsi,
int three = 3;
int ntmp;

double rone = 1.0;
double rmone = -1.0;
double rone[2] = {1.0,0.0};
double rmone[2] = {-1.0,0.0};

int nn = mypneb->neq[0] + mypneb->neq[1];
int ispin = mypneb->ispin;
Expand Down Expand Up @@ -2247,7 +2264,7 @@ double CPseudopotential::e_nonlocal(double *psi)
{
Multiply_Gijl_zsw1(nn, nprj[ia], nmax[ia], lmax[ia], n_projector[ia],
l_projector[ia], m_projector[ia], Gijl[ia],
zsw1+ll*nn, zsw2+ll*nn);
zsw1+(ll*2*nn), zsw2+(ll*2*nn));
ll += nprj[ia];
}
}
Expand All @@ -2260,7 +2277,7 @@ double CPseudopotential::e_nonlocal(double *psi)
auto ntmp = 2*nn*nprjall;
DSCAL_PWDFT(ntmp, scal, zsw2, one);

esum += DDOT_PWDFT(ntmp, zsw1, one, zsw2, one);
esum += ZDOTC_PWDFT(ntmp, zsw1, one, zsw2, one);
mypneb->c3db::mygdevice.T_free();
}
}
Expand Down
4 changes: 2 additions & 2 deletions Nwpw/nwpwlib/C3dB/CGrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,8 @@ void CGrid::cc_pack_inprjzdot(const int nb, int nn, int nprj, double *a,
double *b, double *sum)
{
int ng = (nidb[nb]);
double rone = 1.0;
double rzero = 0.0;
double rone[2] = {1.0,0.0};
double rzero[2] = {0.0,0.0};

c3db::mygdevice.CN2_zgemm(nn, nprj, ng, rone, a, b, rzero, sum);
}
Expand Down
24 changes: 12 additions & 12 deletions Nwpw/nwpwlib/C3dB/Cneb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -937,10 +937,10 @@ void Cneb::ggm_sym_Multiply(double *psi1, double *psi2, double *hml)
//int ng0 = 2*CGrid::nzero(1);

int one = 1;
double rzero = 0.0;
double rtwo = 2.0;
double rone = 1.0;
double rmone = -1.0;
double rzero[2] = {0.0,0.0};
double rtwo[2] = {2.0,0.0};
double rone[2] = {1.0,0.0};
double rmone[2] = {-1.0,0.0};

if (parallelized)
{
Expand Down Expand Up @@ -1026,10 +1026,10 @@ void Cneb::ggw_sym_Multiply(double *psi1, double *psi2, double *hml)
//int ng0 = 2*CGrid::nzero(1);

int one = 1;
double rzero = 0.0;
double rtwo = 2.0;
double rone = 1.0;
double rmone = -1.0;
double rzero[2] = {0.0,0.0};
double rtwo[2] = {2.0,0.0};
double rone[2] = {1.0,0.0};
double rmone[2] = {-1.0,0.0};

if (parallelized)
{
Expand Down Expand Up @@ -1116,10 +1116,10 @@ void Cneb::ggw_Multiply(double *psi1, double *psi2, double *hml)
int npack1 = 2 * CGrid::npack1_max();
//int ng0 = 2 * CGrid::nzero(1);

double rzero = 0.0;
double rtwo = 2.0;
double rone = 1.0;
double rmone = -1.0;
double rzero[2] = {0.0,0.0};
double rtwo[2] = {2.0,0.0};
double rone[2] = {1.0,0.0};
double rmone[2] = {-1.0,0.0};

if (parallelized)
{
Expand Down
2 changes: 1 addition & 1 deletion Nwpw/nwpwlib/blas/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ extern "C" void zheev_(char *, char *, int *, double *, int *, double *,
#define ZDOTC_PWDFT(n, a, ida, b, idb) zdotc_(&(n), a, &ida, b, &(idb))

#define ZGEMM_PWDFT(s1, s2, n, m, k, alpha, a, ida, b, idb, beta, c, idc) \
zgemm_(s1, s2, &(n), &(m), &(k), &(alpha), a, &(ida), b, &(idb), &(beta), c, \
zgemm_(s1, s2, &(n), &(m), &(k), alpha, a, &(ida), b, &(idb), beta, c, \
&(idc))

#define ZEIGEN_PWDFT(n, hml, eig, xtmp, nn, rtmp, ierr) \
Expand Down
18 changes: 9 additions & 9 deletions Nwpw/nwpwlib/device/gdevice2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,31 +78,31 @@ void gdevice2::NN_eigensolver(int ispin, int ne[], double *a, double *w) {
}


void gdevice2::CN1_zgemm(int npack, int ne, double alpha, double *a, double *b,
double beta, double *c) {
void gdevice2::CN1_zgemm(int npack, int ne, double *alpha, double *a, double *b,
double *beta, double *c) {
mygdevice2->CN1_zgemm(npack, ne, alpha, a, b, beta, c);
}

void gdevice2::CN2_zgemm(int npack, int ne, int nprj, double alpha, double *a, double *b,
double beta, double *c) {
void gdevice2::CN2_zgemm(int npack, int ne, int nprj, double *alpha, double *a, double *b,
double *beta, double *c) {
mygdevice2->CN2_zgemm(npack, ne, nprj, alpha, a, b, beta, c);
}

void gdevice2::NC2_zgemm(int npack, int ne, int nprj, double alpha, double *a, double *b,
double beta, double *c) {
void gdevice2::NC2_zgemm(int npack, int ne, int nprj, double *alpha, double *a, double *b,
double *beta, double *c) {
mygdevice2->NC2_zgemm(npack, ne, nprj, alpha, a, b, beta, c);
}


void gdevice2::NN_zgemm(int n, int m, int k, double alpha, double *a, int lda, double *b, int ldb, double beta, double *c, int ldc) {
void gdevice2::NN_zgemm(int n, int m, int k, double *alpha, double *a, int lda, double *b, int ldb, double *beta, double *c, int ldc) {
mygdevice2->NN_zgemm(n,m,k,alpha,a,lda,b,ldb,beta,c,ldc);
}

void gdevice2::NC_zgemm(int n, int m, int k, double alpha, double *a, int lda, double *b, int ldb, double beta, double *c, int ldc) {
void gdevice2::NC_zgemm(int n, int m, int k, double *alpha, double *a, int lda, double *b, int ldb, double *beta, double *c, int ldc) {
mygdevice2->NC_zgemm(n,m,k,alpha,a,lda,b,ldb,beta,c,ldc);
}

void gdevice2::CN_zgemm(int n, int m, int k, double alpha, double *a, int lda, double *b, int ldb, double beta, double *c, int ldc) {
void gdevice2::CN_zgemm(int n, int m, int k, double *alpha, double *a, int lda, double *b, int ldb, double *beta, double *c, int ldc) {
mygdevice2->CN_zgemm(n,m,k,alpha,a,lda,b,ldb,beta,c,ldc);
}

Expand Down
16 changes: 8 additions & 8 deletions Nwpw/nwpwlib/device/gdevice2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ Gdevices *mygdevice2;
void NN_eigensolver(int, int *, double *, double *);


void NN1_zgemm(int, int, double, double *, double *, double, double *, int *);
void CN1_zgemm(int, int, double, double *, double *, double, double *);
void CN2_zgemm(int, int, int, double, double *, double *, double, double *);
void NC2_zgemm(int, int, int, double, double *, double *, double, double *);

void NN_zgemm(int, int, int, double, double *, int, double *, int, double, double *,int);
void CN_zgemm(int, int, int, double, double *, int, double *, int, double, double *,int);
void NC_zgemm(int, int, int, double, double *, int, double *, int, double, double *,int);
void NN1_zgemm(int, int, double *, double *, double *, double *, double *, int *);
void CN1_zgemm(int, int, double *, double *, double *, double *, double *);
void CN2_zgemm(int, int, int, double *, double *, double *, double *, double *);
void NC2_zgemm(int, int, int, double *, double *, double *, double *, double *);

void NN_zgemm(int, int, int, double *, double *, int, double *, int, double *, double *,int);
void CN_zgemm(int, int, int, double *, double *, int, double *, int, double *, double *,int);
void NC_zgemm(int, int, int, double *, double *, int, double *, int, double *, double *,int);

void WW_eigensolver(int, int *, double *, double *);

Expand Down
35 changes: 17 additions & 18 deletions Nwpw/nwpwlib/device/gdevices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,56 +225,55 @@ class Gdevices {

/// DOUBLE COMPLEX BLAS

void NN1_zgemm(int npack, int ne, double alpha, double *host_a, double *host_b,
double beta, double *host_c) {
void NN1_zgemm(int npack, int ne, double *alpha, double *host_a, double *host_b,
double *beta, double *host_c) {
ZGEMM_PWDFT((char *)"N", (char *)"N", npack, ne, ne, alpha, host_a, npack,
host_b, ne, beta, host_c, npack);
}

void CN1_zgemm(int npack, int ne, double alpha, double *host_a,
double *host_b, double beta, double *host_c) {
void CN1_zgemm(int npack, int ne, double *alpha, double *host_a,
double *host_b, double *beta, double *host_c) {
ZGEMM_PWDFT((char *)"C", (char *)"N", ne, ne, npack, alpha, host_a, npack,
host_b, npack, beta, host_c, ne);
}

void CN2_zgemm(int ne, int nprj, int npack, double alpha, double *host_a,
double *host_b, double beta, double *host_c) {
void CN2_zgemm(int ne, int nprj, int npack, double *alpha, double *host_a,
double *host_b, double *beta, double *host_c) {
ZGEMM_PWDFT((char *)"C", (char *)"N", ne, nprj, npack, alpha, host_a, npack,
host_b, npack, beta, host_c, ne);
}

void NC2_zgemm(int ne, int nprj, int npack, double alpha, double *host_a,
double *host_b, double beta, double *host_c) {
ZGEMM_PWDFT((char *)"N", (char *)"C", ne, nprj, npack, alpha, host_a, npack,
host_b, npack, beta, host_c, ne);
void NC2_zgemm(int npack, int ne, int nprj, double *alpha, double *host_a,
double *host_b, double *beta, double *host_c) {
ZGEMM_PWDFT((char *)"N", (char *)"C", npack,ne,nprj,alpha, host_a, npack,
host_b, ne, beta, host_c, npack);
}


void NN_zgemm(int m, int n, int k,
double alpha,
double *alpha,
double *host_a, int lda,
double *host_b, int ldb,
double beta,
double *beta,
double *host_c,int ldc) {
ZGEMM_PWDFT((char *)"N", (char *)"N", m, n, k, alpha, host_a, lda, host_b, ldb, beta, host_c, ldc);
}

void CN_zgemm(int m, int n, int k,
double alpha,
double *alpha,
double *host_a, int lda,
double *host_b, int ldb,
double beta,
double *beta,
double *host_c,int ldc) {
ZGEMM_PWDFT((char *)"C", (char *)"N", m, n, k, alpha, host_a, lda, host_b, ldb, beta, host_c, ldc);
}




void NC_zgemm(int m, int n, int k,
double alpha,
double *alpha,
double *host_a, int lda,
double *host_b, int ldb,
double beta,
double *beta,
double *host_c,int ldc) {
ZGEMM_PWDFT((char *)"N", (char *)"C", m, n, k, alpha, host_a, lda, host_b, ldb, beta, host_c, ldc);
}
Expand Down

0 comments on commit 1398068

Please sign in to comment.