ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
setup_psi_pw.h
Go to the documentation of this file.
1#ifndef SETUP_PSI_PW_H
2#define SETUP_PSI_PW_H
3
6#include "source_cell/klist.h"
13
15{
16 public:
17
20
21 //------------
22 // public types
23 //------------
24
25 // Precision type: 0 = float, 1 = double, 2 = complex<float>, 3 = complex<double>
26 enum class PrecisionType {
27 Float = 0,
28 Double = 1,
29 ComplexFloat = 2,
31 };
32
33 //------------
34 // variables
35 // psi_cpu, complex<double> on cpu
36 //------------
37
38 // originally, this term is psi
39 // for PW, we have psi_cpu
40 psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi_cpu = nullptr;
41
42 // psi_initializer controller
44
45 //------------
46 // functions
47 //------------
48
49 void before_runner(
50 const UnitCell &ucell,
51 const K_Vectors &kv,
52 const Structure_Factor &sf,
53 const ModulePW::PW_Basis_K &pw_wfc,
54 const pseudopot_cell_vnl &ppcell,
55 const Input_para &inp);
56
57 void init(hamilt::HamiltBase* p_hamilt);
58
59 void update_psi_d();
60
61 // Transfer data from device to host in pw basis
62 void copy_d2h();
63
64 void clean();
65
66 //------------
67 // accessor functions
68 //------------
69
70 // Get basic information (no type conversion needed, use psi_cpu)
71 int get_nbands() const { return this->psi_cpu->get_nbands(); }
72 int get_nk() const { return this->psi_cpu->get_nk(); }
73 int get_nbasis() const { return this->psi_cpu->get_nbasis(); }
74 size_t size() const { return this->psi_cpu->size(); }
75
76 // Get runtime type information
79
80 // Get psi_t pointer (template version, for backward compatibility)
81 template <typename T, typename Device>
83
84 template <typename T, typename Device>
85 const psi::Psi<T, Device>* get_psi_t() const { return static_cast<const psi::Psi<T, Device>*>(psi_t); }
86
87 // Get psi_d pointer (template version, for backward compatibility)
88 template <typename T, typename Device>
90 return static_cast<psi::Psi<std::complex<double>, Device>*>(psi_d);
91 }
92
93 template <typename T, typename Device>
94 const psi::Psi<std::complex<double>, Device>* get_psi_d() const {
95 return static_cast<const psi::Psi<std::complex<double>, Device>*>(psi_d);
96 }
97
98 private:
99
100 //------------
101 // private variables
102 //------------
103
104 // originally, this term is kspw_psi
105 // if CPU, kspw_psi = psi, otherwise, kspw_psi has a new copy
106 void* psi_t = nullptr; // Use void* to store pointer, runtime type information records actual type
107
108 // originally, this term is __kspw_psi
109 void* psi_d = nullptr; // Use void* to store pointer, runtime type information records actual type
110
111 bool already_initpsi = false;
112
113 //------------
114 // runtime type information
115 //------------
118
119 //------------
120 // private functions
121 //------------
122
123 template <typename T, typename Device>
125 const UnitCell &ucell,
126 const K_Vectors &kv,
127 const Structure_Factor &sf,
128 const ModulePW::PW_Basis_K &pw_wfc,
129 const pseudopot_cell_vnl &ppcell,
130 const Input_para &inp);
131
132 template <typename T, typename Device>
133 void init_impl(hamilt::Hamilt<T, Device>* p_hamilt);
134
135 template <typename T, typename Device>
136 void update_psi_d_impl();
137
138 template <typename T, typename Device>
139 void clean_impl();
140
141 template <typename T, typename Device>
142 void copy_d2h_impl();
143
144 template <typename T, typename Device>
145 void castmem_d2h_impl(std::complex<double>* dst, const std::complex<double>* src, const size_t size);
146
147 template <typename T, typename Device>
148 void castmem_d2h_impl(std::complex<double>* dst, const std::complex<float>* src, const size_t size);
149
150};
151
152
153#endif
Definition klist.h:12
Special pw_basis class. It includes different k-points.
Definition pw_basis_k.h:56
Definition setup_psi_pw.h:15
psi::PSIPrepareBase * p_psi_init
Definition setup_psi_pw.h:43
~Setup_Psi_pw()
Definition setup_psi_pw.cpp:6
void * psi_t
Definition setup_psi_pw.h:106
void clean()
Definition setup_psi_pw.cpp:233
void update_psi_d_impl()
Definition setup_psi_pw.cpp:84
void update_psi_d()
Definition setup_psi_pw.cpp:99
psi::Psi< std::complex< double >, base_device::DEVICE_CPU > * psi_cpu
Definition setup_psi_pw.h:40
PrecisionType
Definition setup_psi_pw.h:26
bool already_initpsi
Definition setup_psi_pw.h:111
void before_runner(const UnitCell &ucell, const K_Vectors &kv, const Structure_Factor &sf, const ModulePW::PW_Basis_K &pw_wfc, const pseudopot_cell_vnl &ppcell, const Input_para &inp)
Definition setup_psi_pw.cpp:49
int get_nbasis() const
Definition setup_psi_pw.h:73
void clean_impl()
Definition setup_psi_pw.cpp:218
psi::Psi< std::complex< double >, Device > * get_psi_d()
Definition setup_psi_pw.h:89
void init(hamilt::HamiltBase *p_hamilt)
Definition setup_psi_pw.cpp:138
base_device::AbacusDevice_t device_type_
Definition setup_psi_pw.h:116
size_t size() const
Definition setup_psi_pw.h:74
void * psi_d
Definition setup_psi_pw.h:109
psi::Psi< T, Device > * get_psi_t()
Definition setup_psi_pw.h:82
Setup_Psi_pw()
Definition setup_psi_pw.cpp:4
void copy_d2h()
Definition setup_psi_pw.cpp:186
PrecisionType precision_type_
Definition setup_psi_pw.h:117
const psi::Psi< T, Device > * get_psi_t() const
Definition setup_psi_pw.h:85
PrecisionType get_precision_type() const
Definition setup_psi_pw.h:78
void init_impl(hamilt::Hamilt< T, Device > *p_hamilt)
Definition setup_psi_pw.cpp:128
int get_nk() const
Definition setup_psi_pw.h:72
void castmem_d2h_impl(std::complex< double > *dst, const std::complex< double > *src, const size_t size)
Definition setup_psi_pw.cpp:206
void copy_d2h_impl()
Definition setup_psi_pw.cpp:178
base_device::AbacusDevice_t get_device_type() const
Definition setup_psi_pw.h:77
const psi::Psi< std::complex< double >, Device > * get_psi_d() const
Definition setup_psi_pw.h:94
void before_runner_impl(const UnitCell &ucell, const K_Vectors &kv, const Structure_Factor &sf, const ModulePW::PW_Basis_K &pw_wfc, const pseudopot_cell_vnl &ppcell, const Input_para &inp)
Definition setup_psi_pw.cpp:9
int get_nbands() const
Definition setup_psi_pw.h:71
Definition structure_factor.h:10
Definition unitcell.h:15
Base class for Hamiltonian.
Definition hamilt_base.h:17
Definition hamilt.h:17
Definition vnl_pw.h:21
Base class for PSIPrepare without template parameters.
Definition psi_prepare_base.h:15
Definition psi.h:37
const int & get_nbands() const
Definition psi.cpp:341
const int & get_nk() const
Definition psi.cpp:335
size_t size() const
Definition psi.cpp:353
const int & get_nbasis() const
Definition psi.cpp:347
AbacusDevice_t
Definition types.h:12
@ CpuDevice
Definition types.h:14
Definition input_parameter.h:12