ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
esolver_ks_lcao_tddft.h
Go to the documentation of this file.
1#ifndef ESOLVER_KS_LCAO_TDDFT_H
2#define ESOLVER_KS_LCAO_TDDFT_H
3#include "esolver_ks.h"
4#include "esolver_ks_lcao.h"
7#include "source_psi/psi.h"
10
11namespace ModuleESolver
12{
13//------------------------ MPI gathering and distributing functions ------------------------//
14// This struct is used for collecting matrices from all processes to root process
15template <typename T>
17{
18 std::shared_ptr<T> p;
19 size_t row;
20 size_t col;
21 std::shared_ptr<int> desc;
22};
23
24// Collect matrices from all processes to root process
25template <typename T>
26void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock<T>& mat_l, Matrix_g<T>& mat_g)
27{
28 const int* desca = mat_l.desc; // Obtain the descriptor of the local matrix
29 int ctxt = desca[1]; // BLACS context
30 int nrows = desca[2]; // Global matrix row number
31 int ncols = desca[3]; // Global matrix column number
32
33 if (myid == root_proc)
34 {
35 mat_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
36 }
37 else
38 {
39 mat_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
40 }
41
42 // Set the descriptor of the global matrix
43 mat_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
44 mat_g.row = nrows;
45 mat_g.col = ncols;
46
47 // Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
48 Cpxgemr2d(nrows, ncols, mat_l.p, 1, 1, const_cast<int*>(desca), mat_g.p.get(), 1, 1, mat_g.desc.get(), ctxt);
49}
50//------------------------ MPI gathering and distributing functions ------------------------//
51
52template <typename TR, typename Device = base_device::DEVICE_CPU>
53class ESolver_KS_LCAO_TDDFT : public ESolver_KS_LCAO<std::complex<double>, TR>
54{
55 public:
57
59
60 void before_all_runners(UnitCell& ucell, const Input_para& inp) override;
61
62 protected:
63 virtual void runner(UnitCell& cell, const int istep) override;
64
65 virtual void hamilt2rho_single(UnitCell& ucell, const int istep, const int iter, const double ethr) override;
66
67 virtual void update_pot(UnitCell& ucell, const int istep, const int iter, const bool conv_esolver) override;
68
69 virtual void iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) override;
70
71 virtual void after_scf(UnitCell& ucell, const int istep, const bool conv_esolver) override;
72
73 void print_step();
76
78 std::complex<double>** Hk_laststep = nullptr;
79
81 std::complex<double>** Sk_laststep = nullptr;
82
83 const int td_htype = 1;
84
86 bool use_tensor = false;
87 bool use_lapack = false;
88
90 int totstep = -1;
91
94
95 TD_info* td_p = nullptr;
96
98 bool restart_done = false;
99
100 private:
101 void weight_dm_rho(const UnitCell& ucell);
102};
103
104} // namespace ModuleESolver
105#endif
106
Definition esolver_ks_lcao_tddft.h:54
std::complex< double > ** Sk_laststep
Overlap matrix of last time step.
Definition esolver_ks_lcao_tddft.h:81
~ESolver_KS_LCAO_TDDFT()
Definition esolver_ks_lcao_tddft.cpp:58
virtual void hamilt2rho_single(UnitCell &ucell, const int istep, const int iter, const double ethr) override
Definition esolver_ks_lcao_tddft.cpp:235
bool restart_done
doubt
Definition esolver_ks_lcao_tddft.h:98
virtual void runner(UnitCell &cell, const int istep) override
run energy solver
Definition esolver_ks_lcao_tddft.cpp:114
std::complex< double > ** Hk_laststep
Hamiltonian of last time step.
Definition esolver_ks_lcao_tddft.h:78
void print_step()
Definition esolver_ks_lcao_tddft.cpp:228
psi::Psi< std::complex< double > > * psi_laststep
wave functions of last time step
Definition esolver_ks_lcao_tddft.h:75
void weight_dm_rho(const UnitCell &ucell)
Definition esolver_ks_lcao_tddft.cpp:536
const int td_htype
Definition esolver_ks_lcao_tddft.h:83
virtual void iter_finish(UnitCell &ucell, const int istep, int &iter, bool &conv_esolver) override
Something to do after hamilt2rho function in each iter loop.
Definition esolver_ks_lcao_tddft.cpp:311
Velocity_op< TR > * velocity_mat
Velocity matrix for calculating current.
Definition esolver_ks_lcao_tddft.h:93
ESolver_KS_LCAO_TDDFT()
Definition esolver_ks_lcao_tddft.cpp:43
virtual void after_scf(UnitCell &ucell, const int istep, const bool conv_esolver) override
Something to do after SCF iterations when SCF is converged or comes to the max iter step.
Definition esolver_ks_lcao_tddft.cpp:469
virtual void update_pot(UnitCell &ucell, const int istep, const int iter, const bool conv_esolver) override
<Temporary> It should be replaced by a function in Hamilt Class
Definition esolver_ks_lcao_tddft.cpp:340
bool use_tensor
Control heterogeneous computing of the TDDFT solver.
Definition esolver_ks_lcao_tddft.h:86
TD_info * td_p
Definition esolver_ks_lcao_tddft.h:95
bool use_lapack
Definition esolver_ks_lcao_tddft.h:87
void before_all_runners(UnitCell &ucell, const Input_para &inp) override
Initialize of the first-principels energy solver.
Definition esolver_ks_lcao_tddft.cpp:85
int totstep
Total steps for evolving the wave function.
Definition esolver_ks_lcao_tddft.h:90
Definition esolver_ks_lcao.h:50
bool conv_esolver
Definition esolver.h:44
Definition td_info.h:10
Definition unitcell.h:16
Definition velocity_op.h:15
Definition psi.h:37
#define T
Definition exp.cpp:237
plane wave basis
Definition opt_test_tools.cpp:93
void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock< T > &mat_l, Matrix_g< T > &mat_g)
Definition esolver_ks_lcao_tddft.h:26
std::enable_if< block2d_data_type< T >::value, void >::type Cpxgemr2d(int M, int N, T *A, int IA, int JA, int *DESCA, T *B, int IB, int JB, int *DESCB, int ICTXT)
Definition scalapack_connector.h:186
Definition input_parameter.h:12
Definition esolver_ks_lcao_tddft.h:17
std::shared_ptr< T > p
Definition esolver_ks_lcao_tddft.h:18
size_t row
Definition esolver_ks_lcao_tddft.h:19
size_t col
Definition esolver_ks_lcao_tddft.h:20
std::shared_ptr< int > desc
Definition esolver_ks_lcao_tddft.h:21
Definition matrixblock.h:9
T * p
Definition matrixblock.h:12
const int * desc
Definition matrixblock.h:15