ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
diago_cg.h
Go to the documentation of this file.
1#ifndef MODULE_HSOLVER_DIAGO_CG_H_
2#define MODULE_HSOLVER_DIAGO_CG_H_
3
4#include <functional>
5
8
9#include <ATen/core/tensor.h>
11
12namespace hsolver {
13
14template <typename T, typename Device = base_device::DEVICE_CPU>
15class DiagoCG final
16{
17 // private: accessibility within class is private by default
18 // Note GetTypeReal<T>::type will
19 // return T if T is real type(float, double),
20 // otherwise return the real type of T(complex<float>, std::complex<double>)
21 using Real = typename GetTypeReal<T>::type;
23 public:
24 using Func = std::function<void(const ct::Tensor&, ct::Tensor&)>;
25 using SubspaceFunc = std::function<void(const ct::Tensor&, ct::Tensor&, const bool)>;
26 // Constructor need:
27 // 1. temporary mock of Hamiltonian "Hamilt_PW"
28 // 2. precondition pointer should point to place of precondition array.
29 DiagoCG(const std::string& basis_type, const std::string& calculation);
30 DiagoCG(
31 const std::string& basis_type,
32 const std::string& calculation,
33 const bool& need_subspace,
34 const SubspaceFunc& subspace_func,
35 const Real& pw_diag_thr,
36 const int& pw_diag_nmax,
37 const int& nproc_in_pool);
38
39 ~DiagoCG();
40
41 // virtual void init(){};
42 // refactor hpsi_info
43 // this is the diag() function for CG method
44 void diag(const Func& hpsi_func,
45 const Func& spsi_func,
47 ct::Tensor& eigen,
48 const std::vector<double>& ethr_band,
49 const ct::Tensor& prec = {});
50
51 private:
52 Device * ctx_ = {};
55 int notconv_ = 0;
58 int n_band_ = 0;
60 int n_basis_ = 0;
62 int avg_iter_ = 0;
70 std::string basis_type_ = {};
72 std::string calculation_ = {};
73
74 bool need_subspace_ = false;
76 Func hpsi_func_ = nullptr;
78 Func spsi_func_ = nullptr;
81
82 void calc_grad(
83 const ct::Tensor& prec,
84 ct::Tensor& grad,
85 ct::Tensor& hphi,
86 ct::Tensor& sphi,
87 ct::Tensor& pphi);
88
89 void orth_grad(
90 const ct::Tensor& psi,
91 const int& m,
92 ct::Tensor& grad,
93 ct::Tensor& scg,
94 ct::Tensor& lagrange);
95
96 void calc_gamma_cg(
97 const int& iter,
98 const Real& cg_norm,
99 const Real& theta,
100 const ct::Tensor& prec,
101 const ct::Tensor& scg,
102 const ct::Tensor& grad,
103 const ct::Tensor& phi_m,
104 Real& gg_last,
105 ct::Tensor& g0,
106 ct::Tensor& cg);
107
108 bool update_psi(
109 const ct::Tensor& pphi,
110 const ct::Tensor& cg,
111 const ct::Tensor& scg,
112 const double& ethreshold,
113 Real &cg_norm,
114 Real &theta,
115 Real &eigen,
116 ct::Tensor& phi_m,
117 ct::Tensor& sphi,
118 ct::Tensor& hphi);
119
120 void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);
121
122 // used in diag() for template replace Hamilt with Hamilt_PW
123 void diag_once(const ct::Tensor& prec,
125 ct::Tensor& eigen,
126 const std::vector<double>& ethr_band);
127
128 bool test_exit_cond(const int& ntry, const int& notconv) const;
129
131 const T * one_ = nullptr, * zero_ = nullptr, * neg_one_ = nullptr;
132};
133
134} // namespace hsolver
135
136#endif // MODULE_HSOLVER_DIAGO_CG_H_
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
Definition diago_cg.h:16
Func hpsi_func_
A function object that performs the hPsi calculation.
Definition diago_cg.h:76
bool need_subspace_
Definition diago_cg.h:74
Real pw_diag_thr_
threshold for cg diagonalization
Definition diago_cg.h:64
typename ct::PsiToContainer< Device >::type ct_Device
Definition diago_cg.h:22
void calc_grad(const ct::Tensor &prec, ct::Tensor &grad, ct::Tensor &hphi, ct::Tensor &sphi, ct::Tensor &pphi)
Definition diago_cg.cpp:215
void diag_once(const ct::Tensor &prec, ct::Tensor &psi, ct::Tensor &eigen, const std::vector< double > &ethr_band)
Definition diago_cg.cpp:57
void schmit_orth(const int &m, const ct::Tensor &psi, const ct::Tensor &sphi, ct::Tensor &phi_m)
Definition diago_cg.cpp:476
Func spsi_func_
A function object that performs the sPsi calculation.
Definition diago_cg.h:78
std::function< void(const ct::Tensor &, ct::Tensor &)> Func
Definition diago_cg.h:24
int avg_iter_
average iteration steps for cg diagonalization
Definition diago_cg.h:62
int pw_diag_nmax_
maximum iteration steps for cg diagonalization
Definition diago_cg.h:66
SubspaceFunc subspace_func_
A function object that performs the subspace calculation.
Definition diago_cg.h:80
typename GetTypeReal< T >::type Real
Definition diago_cg.h:21
const T * one_
Definition diago_cg.h:131
std::function< void(const ct::Tensor &, ct::Tensor &, const bool)> SubspaceFunc
Definition diago_cg.h:25
bool test_exit_cond(const int &ntry, const int &notconv) const
Definition diago_cg.cpp:564
const T * neg_one_
Definition diago_cg.h:131
int n_basis_
col size for input psi matrix
Definition diago_cg.h:60
~DiagoCG()
Definition diago_cg.cpp:49
int n_band_
Definition diago_cg.h:58
void calc_gamma_cg(const int &iter, const Real &cg_norm, const Real &theta, const ct::Tensor &prec, const ct::Tensor &scg, const ct::Tensor &grad, const ct::Tensor &phi_m, Real &gg_last, ct::Tensor &g0, ct::Tensor &cg)
Definition diago_cg.cpp:310
std::string calculation_
calculation type of ABACUS
Definition diago_cg.h:72
void diag(const Func &hpsi_func, const Func &spsi_func, ct::Tensor &psi, ct::Tensor &eigen, const std::vector< double > &ethr_band, const ct::Tensor &prec={})
Definition diago_cg.cpp:578
int notconv_
Definition diago_cg.h:55
std::string basis_type_
basis_type of psi
Definition diago_cg.h:70
Device * ctx_
Definition diago_cg.h:52
void orth_grad(const ct::Tensor &psi, const int &m, ct::Tensor &grad, ct::Tensor &scg, ct::Tensor &lagrange)
Definition diago_cg.cpp:260
const T * zero_
Definition diago_cg.h:131
bool update_psi(const ct::Tensor &pphi, const ct::Tensor &cg, const ct::Tensor &scg, const double &ethreshold, Real &cg_norm, Real &theta, Real &eigen, ct::Tensor &phi_m, ct::Tensor &sphi, ct::Tensor &hphi)
Definition diago_cg.cpp:392
int nproc_in_pool_
number of processors in a node
Definition diago_cg.h:68
#define T
Definition exp.cpp:237
Definition diag_comm_info.h:9
Definition exx_lip.h:23
T type
Definition macros.h:8
Definition math_kernel_op.h:168
Definition tensor_types.h:113
This file contains the definition of the DataType enum class.
int nproc_in_pool
Definition pw_test.cpp:12