ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
math_tools.h
Go to the documentation of this file.
5#include "source_psi/psi.h"
6#include "source_base/timer.h"
7
8#ifdef __MPI
9inline void psiMulPsiMpi(const psi::Psi<double>& psi1,
10 const psi::Psi<double>& psi2,
11 ModuleBase::matrix& dm_out,
12 const int* desc_psi,
13 const int* desc_dm)
14{
15 ModuleBase::timer::tick("psiMulPsiMpi","pdgemm");
16 const double one_float = 1.0, zero_float = 0.0;
17 const int one_int = 1;
18 const char N_char = 'N', T_char = 'T';
19 const int nlocal = desc_dm[2];
20 const int nbands = desc_psi[3];
21 pdgemm_(&N_char,
22 &T_char,
23 &nlocal,
24 &nlocal,
25 &nbands,
26 &one_float,
27 psi1.get_pointer(),
28 &one_int,
29 &one_int,
30 desc_psi,
31 psi2.get_pointer(),
32 &one_int,
33 &one_int,
34 desc_psi,
35 &zero_float,
36 dm_out.c,
37 &one_int,
38 &one_int,
39 desc_dm);
40 ModuleBase::timer::tick("psiMulPsiMpi","pdgemm");
41}
42
43inline void psiMulPsiMpi(const psi::Psi<std::complex<double>>& psi1,
44 const psi::Psi<std::complex<double>>& psi2,
46 const int* desc_psi,
47 const int* desc_dm)
48{
49 ModuleBase::timer::tick("psiMulPsiMpi","pdgemm");
50 const std::complex<double> one_complex = {1.0, 0.0}, zero_complex = {0.0, 0.0};
51 const int one_int = 1;
52 const char N_char = 'N', T_char = 'T';
53 const int nlocal = desc_dm[2];
54 const int nbands = desc_psi[3];
55 pzgemm_(&N_char,
56 &T_char,
57 &nlocal,
58 &nlocal,
59 &nbands,
60 &one_complex,
61 psi1.get_pointer(),
62 &one_int,
63 &one_int,
64 desc_psi,
65 psi2.get_pointer(),
66 &one_int,
67 &one_int,
68 desc_psi,
69 &zero_complex,
70 dm_out.c,
71 &one_int,
72 &one_int,
73 desc_dm);
74 ModuleBase::timer::tick("psiMulPsiMpi","pdgemm");
75}
76
77#else
78inline void psiMulPsi(const psi::Psi<double>& psi1, const psi::Psi<double>& psi2, ModuleBase::matrix& dm_out)
79{
80 const double one_float = 1.0, zero_float = 0.0;
81 const int one_int = 1;
82 const char N_char = 'N', T_char = 'T';
83 const int nlocal = psi1.get_nbasis();
84 const int nbands = psi1.get_nbands();
85 dgemm_(&N_char,
86 &T_char,
87 &nlocal,
88 &nlocal,
89 &nbands,
90 &one_float,
91 psi1.get_pointer(),
92 &nlocal,
93 psi2.get_pointer(),
94 &nlocal,
95 &zero_float,
96 dm_out.c,
97 &nlocal);
98}
99
100inline void psiMulPsi(const psi::Psi<std::complex<double>>& psi1,
101 const psi::Psi<std::complex<double>>& psi2,
103{
104 const int one_int = 1;
105 const char N_char = 'N', T_char = 'T';
106 const int nlocal = psi1.get_nbasis();
107 const int nbands = psi1.get_nbands();
108 const std::complex<double> one_complex = {1.0, 0.0}, zero_complex = {0.0, 0.0};
109 zgemm_(&N_char,
110 &T_char,
111 &nlocal,
112 &nlocal,
113 &nbands,
114 &one_complex,
115 psi1.get_pointer(),
116 &nlocal,
117 psi2.get_pointer(),
118 &nlocal,
119 &zero_complex,
120 dm_out.c,
121 &nlocal);
122}
123
124#endif
void zgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc)
void dgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc)
Definition complexmatrix.h:14
std::complex< double > * c
Definition complexmatrix.h:21
Definition matrix.h:19
double * c
Definition matrix.h:25
static void tick(const std::string &class_name_in, const std::string &name_in)
Use twice at a time: the first time, set start_flag to false; the second time, calculate the time dur...
Definition timer.cpp:57
Definition psi.h:37
const int & get_nbands() const
Definition psi.cpp:342
const int & get_nbasis() const
Definition psi.cpp:348
T * get_pointer() const
Definition psi.cpp:272
void psiMulPsiMpi(const psi::Psi< double > &psi1, const psi::Psi< double > &psi2, ModuleBase::matrix &dm_out, const int *desc_psi, const int *desc_dm)
Definition math_tools.h:9
void psiMulPsi(const psi::Psi< double > &psi1, const psi::Psi< double > &psi2, double *dm_out)
Definition cal_dm_psi.cpp:227
void pdgemm_(const char *transa, const char *transb, const int *M, const int *N, const int *K, const double *alpha, const double *A, const int *IA, const int *JA, const int *DESCA, const double *B, const int *IB, const int *JB, const int *DESCB, const double *beta, double *C, const int *IC, const int *JC, const int *DESCC)
void pzgemm_(const char *transa, const char *transb, const int *M, const int *N, const int *K, const std::complex< double > *alpha, const std::complex< double > *A, const int *IA, const int *JA, const int *DESCA, const std::complex< double > *B, const int *IB, const int *JB, const int *DESCB, const std::complex< double > *beta, std::complex< double > *C, const int *IC, const int *JC, const int *DESCC)