ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
math_chebyshev.h
Go to the documentation of this file.
1#ifndef STO_CHEBYCHEV_H
2#define STO_CHEBYCHEV_H
3#include "fftw3.h"
7
8#include <complex>
9#include <functional>
10
11namespace ModuleBase
12{
13// template class for fftw
14template <typename T>
15class FFTW;
16
82template <typename REAL, typename Device = base_device::DEVICE_CPU>
84{
85
86 public:
87 // constructor and deconstructor
88 Chebyshev(const int norder);
89 ~Chebyshev();
90
91 public:
92 // I.
93 // Calculate coefficients C_n[f], where f is a function of real number
94 void calcoef_real(std::function<REAL(REAL)> fun);
95
96 // Calculate coefficients C_n[g], where g is a function of complex number
97 void calcoef_complex(std::function<std::complex<REAL>(std::complex<REAL>)> fun);
98
99 // Calculate coefficients C_n[g], where g is a general complex function g(x)=(g1(x), g2(x))
100 // e.g. exp(ix)=(cos(x),sin(x))
101 void calcoef_pair(std::function<REAL(REAL)> fun1, std::function<REAL(REAL)> fun2);
102
103 // II.
104 // Calculate the final vector f(A)v = \sum_{n=0}^{norder-1} C_n[f]*v_n
105 // Here funA(in, out) means the map v -> Av : funA(v, Av)
106 // Here m represents we treat m vectors at the same time: f(A)[v1,...,vm] and funA(in,out,m) means [v1,...,vm] ->
107 // A[v1,...,vm] N is dimension of vector, and LDA is the distance between the first number of v_n and v_{n+1}. LDA
108 // >= max(1, N). It is the same as the BLAS lib. calfinalvec_real uses C_n[f], where f is a function of real number
109 // and A is a real Operator.
110 void calfinalvec_real(std::function<void(REAL*, REAL*, const int)> funA,
111 REAL* wavein,
112 REAL* waveout,
113 const int N,
114 const int LDA = 1,
115 const int m = 1); // do not define yet
116
117 // calfinalvec_real uses C_n[f], where f is a function of real number and A is a complex Operator.
118 void calfinalvec_real(std::function<void(std::complex<REAL>*, std::complex<REAL>*, const int)> funA,
119 std::complex<REAL>* wavein,
120 std::complex<REAL>* waveout,
121 const int N,
122 const int LDA = 1,
123 const int m = 1);
124
125 // calfinalvec_complex uses C_n[g], where g is a function of complex number and A is a complex Operator.
126 void calfinalvec_complex(std::function<void(std::complex<REAL>*, std::complex<REAL>*, const int)> funA,
127 std::complex<REAL>* wavein,
128 std::complex<REAL>* waveout,
129 const int N,
130 const int LDA = 1,
131 const int m = 1);
132
133 // III.
134 // \sum_i v_i^+f(A)v_i = \sum_{i,n=0}^{norder-1} C_n[f]*v_i^+v_{i,n} = \sum_{n=0}^{norder-1} C_n[f] * w_n
135 // calculate the sum of diagonal elements (Trace) of T_n(A) in v-represent: w_n = \sum_i v_i^+ * T_n(A) * v_i
136 // i = 1,2,...m
137 void tracepolyA(std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
138 std::complex<REAL>* wavein,
139 const int N,
140 const int LDA = 1,
141 const int m = 1);
142
143 // get T_n(x)
144 void getpolyval(REAL x, REAL* polyval, const int N);
145
146 // get each order of vector: {T_0(A)v, T_1(A)v, ..., T_n(A)v}
147 // Note: use it carefully, it will cost a lot of memory!
148 // calpolyvec_real: f(x) = \sum_n C_n*T_n(x), f is a real function
149 void calpolyvec_real(std::function<void(REAL* in, REAL* out, const int)> funA,
150 REAL* wavein,
151 REAL* waveout,
152 const int N,
153 const int LDA = 1,
154 const int m = 1); // do not define yet
155
156 // calpolyvec_complex: f(x) = \sum_n C_n*T_n(x), f is a complex function
157 void calpolyvec_complex(std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
158 std::complex<REAL>* wavein,
159 std::complex<REAL>* waveout,
160 const int N,
161 const int LDA = 1,
162 const int m = 1);
163
164 // IV.
165 // recurs fomula: v_{n+1} = 2Av_n - v_{n-1}
166 // get v_{n+1} from v_n and v_{n-1}
167 // recurs_complex: A is a real operator
168 void recurs_real(std::function<void(REAL* in, REAL* out, const int)> funA,
169 REAL* arraynp1,
170 REAL* arrayn,
171 REAL* arrayn_1,
172 const int N,
173 const int LDA = 1,
174 const int m = 1);
175
176 // recurs_complex: A is a complex operator
177 void recurs_complex(std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
178 std::complex<REAL>* arraynp1,
179 std::complex<REAL>* arrayn,
180 std::complex<REAL>* arrayn_1,
181 const int N,
182 const int LDA = 1,
183 const int m = 1);
184
185 // return 2xTn-Tn_1
186 REAL recurs(const REAL x, const REAL Tn, const REAL Tn_1);
187
188 // V.
189 // auxiliary function
190 // Abs of all eigenvalues of A should be less than 1.
191 // Thus \hat(a) = \frac{(A - (tmax+tmin)/2)}{(tmax-tmin)/2}
192 // tmax >= all eigenvalues; tmin <= all eigenvalues
193 // Here we check if the trial number tmax(tmin) is the upper(lower) bound of eigenvalues and return it.
194 bool checkconverge(std::function<void(std::complex<REAL>* in, std::complex<REAL>* out, const int)> funA,
195 std::complex<REAL>* wavein,
196 const int N,
197 const int LDA,
198 REAL& tmax, // trial number for upper bound
199 REAL& tmin, // trial number for lower bound
200 REAL stept); // tmax = max() + stept, tmin = min() - stept
201
202 public:
203 // Members:
204 int norder; // order of Chebyshev expansion
205 int norder2; // 2 * norder * EXTEND
206
207 REAL* coef_real = nullptr; //[Device] expansion coefficient of each order
208 std::complex<REAL>* coef_complex = nullptr; //[Device] expansion coefficient of each order
209 REAL* coefr_cpu = nullptr; //[CPU] expansion coefficient of each order
210 std::complex<REAL>* coefc_cpu = nullptr; //[CPU] expansion coefficient of each order
211
212 FFTW<REAL> fftw; // use for fftw
213 REAL* polytrace; //[CPU] w_n = \sum_i v^+ * T_n(A) * v, only
214
215 bool getcoef_real; // coef_real has been calculated
216 bool getcoef_complex; // coef_complex has been calculated
217
218 public:
219 // SI.
220 // calculate dot product <psi_L|psi_R>
221 REAL ddot_real(const std::complex<REAL>* psi_L,
222 const std::complex<REAL>* psi_R,
223 const int N,
224 const int LDA = 1,
225 const int m = 1);
226
227 private:
228 Device* ctx = {};
229 base_device::DEVICE_CPU* cpu_ctx = {};
242};
243
244template <>
245class FFTW<double>
246{
247 public:
248 FFTW(const int norder2_in);
249 ~FFTW();
250 void execute_fftw();
251 double* dcoef; //[norder2]
252 fftw_complex* ccoef;
253 fftw_plan coef_plan;
254};
255
256#ifdef __ENABLE_FLOAT_FFTW
257template <>
258class FFTW<float>
259{
260 public:
261 FFTW(const int norder2_in);
262 ~FFTW();
263 void execute_fftw();
264 float* dcoef; //[norder2]
265 fftwf_complex* ccoef;
266 fftwf_plan coef_plan;
267};
268#endif
269
270} // namespace ModuleBase
271
272#endif
A class to treat the Chebyshev expansion.
Definition math_chebyshev.h:84
void getpolyval(REAL x, REAL *polyval, const int N)
Definition math_chebyshev.cpp:101
REAL * coef_real
Definition math_chebyshev.h:207
void calfinalvec_real(std::function< void(REAL *, REAL *, const int)> funA, REAL *wavein, REAL *waveout, const int N, const int LDA=1, const int m=1)
void recurs_real(std::function< void(REAL *in, REAL *out, const int)> funA, REAL *arraynp1, REAL *arrayn, REAL *arrayn_1, const int N, const int LDA=1, const int m=1)
void calfinalvec_complex(std::function< void(std::complex< REAL > *, std::complex< REAL > *, const int)> funA, std::complex< REAL > *wavein, std::complex< REAL > *waveout, const int N, const int LDA=1, const int m=1)
Definition math_chebyshev.cpp:472
void calpolyvec_real(std::function< void(REAL *in, REAL *out, const int)> funA, REAL *wavein, REAL *waveout, const int N, const int LDA=1, const int m=1)
int norder
Definition math_chebyshev.h:204
void calcoef_pair(std::function< REAL(REAL)> fun1, std::function< REAL(REAL)> fun2)
Definition math_chebyshev.cpp:312
bool getcoef_complex
Definition math_chebyshev.h:216
REAL recurs(const REAL x, const REAL Tn, const REAL Tn_1)
Definition math_chebyshev.cpp:111
void calpolyvec_complex(std::function< void(std::complex< REAL > *in, std::complex< REAL > *out, const int)> funA, std::complex< REAL > *wavein, std::complex< REAL > *waveout, const int N, const int LDA=1, const int m=1)
Definition math_chebyshev.cpp:537
~Chebyshev()
Definition math_chebyshev.cpp:82
REAL * coefr_cpu
Definition math_chebyshev.h:209
void recurs_complex(std::function< void(std::complex< REAL > *in, std::complex< REAL > *out, const int)> funA, std::complex< REAL > *arraynp1, std::complex< REAL > *arrayn, std::complex< REAL > *arrayn_1, const int N, const int LDA=1, const int m=1)
Definition math_chebyshev.cpp:628
Device * ctx
Definition math_chebyshev.h:228
base_device::DEVICE_CPU * cpu_ctx
Definition math_chebyshev.h:229
std::complex< REAL > * coef_complex
Definition math_chebyshev.h:208
REAL * polytrace
Definition math_chebyshev.h:213
void calcoef_real(std::function< REAL(REAL)> fun)
Definition math_chebyshev.cpp:160
typename container::PsiToContainer< Device >::type ct_Device
Definition math_chebyshev.h:230
void tracepolyA(std::function< void(std::complex< REAL > *in, std::complex< REAL > *out, const int)> funA, std::complex< REAL > *wavein, const int N, const int LDA=1, const int m=1)
Definition math_chebyshev.cpp:577
void calcoef_complex(std::function< std::complex< REAL >(std::complex< REAL >)> fun)
Definition math_chebyshev.cpp:222
REAL ddot_real(const std::complex< REAL > *psi_L, const std::complex< REAL > *psi_R, const int N, const int LDA=1, const int m=1)
Definition math_chebyshev.cpp:117
std::complex< REAL > * coefc_cpu
Definition math_chebyshev.h:210
FFTW< REAL > fftw
Definition math_chebyshev.h:212
bool getcoef_real
Definition math_chebyshev.h:215
bool checkconverge(std::function< void(std::complex< REAL > *in, std::complex< REAL > *out, const int)> funA, std::complex< REAL > *wavein, const int N, const int LDA, REAL &tmax, REAL &tmin, REAL stept)
Definition math_chebyshev.cpp:658
int norder2
Definition math_chebyshev.h:205
fftw_plan coef_plan
Definition math_chebyshev.h:253
fftw_complex * ccoef
Definition math_chebyshev.h:252
double * dcoef
Definition math_chebyshev.h:251
Definition math_chebyshev.h:15
#define N
Definition exp.cpp:24
Definition array_pool.h:6
Definition memory_op.h:77
Definition memory_op.h:17
Definition memory_op.h:31
T type
Definition tensor_types.h:114
This file contains the definition of the DataType enum class.