ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cal_dm.h
Go to the documentation of this file.
1#ifndef CAL_DM_H
2#define CAL_DM_H
3
4#include "math_tools.h"
5#include "source_base/timer.h"
8
9namespace elecstate
10{
11
12// for gamma_only(double case) and multi-k(complex<double> case)
13inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, const psi::Psi<double>& wfc, std::vector<ModuleBase::matrix>& dm)
14{
15 ModuleBase::TITLE("elecstate", "cal_dm");
16 ModuleBase::timer::tick("elecstate","cal_dm");
17
18 //dm.resize(wfc.get_nk(), ParaV->ncol, ParaV->nrow);
19 const int nbands_local = wfc.get_nbands();
20 const int nbasis_local = wfc.get_nbasis();
21
22 // dm = wfc.T * wg * wfc.conj()
23 // dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
24 for (int ik = 0; ik < wfc.get_nk(); ++ik)
25 {
26 wfc.fix_k(ik);
27 //dm.fix_k(ik);
28 dm[ik].create(ParaV->ncol, ParaV->nrow);
29 // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
30 psi::Psi<double> wg_wfc(1,
31 wfc.get_nbands(),
32 wfc.get_nbasis(),
33 wfc.get_nbasis(),
34 true);
35 wg_wfc.set_all_psi(wfc.get_pointer(), wg_wfc.size());
36
37 int ib_global = 0;
38 for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
39 {
40 while (ib_local != ParaV->global2local_col(ib_global))
41 {
42 ++ib_global;
43 if (ib_global >= wg.nc)
44 {
45 break;
46 ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
47 }
48 }
49 if (ib_global >= wg.nc) { continue;
50}
51 const double wg_local = wg(ik, ib_global);
52 double* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
53 BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
54 }
55
56 // C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
57#ifdef __MPI
58 psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc);
59#else
60 psiMulPsi(wg_wfc, wfc, dm[ik]);
61#endif
62 }
63 ModuleBase::timer::tick("elecstate","cal_dm");
64
65 return;
66}
67
68inline void cal_dm(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& wg, const psi::Psi<std::complex<double>>& wfc, std::vector<ModuleBase::ComplexMatrix>& dm)
69{
70 ModuleBase::TITLE("elecstate", "cal_dm");
71 ModuleBase::timer::tick("elecstate","cal_dm");
72
73 //dm.resize(wfc.get_nk(), ParaV->ncol, ParaV->nrow);
74 const int nbands_local = wfc.get_nbands();
75 const int nbasis_local = wfc.get_nbasis();
76
77 // dm = wfc.T * wg * wfc.conj()
78 // dm[is](iw1,iw2) = \sum_{ib} wfc[is](ib,iw1).T * wg(is,ib) * wfc[is](ib,iw2).conj()
79 for (int ik = 0; ik < wfc.get_nk(); ++ik)
80 {
81 wfc.fix_k(ik);
82 //dm.fix_k(ik);
83 dm[ik].create(ParaV->ncol, ParaV->nrow);
84 // wg_wfc(ib,iw) = wg[ib] * wfc(ib,iw);
85 psi::Psi<std::complex<double>> wg_wfc(1, wfc.get_nbands(), wfc.get_nbasis(), wfc.get_nbasis(), true);
86 const std::complex<double>* pwfc = wfc.get_pointer();
87 std::complex<double>* pwg_wfc = wg_wfc.get_pointer();
88#ifdef _OPENMP
89#pragma omp parallel for schedule(static, 1024)
90#endif
91 for(int i = 0;i<wg_wfc.size();++i)
92 {
93 pwg_wfc[i] = conj(pwfc[i]);
94 }
95
96 int ib_global = 0;
97 for (int ib_local = 0; ib_local < nbands_local; ++ib_local)
98 {
99 while (ib_local != ParaV->global2local_col(ib_global))
100 {
101 ++ib_global;
102 if (ib_global >= wg.nc)
103 {
104 break;
105 ModuleBase::WARNING_QUIT("ElecStateLCAO::cal_dm", "please check global2local_col!");
106 }
107 }
108 if (ib_global >= wg.nc) { continue;
109}
110 const double wg_local = wg(ik, ib_global);
111 std::complex<double>* wg_wfc_pointer = &(wg_wfc(0, ib_local, 0));
112 BlasConnector::scal(nbasis_local, wg_local, wg_wfc_pointer, 1);
113 }
114
115 // C++: dm(iw1,iw2) = wfc(ib,iw1).T * wg_wfc(ib,iw2)
116#ifdef __MPI
117 psiMulPsiMpi(wg_wfc, wfc, dm[ik], ParaV->desc_wfc, ParaV->desc);
118#else
119 psiMulPsi(wg_wfc, wfc, dm[ik]);
120#endif
121 }
122
123 ModuleBase::timer::tick("elecstate","cal_dm");
124 return;
125}
126
127}//namespace elecstate
128
129#endif
static void scal(const int n, const float alpha, float *X, const int incX, base_device::AbacusDevice_t device_type=base_device::AbacusDevice_t::CpuDevice)
Definition blas_connector_vector.cpp:80
Definition matrix.h:19
int nc
Definition matrix.h:24
static void tick(const std::string &class_name_in, const std::string &name_in)
Use twice at a time: the first time, set start_flag to false; the second time, calculate the time dur...
Definition timer.cpp:57
int ncol
Definition parallel_2d.h:116
int nrow
local size (nloc = nrow * ncol)
Definition parallel_2d.h:115
int global2local_col(const int igc) const
get the local index of a global index (col)
Definition parallel_2d.h:51
int desc[9]
ScaLAPACK descriptor.
Definition parallel_2d.h:103
Definition parallel_orbitals.h:9
int desc_wfc[9]
Definition parallel_orbitals.h:37
Definition psi.h:37
const int & get_nbands() const
Definition psi.cpp:342
const int & get_nk() const
Definition psi.cpp:336
void set_all_psi(const T *another_pointer, const std::size_t size_in)
Definition psi.cpp:223
size_t size() const
Definition psi.cpp:354
const int & get_nbasis() const
Definition psi.cpp:348
T * get_pointer() const
Definition psi.cpp:272
void fix_k(const int ik) const
Definition psi.cpp:364
void WARNING_QUIT(const std::string &, const std::string &)
Combine the functions of WARNING and QUIT.
Definition test_delley.cpp:14
void TITLE(const std::string &class_name, const std::string &function_name, const bool disable)
Definition tool_title.cpp:18
Definition cal_dm.h:10
void cal_dm(const Parallel_Orbitals *ParaV, const ModuleBase::matrix &wg, const psi::Psi< double > &wfc, std::vector< ModuleBase::matrix > &dm)
Definition cal_dm.h:13
void psiMulPsi(const psi::Psi< double > &psi1, const psi::Psi< double > &psi2, double *dm_out)
Definition cal_dm_psi.cpp:227
void psiMulPsiMpi(const psi::Psi< double > &psi1, const psi::Psi< double > &psi2, double *dm_out, const int *desc_psi, const int *desc_dm)
Definition cal_dm_psi.cpp:156
double conj(double a)
Definition operator_lr_hxc.cpp:14