ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
diago_cg.h
Go to the documentation of this file.
1#ifndef MODULE_HSOLVER_DIAGO_CG_H_
2#define MODULE_HSOLVER_DIAGO_CG_H_
3
4#include <functional>
5#include <vector>
6
9
10#include <ATen/core/tensor.h>
12
13namespace hsolver {
14
15template <typename T, typename Device = base_device::DEVICE_CPU>
16class DiagoCG final
17{
18 // private: accessibility within class is private by default
19 // Note GetTypeReal<T>::type will
20 // return T if T is real type(float, double),
21 // otherwise return the real type of T(complex<float>, std::complex<double>)
22 using Real = typename GetTypeReal<T>::type;
24 public:
25 using HPsiFunc = std::function<void(T*, T*, const int, const int)>;
26 using SPsiFunc = std::function<void(T*, T*, const int, const int)>;
27 using SubspaceFunc = std::function<void(T*, T*, const int, const int, const bool)>;
28 // Constructor need:
29 // 1. temporary mock of Hamiltonian "Hamilt_PW"
30 // 2. precondition pointer should point to place of precondition array.
31 DiagoCG(const std::string& basis_type, const std::string& calculation);
32 DiagoCG(
33 const std::string& basis_type,
34 const std::string& calculation,
35 const bool& need_subspace,
36 const SubspaceFunc& subspace_func,
37 const Real& pw_diag_thr,
38 const int& pw_diag_nmax,
39 const int& nproc_in_pool);
40
41 ~DiagoCG();
42
43 // virtual void init(){};
44 // refactor hpsi_info
45 // this is the diag() function for CG method
46 // returns avg_iter
47 double diag(const HPsiFunc& hpsi_func,
48 const SPsiFunc& spsi_func,
49 const int ld_psi,
50 const int nband,
51 const int dim,
52 T* psi_in,
53 Real* eigenvalue_in,
54 const std::vector<double>& ethr_band,
55 const Real* prec = nullptr);
56
57 private:
58 Device * ctx_ = {};
61 int notconv_ = 0;
64 int n_band_ = 0;
66 int n_basis_ = 0;
68 double avg_iter_ = 0;
70 std::vector<int> iter_band;
78 std::string basis_type_ = {};
80 std::string calculation_ = {};
81
82 bool need_subspace_ = false;
89
90 void calc_grad(
91 const ct::Tensor& prec,
92 ct::Tensor& grad,
93 ct::Tensor& hphi,
94 ct::Tensor& sphi,
95 ct::Tensor& pphi);
96
97 void orth_grad(
98 const ct::Tensor& psi,
99 const int& m,
100 ct::Tensor& grad,
101 ct::Tensor& scg,
102 ct::Tensor& lagrange);
103
104 void calc_gamma_cg(
105 const int& iter,
106 const Real& cg_norm,
107 const Real& theta,
108 const ct::Tensor& prec,
109 const ct::Tensor& scg,
110 const ct::Tensor& grad,
111 const ct::Tensor& phi_m,
112 Real& gg_last,
113 ct::Tensor& g0,
114 ct::Tensor& cg);
115
116 bool update_psi(
117 const ct::Tensor& pphi,
118 const ct::Tensor& cg,
119 const ct::Tensor& scg,
120 const double& ethreshold,
121 Real &cg_norm,
122 Real &theta,
123 Real &eigen,
124 ct::Tensor& phi_m,
125 ct::Tensor& sphi,
126 ct::Tensor& hphi);
127
128 void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);
129
130 // used in diag() for template replace Hamilt with Hamilt_PW
131 void diag_once(const ct::Tensor& prec,
133 ct::Tensor& eigen,
134 const std::vector<double>& ethr_band);
135
136 bool test_exit_cond(const int& ntry, const int& notconv) const;
137
139 const T * one_ = nullptr, * zero_ = nullptr, * neg_one_ = nullptr;
140};
141
142} // namespace hsolver
143
144#endif // MODULE_HSOLVER_DIAGO_CG_H_
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
Definition diago_cg.h:17
double diag(const HPsiFunc &hpsi_func, const SPsiFunc &spsi_func, const int ld_psi, const int nband, const int dim, T *psi_in, Real *eigenvalue_in, const std::vector< double > &ethr_band, const Real *prec=nullptr)
Definition diago_cg.cpp:582
bool need_subspace_
Definition diago_cg.h:82
Real pw_diag_thr_
threshold for cg diagonalization
Definition diago_cg.h:72
typename ct::PsiToContainer< Device >::type ct_Device
Definition diago_cg.h:23
HPsiFunc hpsi_func_
A function object that performs the hPsi calculation.
Definition diago_cg.h:84
void calc_grad(const ct::Tensor &prec, ct::Tensor &grad, ct::Tensor &hphi, ct::Tensor &sphi, ct::Tensor &pphi)
Definition diago_cg.cpp:216
void diag_once(const ct::Tensor &prec, ct::Tensor &psi, ct::Tensor &eigen, const std::vector< double > &ethr_band)
Definition diago_cg.cpp:57
void schmit_orth(const int &m, const ct::Tensor &psi, const ct::Tensor &sphi, ct::Tensor &phi_m)
Definition diago_cg.cpp:477
std::function< void(T *, T *, const int, const int)> HPsiFunc
Definition diago_cg.h:25
int pw_diag_nmax_
maximum iteration steps for cg diagonalization
Definition diago_cg.h:74
SubspaceFunc subspace_func_
A function object that performs the subspace calculation.
Definition diago_cg.h:88
typename GetTypeReal< T >::type Real
Definition diago_cg.h:22
const T * one_
Definition diago_cg.h:139
bool test_exit_cond(const int &ntry, const int &notconv) const
Definition diago_cg.cpp:568
std::function< void(T *, T *, const int, const int)> SPsiFunc
Definition diago_cg.h:26
const T * neg_one_
Definition diago_cg.h:139
SPsiFunc spsi_func_
A function object that performs the sPsi calculation.
Definition diago_cg.h:86
int n_basis_
col size for input psi matrix
Definition diago_cg.h:66
~DiagoCG()
Definition diago_cg.cpp:49
double avg_iter_
average iteration steps for cg diagonalization
Definition diago_cg.h:68
int n_band_
Definition diago_cg.h:64
void calc_gamma_cg(const int &iter, const Real &cg_norm, const Real &theta, const ct::Tensor &prec, const ct::Tensor &scg, const ct::Tensor &grad, const ct::Tensor &phi_m, Real &gg_last, ct::Tensor &g0, ct::Tensor &cg)
Definition diago_cg.cpp:311
std::string calculation_
calculation type of ABACUS
Definition diago_cg.h:80
int notconv_
Definition diago_cg.h:61
std::function< void(T *, T *, const int, const int, const bool)> SubspaceFunc
Definition diago_cg.h:27
std::string basis_type_
basis_type of psi
Definition diago_cg.h:78
Device * ctx_
Definition diago_cg.h:58
std::vector< int > iter_band
std::vector for iter count of each band
Definition diago_cg.h:70
void orth_grad(const ct::Tensor &psi, const int &m, ct::Tensor &grad, ct::Tensor &scg, ct::Tensor &lagrange)
Definition diago_cg.cpp:261
const T * zero_
Definition diago_cg.h:139
bool update_psi(const ct::Tensor &pphi, const ct::Tensor &cg, const ct::Tensor &scg, const double &ethreshold, Real &cg_norm, Real &theta, Real &eigen, ct::Tensor &phi_m, ct::Tensor &sphi, ct::Tensor &hphi)
Definition diago_cg.cpp:393
int nproc_in_pool_
number of processors in a node
Definition diago_cg.h:76
#define T
Definition exp.cpp:237
Definition diag_comm_info.h:9
Definition exx_lip.h:23
T type
Definition macros.h:8
Definition math_kernel_op.h:167
Definition tensor_types.h:113
This file contains the definition of the DataType enum class.
int nproc_in_pool
Definition pw_test.cpp:12