ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
force_op.h
Go to the documentation of this file.
1#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
2#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
4
5#include "source_psi/psi.h"
6
7#include <complex>
8
9namespace hamilt
10{
11
12template <typename FPTYPE, typename Device>
14{
30 void operator()(const Device* ctx,
31 const int& nkb,
32 const int& npwx,
33 const int& vkb_nc,
34 const int& nbasis,
35 const int& ipol,
36 const std::complex<FPTYPE>& NEG_IMAG_UNIT,
37 const std::complex<FPTYPE>* vkb,
38 const FPTYPE* gcar,
39 std::complex<FPTYPE>* vkb1);
40};
41
42template <typename FPTYPE, typename Device>
44{
72 void operator()(const base_device::DEVICE_CPU* ctx,
73 const bool& nondiagonal,
74 const int& nbands_occ,
75 const int& ntype,
76 const int& spin,
77 const int& deeq_2,
78 const int& deeq_3,
79 const int& deeq_4,
80 const int& forcenl_nc,
81 const int& nbands,
82 const int& nkb,
83 const int* atom_nh,
84 const int* atom_na,
85 const FPTYPE& tpiba,
86 const FPTYPE* d_wg,
87 const bool& occ,
88 const FPTYPE* d_ekb,
89 const FPTYPE* qq_nt,
90 const FPTYPE* deeq,
91 const std::complex<FPTYPE>* becp,
92 const std::complex<FPTYPE>* dbecp,
93 FPTYPE* force);
94 // interface for nspin=4 only
95 void operator()(const base_device::DEVICE_CPU* ctx,
96 const int& nbands_occ,
97 const int& ntype,
98 const int& deeq_2,
99 const int& deeq_3,
100 const int& deeq_4,
101 const int& forcenl_nc,
102 const int& nbands,
103 const int& nkb,
104 const int* atom_nh,
105 const int* atom_na,
106 const FPTYPE& tpiba,
107 const FPTYPE* d_wg,
108 const bool& occ,
109 const FPTYPE* d_ekb,
110 const FPTYPE* qq_nt,
111 const std::complex<FPTYPE>* deeq_nc,
112 const std::complex<FPTYPE>* becp,
113 const std::complex<FPTYPE>* dbecp,
114 FPTYPE* force);
116 void operator()(const base_device::DEVICE_CPU* ctx,
117 const int& nbands_occ,
118 const int& wg_nc,
119 const int& ntype,
120 const int& forcenl_nc,
121 const int& nbands,
122 const int& ik,
123 const int& nkb,
124 const int* atom_nh,
125 const int* atom_na,
126 const FPTYPE& tpiba,
127 const FPTYPE* d_wg,
128 const std::complex<FPTYPE>* vu,
129 const int* orbital_corr,
130 const std::complex<FPTYPE>* becp,
131 const std::complex<FPTYPE>* dbecp,
132 FPTYPE* force);
134 void operator()(const base_device::DEVICE_CPU* ctx,
135 const int& nbands_occ,
136 const int& wg_nc,
137 const int& ntype,
138 const int& forcenl_nc,
139 const int& nbands,
140 const int& ik,
141 const int& nkb,
142 const int* atom_nh,
143 const int* atom_na,
144 const FPTYPE& tpiba,
145 const FPTYPE* d_wg,
146 const FPTYPE* lambda,
147 const std::complex<FPTYPE>* becp,
148 const std::complex<FPTYPE>* dbecp,
149 FPTYPE* force);
150};
151
152template <typename FPTYPE, typename Device>
155 const int nat,
156 const int npw,
157 const FPTYPE tpiba_omega,
158 const int* iat2it,
159 const int* ig2igg,
160 const FPTYPE* gcar,
161 const FPTYPE* tau,
162 const std::complex<FPTYPE>* aux,
163 const FPTYPE* vloc,
164 const int vloc_nr,
165 FPTYPE* forcelc) {};
166};
167
168template <typename FPTYPE, typename Device>
171 const int nat,
172 const int npw,
173 const int ig_gge0,
174 const int* iat2it,
175 const FPTYPE* gcar,
176 const FPTYPE* tau,
177 const FPTYPE* it_fact,
178 const std::complex<FPTYPE>* aux,
179 FPTYPE* forceion
180 ) {};
181};
182#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
183template <typename FPTYPE>
184struct cal_vkb1_nl_op<FPTYPE, base_device::DEVICE_GPU>
185{
186 void operator()(const base_device::DEVICE_GPU* ctx,
187 const int& nkb,
188 const int& npwx,
189 const int& vkb_nc,
190 const int& nbasis,
191 const int& ipol,
192 const std::complex<FPTYPE>& NEG_IMAG_UNIT,
193 const std::complex<FPTYPE>* vkb,
194 const FPTYPE* gcar,
195 std::complex<FPTYPE>* vkb1);
196};
197
198template <typename FPTYPE>
199struct cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>
200{
201 void operator()(const base_device::DEVICE_GPU* ctx,
202 const bool& nondiagonal,
203 const int& nbands_occ,
204 const int& ntype,
205 const int& spin,
206 const int& deeq_2,
207 const int& deeq_3,
208 const int& deeq_4,
209 const int& forcenl_nc,
210 const int& nbands,
211 const int& nkb,
212 const int* atom_nh,
213 const int* atom_na,
214 const FPTYPE& tpiba,
215 const FPTYPE* d_wg,
216 const bool& occ,
217 const FPTYPE* d_ekb,
218 const FPTYPE* qq_nt,
219 const FPTYPE* deeq,
220 const std::complex<FPTYPE>* becp,
221 const std::complex<FPTYPE>* dbecp,
222 FPTYPE* force);
223 // interface for nspin=4 only
224 void operator()(const base_device::DEVICE_GPU* ctx,
225 const int& nbands_occ,
226 const int& ntype,
227 const int& deeq_2,
228 const int& deeq_3,
229 const int& deeq_4,
230 const int& forcenl_nc,
231 const int& nbands,
232 const int& nkb,
233 const int* atom_nh,
234 const int* atom_na,
235 const FPTYPE& tpiba,
236 const FPTYPE* d_wg,
237 const bool& occ,
238 const FPTYPE* d_ekb,
239 const FPTYPE* qq_nt,
240 const std::complex<FPTYPE>* deeq_nc,
241 const std::complex<FPTYPE>* becp,
242 const std::complex<FPTYPE>* dbecp,
243 FPTYPE* force);
245 void operator()(const base_device::DEVICE_GPU* ctx,
246 const int& nbands_occ,
247 const int& wg_nc,
248 const int& ntype,
249 const int& forcenl_nc,
250 const int& nbands,
251 const int& ik,
252 const int& nkb,
253 const int* atom_nh,
254 const int* atom_na,
255 const FPTYPE& tpiba,
256 const FPTYPE* d_wg,
257 const std::complex<FPTYPE>* vu,
258 const int* orbital_corr,
259 const std::complex<FPTYPE>* becp,
260 const std::complex<FPTYPE>* dbecp,
261 FPTYPE* force);
263 void operator()(const base_device::DEVICE_GPU* ctx,
264 const int& nbands_occ,
265 const int& wg_nc,
266 const int& ntype,
267 const int& forcenl_nc,
268 const int& nbands,
269 const int& ik,
270 const int& nkb,
271 const int* atom_nh,
272 const int* atom_na,
273 const FPTYPE& tpiba,
274 const FPTYPE* d_wg,
275 const FPTYPE* lambda,
276 const std::complex<FPTYPE>* becp,
277 const std::complex<FPTYPE>* dbecp,
278 FPTYPE* force);
279};
280
284template <typename FPTYPE>
285void revertVkbValues(const int* gcar_zero_ptrs,
286 std::complex<FPTYPE>* vkb_ptr,
287 const std::complex<FPTYPE>* vkb_save_ptr,
288 int nkb,
289 int gcar_zero_counts,
290 int npw,
291 int ipol,
292 int npwx,
293 const std::complex<FPTYPE> coeff);
294
298template <typename FPTYPE>
299void saveVkbValues(const int* gcar_zero_ptrs,
300 const std::complex<FPTYPE>* vkb_ptr,
301 std::complex<FPTYPE>* vkb_save_ptr,
302 int nkb,
303 int gcar_zero_counts,
304 int npw,
305 int ipol,
306 int npwx);
307
308template <typename FPTYPE>
309struct cal_force_loc_op<FPTYPE, base_device::DEVICE_GPU>{
310 void operator()(
311 const int nat,
312 const int npw,
313 const FPTYPE tpiba_omega,
314 const int* iat2it,
315 const int* ig2igg,
316 const FPTYPE* gcar,
317 const FPTYPE* tau,
318 const std::complex<FPTYPE>* aux,
319 const FPTYPE* vloc,
320 const int vloc_nr,
321 FPTYPE* forcelc);
322};
323
324template <typename FPTYPE>
325struct cal_force_ew_op<FPTYPE, base_device::DEVICE_GPU>{
326 void operator()(
327 const int nat,
328 const int npw,
329 const int ig_gge0,
330 const int* iat2it,
331 const FPTYPE* gcar,
332 const FPTYPE* tau,
333 const FPTYPE* it_fact,
334 const std::complex<FPTYPE>* aux,
335 FPTYPE* forceion
336 );
337};
338#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
339} // namespace hamilt
340#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
Definition device.cpp:21
Definition hamilt.h:12
Definition force_op.h:169
void operator()(const int nat, const int npw, const int ig_gge0, const int *iat2it, const FPTYPE *gcar, const FPTYPE *tau, const FPTYPE *it_fact, const std::complex< FPTYPE > *aux, FPTYPE *forceion)
Definition force_op.h:170
Definition force_op.h:153
void operator()(const int nat, const int npw, const FPTYPE tpiba_omega, const int *iat2it, const int *ig2igg, const FPTYPE *gcar, const FPTYPE *tau, const std::complex< FPTYPE > *aux, const FPTYPE *vloc, const int vloc_nr, FPTYPE *forcelc)
Definition force_op.h:154
Definition force_op.h:44
void operator()(const base_device::DEVICE_CPU *ctx, const int &nbands_occ, const int &ntype, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int &forcenl_nc, const int &nbands, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const std::complex< FPTYPE > *deeq_nc, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
void operator()(const base_device::DEVICE_CPU *ctx, const bool &nondiagonal, const int &nbands_occ, const int &ntype, const int &spin, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int &forcenl_nc, const int &nbands, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const FPTYPE *deeq, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
Calculate the final forces for multi-device.
void operator()(const base_device::DEVICE_CPU *ctx, const int &nbands_occ, const int &wg_nc, const int &ntype, const int &forcenl_nc, const int &nbands, const int &ik, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const std::complex< FPTYPE > *vu, const int *orbital_corr, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
kernel for DFT+U
void operator()(const base_device::DEVICE_CPU *ctx, const int &nbands_occ, const int &wg_nc, const int &ntype, const int &forcenl_nc, const int &nbands, const int &ik, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const FPTYPE *lambda, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
kernel for DeltaSpin
Definition force_op.h:14
void operator()(const Device *ctx, const int &nkb, const int &npwx, const int &vkb_nc, const int &nbasis, const int &ipol, const std::complex< FPTYPE > &NEG_IMAG_UNIT, const std::complex< FPTYPE > *vkb, const FPTYPE *gcar, std::complex< FPTYPE > *vkb1)
The prestep to calculate the final forces.