1#ifndef MODULE_HSOLVER_DIAGO_CG_H_
2#define MODULE_HSOLVER_DIAGO_CG_H_
15template <
typename T,
typename Device = base_device::DEVICE_CPU>
25 using HPsiFunc = std::function<void(
T*,
T*,
const int,
const int)>;
26 using SPsiFunc = std::function<void(
T*,
T*,
const int,
const int)>;
27 using SubspaceFunc = std::function<void(
T*,
T*,
const int,
const int,
const bool)>;
31 DiagoCG(
const std::string& basis_type,
const std::string& calculation);
33 const std::string& basis_type,
34 const std::string& calculation,
35 const bool& need_subspace,
37 const Real& pw_diag_thr,
38 const int& pw_diag_nmax,
54 const std::vector<double>& ethr_band,
55 const Real* prec =
nullptr);
120 const double& ethreshold,
134 const std::vector<double>& ethr_band);
A multi-dimensional array of elements of a single data type.
Definition tensor.h:32
double diag(const HPsiFunc &hpsi_func, const SPsiFunc &spsi_func, const int ld_psi, const int nband, const int dim, T *psi_in, Real *eigenvalue_in, const std::vector< double > ðr_band, const Real *prec=nullptr)
Definition diago_cg.cpp:582
bool need_subspace_
Definition diago_cg.h:82
Real pw_diag_thr_
threshold for cg diagonalization
Definition diago_cg.h:72
typename ct::PsiToContainer< Device >::type ct_Device
Definition diago_cg.h:23
HPsiFunc hpsi_func_
A function object that performs the hPsi calculation.
Definition diago_cg.h:84
void calc_grad(const ct::Tensor &prec, ct::Tensor &grad, ct::Tensor &hphi, ct::Tensor &sphi, ct::Tensor &pphi)
Definition diago_cg.cpp:216
void diag_once(const ct::Tensor &prec, ct::Tensor &psi, ct::Tensor &eigen, const std::vector< double > ðr_band)
Definition diago_cg.cpp:57
void schmit_orth(const int &m, const ct::Tensor &psi, const ct::Tensor &sphi, ct::Tensor &phi_m)
Definition diago_cg.cpp:477
std::function< void(T *, T *, const int, const int)> HPsiFunc
Definition diago_cg.h:25
int pw_diag_nmax_
maximum iteration steps for cg diagonalization
Definition diago_cg.h:74
SubspaceFunc subspace_func_
A function object that performs the subspace calculation.
Definition diago_cg.h:88
typename GetTypeReal< T >::type Real
Definition diago_cg.h:22
const T * one_
Definition diago_cg.h:139
bool test_exit_cond(const int &ntry, const int ¬conv) const
Definition diago_cg.cpp:568
std::function< void(T *, T *, const int, const int)> SPsiFunc
Definition diago_cg.h:26
const T * neg_one_
Definition diago_cg.h:139
SPsiFunc spsi_func_
A function object that performs the sPsi calculation.
Definition diago_cg.h:86
int n_basis_
col size for input psi matrix
Definition diago_cg.h:66
~DiagoCG()
Definition diago_cg.cpp:49
double avg_iter_
average iteration steps for cg diagonalization
Definition diago_cg.h:68
int n_band_
Definition diago_cg.h:64
void calc_gamma_cg(const int &iter, const Real &cg_norm, const Real &theta, const ct::Tensor &prec, const ct::Tensor &scg, const ct::Tensor &grad, const ct::Tensor &phi_m, Real &gg_last, ct::Tensor &g0, ct::Tensor &cg)
Definition diago_cg.cpp:311
std::string calculation_
calculation type of ABACUS
Definition diago_cg.h:80
int notconv_
Definition diago_cg.h:61
std::function< void(T *, T *, const int, const int, const bool)> SubspaceFunc
Definition diago_cg.h:27
std::string basis_type_
basis_type of psi
Definition diago_cg.h:78
Device * ctx_
Definition diago_cg.h:58
std::vector< int > iter_band
std::vector for iter count of each band
Definition diago_cg.h:70
void orth_grad(const ct::Tensor &psi, const int &m, ct::Tensor &grad, ct::Tensor &scg, ct::Tensor &lagrange)
Definition diago_cg.cpp:261
const T * zero_
Definition diago_cg.h:139
bool update_psi(const ct::Tensor &pphi, const ct::Tensor &cg, const ct::Tensor &scg, const double ðreshold, Real &cg_norm, Real &theta, Real &eigen, ct::Tensor &phi_m, ct::Tensor &sphi, ct::Tensor &hphi)
Definition diago_cg.cpp:393
int nproc_in_pool_
number of processors in a node
Definition diago_cg.h:76
#define T
Definition exp.cpp:237
Definition diag_comm_info.h:9
T type
Definition macros.h:8
Definition math_kernel_op.h:167
Definition tensor_types.h:113
This file contains the definition of the DataType enum class.
int nproc_in_pool
Definition pw_test.cpp:12