1#ifndef MODULE_HSOLVER_BPCG_KERNEL_H
2#define MODULE_HSOLVER_BPCG_KERNEL_H
8template <
typename T,
typename Device>
29 const int& n_basis_max,
33template <
typename T,
typename Device>
58 const int& n_basis_max,
62template <
typename T,
typename Device>
71 const Real* eigenvalues);
74template <
typename T,
typename Device>
81 const Real* precondition,
82 const Real* eigenvalues);
85template <
typename T,
typename Device>
92 Real* psi_norm =
nullptr);
95#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
100 void operator()(
T *grad_out,
T *hgrad_out,
T *psi_out,
T *hpsi_out,
101 const int &n_basis,
const int &n_basis_max,
106struct calc_grad_with_block_op<
T,
base_device::DEVICE_GPU> {
109 T *psi_out,
T *hpsi_out,
T *grad_out,
T *grad_old_out,
110 const int &n_basis,
const int &n_basis_max,
115struct apply_eigenvalues_op<
T,
base_device::DEVICE_GPU> {
122 const Real* eigenvalues);
132 const Real* precondition,
133 const Real* eigenvalues);
143 Real* psi_norm =
nullptr);
#define T
Definition exp.cpp:237
Definition diag_comm_info.h:9
T type
Definition macros.h:8
Definition bpcg_kernel_op.h:64
typename GetTypeReal< T >::type Real
Definition bpcg_kernel_op.h:65
void operator()(const int &nbase, const int &nbase_x, const int ¬conv, T *result, const T *vectors, const Real *eigenvalues)
Definition bpcg_kernel_op.h:35
void operator()(const Real *prec_in, Real *err_out, Real *beta_out, T *psi_out, T *hpsi_out, T *grad_out, T *grad_old_out, const int &n_basis, const int &n_basis_max, const int &n_band)
typename GetTypeReal< T >::type Real
dot_real_op computes the dot product of the given complex arrays(treated as float arrays)....
Definition bpcg_kernel_op.h:49
Definition bpcg_kernel_op.h:10
void operator()(T *grad_out, T *hgrad_out, T *psi_out, T *hpsi_out, const int &n_basis, const int &n_basis_max, const int &n_band)
dot_real_op computes the dot product of the given complex arrays(treated as float arrays)....
Definition bpcg_kernel_op.h:86
typename GetTypeReal< T >::type Real
Definition bpcg_kernel_op.h:87
void operator()(const int &dim, T *psi_iter, const int &nbase, const int ¬conv, Real *psi_norm=nullptr)
Definition bpcg_kernel_op.h:75
void operator()(const int &dim, T *psi_iter, const int &nbase, const int ¬conv, const Real *precondition, const Real *eigenvalues)
typename GetTypeReal< T >::type Real
Definition bpcg_kernel_op.h:76