ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
diago_cg.h
Go to the documentation of this file.
1#ifndef MODULE_HSOLVER_DIAGO_CG_H_
2#define MODULE_HSOLVER_DIAGO_CG_H_
3
4#include <functional>
5
8
9#include <ATen/core/tensor.h>
11
12namespace hsolver {
13
14template <typename T, typename Device = base_device::DEVICE_CPU>
15class DiagoCG final
16{
17 // private: accessibility within class is private by default
18 // Note GetTypeReal<T>::type will
19 // return T if T is real type(float, double),
20 // otherwise return the real type of T(complex<float>, std::complex<double>)
21 using Real = typename GetTypeReal<T>::type;
23 public:
24 using Func = std::function<void(const ct::Tensor&, ct::Tensor&)>;
25 // Constructor need:
26 // 1. temporary mock of Hamiltonian "Hamilt_PW"
27 // 2. precondition pointer should point to place of precondition array.
28 DiagoCG(const std::string& basis_type, const std::string& calculation);
29 DiagoCG(
30 const std::string& basis_type,
31 const std::string& calculation,
32 const bool& need_subspace,
33 const Func& subspace_func,
34 const Real& pw_diag_thr,
35 const int& pw_diag_nmax,
36 const int& nproc_in_pool);
37
38 ~DiagoCG();
39
40 // virtual void init(){};
41 // refactor hpsi_info
42 // this is the diag() function for CG method
43 void diag(const Func& hpsi_func,
44 const Func& spsi_func,
46 ct::Tensor& eigen,
47 const std::vector<double>& ethr_band,
48 const ct::Tensor& prec = {});
49
50 private:
51 Device * ctx_ = {};
54 int notconv_ = 0;
57 int n_band_ = 0;
59 int n_basis_ = 0;
61 int avg_iter_ = 0;
69 std::string basis_type_ = {};
71 std::string calculation_ = {};
72
73 bool need_subspace_ = false;
75 std::function<void(const ct::Tensor&, ct::Tensor&)> hpsi_func_ = nullptr;
77 std::function<void(const ct::Tensor&, ct::Tensor&)> spsi_func_ = nullptr;
79 std::function<void(const ct::Tensor&, ct::Tensor&)> subspace_func_ = nullptr;
80
81 void calc_grad(
82 const ct::Tensor& prec,
83 ct::Tensor& grad,
84 ct::Tensor& hphi,
85 ct::Tensor& sphi,
86 ct::Tensor& pphi);
87
88 void orth_grad(
89 const ct::Tensor& psi,
90 const int& m,
91 ct::Tensor& grad,
92 ct::Tensor& scg,
93 ct::Tensor& lagrange);
94
95 void calc_gamma_cg(
96 const int& iter,
97 const Real& cg_norm,
98 const Real& theta,
99 const ct::Tensor& prec,
100 const ct::Tensor& scg,
101 const ct::Tensor& grad,
102 const ct::Tensor& phi_m,
103 Real& gg_last,
104 ct::Tensor& g0,
105 ct::Tensor& cg);
106
107 bool update_psi(
108 const ct::Tensor& pphi,
109 const ct::Tensor& cg,
110 const ct::Tensor& scg,
111 const double& ethreshold,
112 Real &cg_norm,
113 Real &theta,
114 Real &eigen,
115 ct::Tensor& phi_m,
116 ct::Tensor& sphi,
117 ct::Tensor& hphi);
118
119 void schmit_orth(const int& m, const ct::Tensor& psi, const ct::Tensor& sphi, ct::Tensor& phi_m);
120
121 // used in diag() for template replace Hamilt with Hamilt_PW
122 void diag_mock(const ct::Tensor& prec,
124 ct::Tensor& eigen,
125 const std::vector<double>& ethr_band);
126
127 bool test_exit_cond(const int& ntry, const int& notconv) const;
128
130 const T * one_ = nullptr, * zero_ = nullptr, * neg_one_ = nullptr;
131};
132
133} // namespace hsolver
134
135#endif // MODULE_HSOLVER_DIAGO_CG_H_
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
Definition diago_cg.h:16
bool need_subspace_
Definition diago_cg.h:73
Real pw_diag_thr_
threshold for cg diagonalization
Definition diago_cg.h:63
typename ct::PsiToContainer< Device >::type ct_Device
Definition diago_cg.h:22
void calc_grad(const ct::Tensor &prec, ct::Tensor &grad, ct::Tensor &hphi, ct::Tensor &sphi, ct::Tensor &pphi)
Definition diago_cg.cpp:215
void schmit_orth(const int &m, const ct::Tensor &psi, const ct::Tensor &sphi, ct::Tensor &phi_m)
Definition diago_cg.cpp:476
std::function< void(const ct::Tensor &, ct::Tensor &)> spsi_func_
A function object that performs the sPsi calculation.
Definition diago_cg.h:77
std::function< void(const ct::Tensor &, ct::Tensor &)> hpsi_func_
A function object that performs the hPsi calculation.
Definition diago_cg.h:75
std::function< void(const ct::Tensor &, ct::Tensor &)> Func
Definition diago_cg.h:24
int avg_iter_
average iteration steps for cg diagonalization
Definition diago_cg.h:61
int pw_diag_nmax_
maximum iteration steps for cg diagonalization
Definition diago_cg.h:65
typename GetTypeReal< T >::type Real
Definition diago_cg.h:21
const T * one_
Definition diago_cg.h:130
bool test_exit_cond(const int &ntry, const int &notconv) const
Definition diago_cg.cpp:564
const T * neg_one_
Definition diago_cg.h:130
int n_basis_
col size for input psi matrix
Definition diago_cg.h:59
~DiagoCG()
Definition diago_cg.cpp:49
int n_band_
Definition diago_cg.h:57
void calc_gamma_cg(const int &iter, const Real &cg_norm, const Real &theta, const ct::Tensor &prec, const ct::Tensor &scg, const ct::Tensor &grad, const ct::Tensor &phi_m, Real &gg_last, ct::Tensor &g0, ct::Tensor &cg)
Definition diago_cg.cpp:310
std::function< void(const ct::Tensor &, ct::Tensor &)> subspace_func_
A function object that performs the subspace calculation.
Definition diago_cg.h:79
std::string calculation_
calculation type of ABACUS
Definition diago_cg.h:71
void diag(const Func &hpsi_func, const Func &spsi_func, ct::Tensor &psi, ct::Tensor &eigen, const std::vector< double > &ethr_band, const ct::Tensor &prec={})
Definition diago_cg.cpp:578
int notconv_
Definition diago_cg.h:54
std::string basis_type_
basis_type of psi
Definition diago_cg.h:69
Device * ctx_
Definition diago_cg.h:51
void diag_mock(const ct::Tensor &prec, ct::Tensor &psi, ct::Tensor &eigen, const std::vector< double > &ethr_band)
Definition diago_cg.cpp:57
void orth_grad(const ct::Tensor &psi, const int &m, ct::Tensor &grad, ct::Tensor &scg, ct::Tensor &lagrange)
Definition diago_cg.cpp:260
const T * zero_
Definition diago_cg.h:130
bool update_psi(const ct::Tensor &pphi, const ct::Tensor &cg, const ct::Tensor &scg, const double &ethreshold, Real &cg_norm, Real &theta, Real &eigen, ct::Tensor &phi_m, ct::Tensor &sphi, ct::Tensor &hphi)
Definition diago_cg.cpp:392
int nproc_in_pool_
number of processors in a node
Definition diago_cg.h:67
#define T
Definition exp.cpp:237
Definition diag_comm_info.h:9
Definition exx_lip.h:23
T type
Definition macros.h:8
Definition math_kernel_op.h:168
Definition tensor_types.h:113
This file contains the definition of the DataType enum class.
int nproc_in_pool
Definition pw_test.cpp:12