ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
force_op.h
Go to the documentation of this file.
1#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
2#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
3
5
6#include <complex>
7
8namespace hamilt
9{
10
11template <typename FPTYPE, typename Device>
13{
29 void operator()(const Device* ctx,
30 const int& nkb,
31 const int& npwx,
32 const int& vkb_nc,
33 const int& nbasis,
34 const int& ipol,
35 const std::complex<FPTYPE>& NEG_IMAG_UNIT,
36 const std::complex<FPTYPE>* vkb,
37 const FPTYPE* gcar,
38 std::complex<FPTYPE>* vkb1);
39};
40
41template <typename FPTYPE, typename Device>
43{
71 void operator()(const base_device::DEVICE_CPU* ctx,
72 const bool& nondiagonal,
73 const int& nbands_occ,
74 const int& ntype,
75 const int& spin,
76 const int& deeq_2,
77 const int& deeq_3,
78 const int& deeq_4,
79 const int& forcenl_nc,
80 const int& nbands,
81 const int& nkb,
82 const int* atom_nh,
83 const int* atom_na,
84 const FPTYPE& tpiba,
85 const FPTYPE* d_wg,
86 const bool& occ,
87 const FPTYPE* d_ekb,
88 const FPTYPE* qq_nt,
89 const FPTYPE* deeq,
90 const std::complex<FPTYPE>* becp,
91 const std::complex<FPTYPE>* dbecp,
92 FPTYPE* force);
93 // interface for nspin=4 only
94 void operator()(const base_device::DEVICE_CPU* ctx,
95 const int& nbands_occ,
96 const int& ntype,
97 const int& deeq_2,
98 const int& deeq_3,
99 const int& deeq_4,
100 const int& forcenl_nc,
101 const int& nbands,
102 const int& nkb,
103 const int* atom_nh,
104 const int* atom_na,
105 const FPTYPE& tpiba,
106 const FPTYPE* d_wg,
107 const bool& occ,
108 const FPTYPE* d_ekb,
109 const FPTYPE* qq_nt,
110 const std::complex<FPTYPE>* deeq_nc,
111 const std::complex<FPTYPE>* becp,
112 const std::complex<FPTYPE>* dbecp,
113 FPTYPE* force);
115 void operator()(const base_device::DEVICE_CPU* ctx,
116 const int& nbands_occ,
117 const int& wg_nc,
118 const int& ntype,
119 const int& forcenl_nc,
120 const int& nbands,
121 const int& ik,
122 const int& nkb,
123 const int* atom_nh,
124 const int* atom_na,
125 const FPTYPE& tpiba,
126 const FPTYPE* d_wg,
127 const std::complex<FPTYPE>* vu,
128 const int* orbital_corr,
129 const std::complex<FPTYPE>* becp,
130 const std::complex<FPTYPE>* dbecp,
131 FPTYPE* force);
133 void operator()(const base_device::DEVICE_CPU* ctx,
134 const int& nbands_occ,
135 const int& wg_nc,
136 const int& ntype,
137 const int& forcenl_nc,
138 const int& nbands,
139 const int& ik,
140 const int& nkb,
141 const int* atom_nh,
142 const int* atom_na,
143 const FPTYPE& tpiba,
144 const FPTYPE* d_wg,
145 const FPTYPE* lambda,
146 const std::complex<FPTYPE>* becp,
147 const std::complex<FPTYPE>* dbecp,
148 FPTYPE* force);
149};
150
151template <typename FPTYPE, typename Device>
154 const int nat,
155 const int npw,
156 const FPTYPE tpiba_omega,
157 const int* iat2it,
158 const int* ig2igg,
159 const FPTYPE* gcar,
160 const FPTYPE* tau,
161 const std::complex<FPTYPE>* aux,
162 const FPTYPE* vloc,
163 const int vloc_nr,
164 FPTYPE* forcelc) {};
165};
166
167template <typename FPTYPE, typename Device>
170 const int nat,
171 const int npw,
172 const int ig_gge0,
173 const int* iat2it,
174 const FPTYPE* gcar,
175 const FPTYPE* tau,
176 const FPTYPE* it_fact,
177 const std::complex<FPTYPE>* aux,
178 FPTYPE* forceion
179 ) {};
180};
181#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
182template <typename FPTYPE>
183struct cal_vkb1_nl_op<FPTYPE, base_device::DEVICE_GPU>
184{
185 void operator()(const base_device::DEVICE_GPU* ctx,
186 const int& nkb,
187 const int& npwx,
188 const int& vkb_nc,
189 const int& nbasis,
190 const int& ipol,
191 const std::complex<FPTYPE>& NEG_IMAG_UNIT,
192 const std::complex<FPTYPE>* vkb,
193 const FPTYPE* gcar,
194 std::complex<FPTYPE>* vkb1);
195};
196
197template <typename FPTYPE>
198struct cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>
199{
200 void operator()(const base_device::DEVICE_GPU* ctx,
201 const bool& nondiagonal,
202 const int& nbands_occ,
203 const int& ntype,
204 const int& spin,
205 const int& deeq_2,
206 const int& deeq_3,
207 const int& deeq_4,
208 const int& forcenl_nc,
209 const int& nbands,
210 const int& nkb,
211 const int* atom_nh,
212 const int* atom_na,
213 const FPTYPE& tpiba,
214 const FPTYPE* d_wg,
215 const bool& occ,
216 const FPTYPE* d_ekb,
217 const FPTYPE* qq_nt,
218 const FPTYPE* deeq,
219 const std::complex<FPTYPE>* becp,
220 const std::complex<FPTYPE>* dbecp,
221 FPTYPE* force);
222 // interface for nspin=4 only
223 void operator()(const base_device::DEVICE_GPU* ctx,
224 const int& nbands_occ,
225 const int& ntype,
226 const int& deeq_2,
227 const int& deeq_3,
228 const int& deeq_4,
229 const int& forcenl_nc,
230 const int& nbands,
231 const int& nkb,
232 const int* atom_nh,
233 const int* atom_na,
234 const FPTYPE& tpiba,
235 const FPTYPE* d_wg,
236 const bool& occ,
237 const FPTYPE* d_ekb,
238 const FPTYPE* qq_nt,
239 const std::complex<FPTYPE>* deeq_nc,
240 const std::complex<FPTYPE>* becp,
241 const std::complex<FPTYPE>* dbecp,
242 FPTYPE* force);
244 void operator()(const base_device::DEVICE_GPU* ctx,
245 const int& nbands_occ,
246 const int& wg_nc,
247 const int& ntype,
248 const int& forcenl_nc,
249 const int& nbands,
250 const int& ik,
251 const int& nkb,
252 const int* atom_nh,
253 const int* atom_na,
254 const FPTYPE& tpiba,
255 const FPTYPE* d_wg,
256 const std::complex<FPTYPE>* vu,
257 const int* orbital_corr,
258 const std::complex<FPTYPE>* becp,
259 const std::complex<FPTYPE>* dbecp,
260 FPTYPE* force);
262 void operator()(const base_device::DEVICE_GPU* ctx,
263 const int& nbands_occ,
264 const int& wg_nc,
265 const int& ntype,
266 const int& forcenl_nc,
267 const int& nbands,
268 const int& ik,
269 const int& nkb,
270 const int* atom_nh,
271 const int* atom_na,
272 const FPTYPE& tpiba,
273 const FPTYPE* d_wg,
274 const FPTYPE* lambda,
275 const std::complex<FPTYPE>* becp,
276 const std::complex<FPTYPE>* dbecp,
277 FPTYPE* force);
278};
279
283template <typename FPTYPE>
284void revertVkbValues(const int* gcar_zero_ptrs,
285 std::complex<FPTYPE>* vkb_ptr,
286 const std::complex<FPTYPE>* vkb_save_ptr,
287 int nkb,
288 int gcar_zero_counts,
289 int npw,
290 int ipol,
291 int npwx,
292 const std::complex<FPTYPE> coeff);
293
297template <typename FPTYPE>
298void saveVkbValues(const int* gcar_zero_ptrs,
299 const std::complex<FPTYPE>* vkb_ptr,
300 std::complex<FPTYPE>* vkb_save_ptr,
301 int nkb,
302 int gcar_zero_counts,
303 int npw,
304 int ipol,
305 int npwx);
306
307template <typename FPTYPE>
308struct cal_force_loc_op<FPTYPE, base_device::DEVICE_GPU>{
309 void operator()(
310 const int nat,
311 const int npw,
312 const FPTYPE tpiba_omega,
313 const int* iat2it,
314 const int* ig2igg,
315 const FPTYPE* gcar,
316 const FPTYPE* tau,
317 const std::complex<FPTYPE>* aux,
318 const FPTYPE* vloc,
319 const int vloc_nr,
320 FPTYPE* forcelc);
321};
322
323template <typename FPTYPE>
324struct cal_force_ew_op<FPTYPE, base_device::DEVICE_GPU>{
325 void operator()(
326 const int nat,
327 const int npw,
328 const int ig_gge0,
329 const int* iat2it,
330 const FPTYPE* gcar,
331 const FPTYPE* tau,
332 const FPTYPE* it_fact,
333 const std::complex<FPTYPE>* aux,
334 FPTYPE* forceion
335 );
336};
337#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
338} // namespace hamilt
339#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
Definition device.cpp:21
Definition hamilt.h:13
Definition force_op.h:168
void operator()(const int nat, const int npw, const int ig_gge0, const int *iat2it, const FPTYPE *gcar, const FPTYPE *tau, const FPTYPE *it_fact, const std::complex< FPTYPE > *aux, FPTYPE *forceion)
Definition force_op.h:169
Definition force_op.h:152
void operator()(const int nat, const int npw, const FPTYPE tpiba_omega, const int *iat2it, const int *ig2igg, const FPTYPE *gcar, const FPTYPE *tau, const std::complex< FPTYPE > *aux, const FPTYPE *vloc, const int vloc_nr, FPTYPE *forcelc)
Definition force_op.h:153
Definition force_op.h:43
void operator()(const base_device::DEVICE_CPU *ctx, const int &nbands_occ, const int &ntype, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int &forcenl_nc, const int &nbands, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const std::complex< FPTYPE > *deeq_nc, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
void operator()(const base_device::DEVICE_CPU *ctx, const bool &nondiagonal, const int &nbands_occ, const int &ntype, const int &spin, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int &forcenl_nc, const int &nbands, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const FPTYPE *deeq, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
Calculate the final forces for multi-device.
void operator()(const base_device::DEVICE_CPU *ctx, const int &nbands_occ, const int &wg_nc, const int &ntype, const int &forcenl_nc, const int &nbands, const int &ik, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const std::complex< FPTYPE > *vu, const int *orbital_corr, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
kernel for DFT+U
void operator()(const base_device::DEVICE_CPU *ctx, const int &nbands_occ, const int &wg_nc, const int &ntype, const int &forcenl_nc, const int &nbands, const int &ik, const int &nkb, const int *atom_nh, const int *atom_na, const FPTYPE &tpiba, const FPTYPE *d_wg, const FPTYPE *lambda, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *force)
kernel for DeltaSpin
Definition force_op.h:13
void operator()(const Device *ctx, const int &nkb, const int &npwx, const int &vkb_nc, const int &nbasis, const int &ipol, const std::complex< FPTYPE > &NEG_IMAG_UNIT, const std::complex< FPTYPE > *vkb, const FPTYPE *gcar, std::complex< FPTYPE > *vkb1)
The prestep to calculate the final forces.