ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
psi.h
Go to the documentation of this file.
1#ifndef PSI_H
2#define PSI_H
3
6
7#include <tuple>
8#include <vector>
9
10namespace psi
11{
12
13// structure for getting range of Psi
14// two display method: k index first or bands index first
15struct Range
16{
18 bool k_first;
20 size_t index_1;
23 size_t range_1;
26 size_t range_2;
27 // this is simple constructor for hPsi return
28 Range(const size_t range_in);
29 // constructor 2
30 Range(const bool k_first_in, const size_t index_1_in, const size_t range_1_in, const size_t range_2_in);
31};
32
33// there is the structure of electric wavefunction coefficient
34// the basic operations defined in the Operator Class
35template <typename T, typename Device = base_device::DEVICE_CPU>
36class Psi
37{
38 public:
39 // Constructor 0: basic
40 Psi();
41
42 // Constructor 1:
43 Psi(const int nk_in, const int nbd_in, const int nbs_in, const std::vector<int>& ngk_in, const bool k_first_in);
44
45 // Constructor 2-1: initialize a new psi from the given psi_in
46 Psi(const Psi& psi_in);
47
48 // Constructor 2-2: initialize a new psi from the given psi_in with a different class template
49 // in this case, psi_in may have a different device type.
50 template <typename T_in, typename Device_in = Device>
51 Psi(const Psi<T_in, Device_in>& psi_in);
52
53 // Constructor 3-1: 2D Psi version
54 // used in hsolver-pw function pointer and somewhere.
55 Psi(T* psi_pointer,
56 const int nk_in,
57 const int nbd_in,
58 const int nbs_in,
59 const int current_nbasis_in,
60 const bool k_first_in = true);
61
62 // Constructor 3-2: 2D Psi version
63 Psi(const int nk_in, const int nbd_in, const int nbs_in, const int current_nbasis_in, const bool k_first_in);
64
65 // Destructor for deleting the psi array manually
66 ~Psi();
67
68 // set psi value func 1
69 void set_all_psi(const T* another_pointer, const std::size_t size_in);
70
71 // set psi value func 2
72 void zero_out();
73
74 // size_t size() const {return this->psi.size();}
75 size_t size() const;
76
77 // copy assignment operator
78 Psi& operator=(const Psi& psi_in);
79
80 // allocate psi for three dimensions
81 void resize(const int nks_in, const int nbands_in, const int nbasis_in);
82
83 // get the pointer for the 1st index
84 T* get_pointer() const;
85
86 // get the pointer for the 2nd index (iband for k_first = true, ik for k_first = false)
87 T* get_pointer(const int& ikb) const;
88
89 // interface to get three dimension size
90 const int& get_nk() const;
91 const int& get_nbands() const;
92 const int& get_nbasis() const;
93
96 void fix_k(const int ik) const;
99 void fix_b(const int ib) const;
102 void fix_kb(const int ik, const int ib) const;
103
107 T& operator()(const int ikb1, const int ikb2, const int ibasis) const;
111 T& operator()(const int ikb2, const int ibasis) const;
112 // use operator "(ibasis)" to reach target element for current k and current band
113 T& operator()(const int ibasis) const;
114
115 // return current k-point index
116 int get_current_k() const;
117 // return current band index
118 int get_current_b() const;
119 // return current ngk for PW base
120 int get_current_nbas() const;
121
122 const int& get_ngk(const int ik_in) const;
123
124 const int* get_ngk_pointer() const;
125
126 // return k_first
127 const bool& get_k_first() const;
128
129 // return device type of psi
130 const Device* get_device() const;
131
132 // return psi_bias
133 const int& get_psi_bias() const;
134
135 const int& get_current_ngk() const;
136
137 // solve Range: return(pointer of begin, number of bands or k-points)
138 std::tuple<const T*, int> to_range(const Range& range) const;
139
140 int get_npol() const;
141
142 private:
143 T* psi = nullptr; // avoid using C++ STL
144
145 Device* ctx = {}; // an context identifier for obtaining the device variable
146
147 // dimensions
148 int nk = 1; // number of k points
149 int nbands = 1; // number of bands
150 int nbasis = 1; // number of basis
151
152 mutable int current_k = 0; // current k point
153 mutable int current_b = 0; // current band index
154 mutable int current_nbasis = 1; // current number of basis of current_k
155
156 // current pointer for getting the psi
157 mutable T* psi_current = nullptr;
158 // psi_current = psi + psi_bias;
159 mutable int psi_bias = 0;
160
161 const int* ngk = nullptr;
162
163 bool k_first = true;
164
165 bool allocate_inside = true;
166
167#ifdef __DSP
168 using delete_memory_op = base_device::memory::delete_memory_op_mt<T, Device>;
169 using resize_memory_op = base_device::memory::resize_memory_op_mt<T, Device>;
170#else
173#endif
176};
177
178} // end of namespace psi
179
180#endif
Definition psi.h:37
const int & get_nbands() const
Definition psi.cpp:342
int get_current_b() const
Definition psi.cpp:468
int get_current_nbas() const
Definition psi.cpp:474
int nbasis
Definition psi.h:150
int current_k
Definition psi.h:152
const int & get_nk() const
Definition psi.cpp:336
Psi & operator=(const Psi &psi_in)
Definition psi.cpp:230
T & operator()(const int ikb1, const int ikb2, const int ibasis) const
Definition psi.cpp:437
~Psi()
Definition psi.cpp:38
int get_npol() const
Definition psi.cpp:323
void set_all_psi(const T *another_pointer, const std::size_t size_in)
Definition psi.cpp:223
int nbands
Definition psi.h:149
std::tuple< const T *, int > to_range(const Range &range) const
Definition psi.cpp:494
int nk
Definition psi.h:148
size_t size() const
Definition psi.cpp:354
bool allocate_inside
whether allocate psi inside Psi class
Definition psi.h:165
bool k_first
Definition psi.h:163
void fix_kb(const int ik, const int ib) const
Definition psi.cpp:419
const bool & get_k_first() const
Definition psi.cpp:286
const int & get_psi_bias() const
Definition psi.cpp:304
const int & get_nbasis() const
Definition psi.cpp:348
void resize(const int nks_in, const int nbands_in, const int nbasis_in)
Definition psi.cpp:254
const int * get_ngk_pointer() const
Definition psi.cpp:298
int get_current_k() const
Definition psi.cpp:462
T * get_pointer() const
Definition psi.cpp:272
void fix_b(const int ib) const
Definition psi.cpp:395
int current_nbasis
Definition psi.h:154
const int * ngk
Definition psi.h:161
void zero_out()
Definition psi.cpp:487
int psi_bias
Definition psi.h:159
const int & get_ngk(const int ik_in) const
Definition psi.cpp:480
const int & get_current_ngk() const
Definition psi.cpp:310
Device * ctx
Definition psi.h:145
const Device * get_device() const
Definition psi.cpp:292
void fix_k(const int ik) const
Definition psi.cpp:364
T * psi_current
Definition psi.h:157
int current_b
Definition psi.h:153
Psi()
Definition psi.cpp:33
#define T
Definition exp.cpp:237
Definition exx_lip.h:23
Definition memory_op.h:77
Definition memory_op.h:17
Definition memory_op.h:31
Definition psi.h:16
bool k_first
k_first = 0: Psi(nbands, nks, nbasis) ; 1: Psi(nks, nbands, nbasis)
Definition psi.h:18
size_t index_1
index_1>= 0: target first index; index_1<0: no use
Definition psi.h:20
size_t range_1
Definition psi.h:23
size_t range_2
Definition psi.h:26