ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
stress_op.h
Go to the documentation of this file.
1#ifndef SRC_PW_STRESS_MULTI_DEVICE_H
2#define SRC_PW_STRESS_MULTI_DEVICE_H
4
5#include "source_psi/psi.h"
6
7#include <complex>
9
10namespace hamilt
11{
12
13template <typename FPTYPE, typename Device>
15{
37 void operator()(const Device* ctx,
38 const int& ipol,
39 const int& jpol,
40 const int& nkb,
41 const int& npw,
42 const int& npwx,
43 const int& ik,
44 const FPTYPE& tpiba,
45 const FPTYPE* gcar,
46 const FPTYPE* kvec_c,
47 std::complex<FPTYPE>* vkbi,
48 std::complex<FPTYPE>* vkbj,
49 std::complex<FPTYPE>* vkb,
50 std::complex<FPTYPE>* vkb1,
51 std::complex<FPTYPE>* vkb2,
52 std::complex<FPTYPE>* dbecp_noevc);
53};
54
55template <typename FPTYPE, typename Device>
57{
84 void operator()(const Device* ctx,
85 const bool& nondiagonal,
86 const int& ipol,
87 const int& jpol,
88 const int& nkb,
89 const int& nbands_occ,
90 const int& ntype,
91 const int& spin,
92 const int& deeq_2,
93 const int& deeq_3,
94 const int& deeq_4,
95 const int* atom_nh,
96 const int* atom_na,
97 const FPTYPE* d_wg,
98 const bool& occ,
99 const FPTYPE* d_ekb,
100 const FPTYPE* qq_nt,
101 const FPTYPE* deeq,
102 const std::complex<FPTYPE>* becp,
103 const std::complex<FPTYPE>* dbecp,
104 FPTYPE* stress);
105 // interface for nspin=4 only
106 void operator()(const Device* ctx,
107 const int& ipol,
108 const int& jpol,
109 const int& nkb,
110 const int& nbands_occ,
111 const int& ntype,
112 const int& deeq_2,
113 const int& deeq_3,
114 const int& deeq_4,
115 const int* atom_nh,
116 const int* atom_na,
117 const FPTYPE* d_wg,
118 const bool& occ,
119 const FPTYPE* d_ekb,
120 const FPTYPE* qq_nt,
121 const std::complex<FPTYPE>* deeq_nc,
122 const std::complex<FPTYPE>* becp,
123 const std::complex<FPTYPE>* dbecp,
124 FPTYPE* stress);
125 // kernel for DFT+U
126 void operator()(const base_device::DEVICE_CPU* ctx,
127 const int& nkb,
128 const int& nbands_occ,
129 const int& ntype,
130 const int& wg_nc,
131 const int& ik,
132 const int* atom_nh,
133 const int* atom_na,
134 const FPTYPE* d_wg,
135 const std::complex<FPTYPE>* vu,
136 const int* orbital_corr,
137 const std::complex<FPTYPE>* becp,
138 const std::complex<FPTYPE>* dbecp,
139 FPTYPE* stress);
140 // kernel for DeltaSpin
141 void operator()(const base_device::DEVICE_CPU* ctx,
142 const int& nkb,
143 const int& nbands_occ,
144 const int& ntype,
145 const int& wg_nc,
146 const int& ik,
147 const int* atom_nh,
148 const int* atom_na,
149 const FPTYPE* d_wg,
150 const double* lambda,
151 const std::complex<FPTYPE>* becp,
152 const std::complex<FPTYPE>* dbecp,
153 FPTYPE* stress);
154};
155
156template <typename T, typename Device>
158{
159 using Real = typename GetTypeReal<T>::type;
160 void operator()(const int& spin, const int& nrxx, const Real& w1, const T* gradwfc, Real* crosstaus);
161};
162
163// cpu version first, gpu version later
164template <typename FPTYPE, typename Device>
166{
167 void operator()(const Device* ctx,
168 const int nh,
169 const int npw,
170 const int* indexes,
171 const FPTYPE* vqs_in,
172 const FPTYPE* ylms_in,
173 const std::complex<FPTYPE>* sk_in,
174 const std::complex<FPTYPE>* pref_in,
175 std::complex<FPTYPE>* vkbs_out);
176};
177
178// cpu version first, gpu version later
179template <typename FPTYPE, typename Device>
181{
182 void operator()(const Device* ctx,
183 const int nh,
184 const int npw,
185 const int ipol,
186 const int jpol,
187 const int* indexes,
188 const FPTYPE* vqs_in,
189 const FPTYPE* vqs_deri_in,
190 const FPTYPE* ylms_in,
191 const FPTYPE* ylms_deri_in,
192 const std::complex<FPTYPE>* sk_in,
193 const std::complex<FPTYPE>* pref_in,
194 const FPTYPE* gk_in,
195 std::complex<FPTYPE>* vkbs_out);
196};
197
198// cpu version first, gpu version later
199template <typename FPTYPE, typename Device>
201{
202 void operator()(const Device* ctx,
203 const FPTYPE* tab,
204 int it,
205 const FPTYPE* gk,
206 int npw,
207 const int tab_2,
208 const int tab_3,
209 const FPTYPE table_interval,
210 const int nbeta,
211 FPTYPE* vq);
212};
213
214// cpu version first, gpu version later
215template <typename FPTYPE, typename Device>
217{
218 void operator()(const Device* ctx,
219 const FPTYPE* tab,
220 int it,
221 const FPTYPE* gk,
222 int npw,
223 const int tab_2,
224 const int tab_3,
225 const FPTYPE table_interval,
226 const int nbeta,
227 FPTYPE* vq);
228};
229
230
231template <typename FPTYPE, typename Device>
234 const FPTYPE* r, const FPTYPE* rhoc,
235 const FPTYPE *gx_arr, const FPTYPE *rab, FPTYPE *drhocg,
236 const int mesh, const int igl0, const int ngg, const double omega,
237 int type
238 );
239};
240
241template <typename FPTYPE, typename Device>
243 void operator()(const std::complex<FPTYPE> *psiv,
244 const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
245 const FPTYPE* rhocgigg_vec,
246 FPTYPE* force,
247 const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_xz,
248 const int npw,
249 const FPTYPE omega, const FPTYPE tpiba
250 );
251};
252
253template <typename FPTYPE, typename Device>
255 FPTYPE operator()(const int& npw,
256 const FPTYPE& fac,
257 const FPTYPE* gk1,
258 const FPTYPE* gk2,
259 const FPTYPE* d_kfac,
260 const std::complex<FPTYPE>* psi);
261};
262
263
264#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
265template <typename FPTYPE>
266struct cal_dbecp_noevc_nl_op<FPTYPE, base_device::DEVICE_GPU>
267{
268 void operator()(const base_device::DEVICE_GPU* ctx,
269 const int& ipol,
270 const int& jpol,
271 const int& nkb,
272 const int& npw,
273 const int& npwx,
274 const int& ik,
275 const FPTYPE& tpiba,
276 const FPTYPE* gcar,
277 const FPTYPE* kvec_c,
278 std::complex<FPTYPE>* vkbi,
279 std::complex<FPTYPE>* vkbj,
280 std::complex<FPTYPE>* vkb,
281 std::complex<FPTYPE>* vkb1,
282 std::complex<FPTYPE>* vkb2,
283 std::complex<FPTYPE>* dbecp_noevc);
284};
285
286template <typename FPTYPE>
287struct cal_stress_nl_op<FPTYPE, base_device::DEVICE_GPU>
288{
289 void operator()(const base_device::DEVICE_GPU* ctx,
290 const bool& nondiagonal,
291 const int& ipol,
292 const int& jpol,
293 const int& nkb,
294 const int& nbands_occ,
295 const int& ntype,
296 const int& spin,
297 const int& deeq_2,
298 const int& deeq_3,
299 const int& deeq_4,
300 const int* atom_nh,
301 const int* atom_na,
302 const FPTYPE* d_wg,
303 const bool& occ,
304 const FPTYPE* d_ekb,
305 const FPTYPE* qq_nt,
306 const FPTYPE* deeq,
307 const std::complex<FPTYPE>* becp,
308 const std::complex<FPTYPE>* dbecp,
309 FPTYPE* stress);
310 // interface for nspin=4 only
311 void operator()(const base_device::DEVICE_GPU* ctx,
312 const int& ipol,
313 const int& jpol,
314 const int& nkb,
315 const int& nbands_occ,
316 const int& ntype,
317 const int& deeq_2,
318 const int& deeq_3,
319 const int& deeq_4,
320 const int* atom_nh,
321 const int* atom_na,
322 const FPTYPE* d_wg,
323 const bool& occ,
324 const FPTYPE* d_ekb,
325 const FPTYPE* qq_nt,
326 const std::complex<FPTYPE>* deeq_nc,
327 const std::complex<FPTYPE>* becp,
328 const std::complex<FPTYPE>* dbecp,
329 FPTYPE* stress);
330 // kernel for DFT+U
331 void operator()(const base_device::DEVICE_GPU* ctx,
332 const int& nkb,
333 const int& nbands_occ,
334 const int& ntype,
335 const int& wg_nc,
336 const int& ik,
337 const int* atom_nh,
338 const int* atom_na,
339 const FPTYPE* d_wg,
340 const std::complex<FPTYPE>* vu,
341 const int* orbital_corr,
342 const std::complex<FPTYPE>* becp,
343 const std::complex<FPTYPE>* dbecp,
344 FPTYPE* stress);
345 // kernel for DeltaSpin
346 void operator()(const base_device::DEVICE_GPU* ctx,
347 const int& nkb,
348 const int& nbands_occ,
349 const int& ntype,
350 const int& wg_nc,
351 const int& ik,
352 const int* atom_nh,
353 const int* atom_na,
354 const FPTYPE* d_wg,
355 const double* lambda,
356 const std::complex<FPTYPE>* becp,
357 const std::complex<FPTYPE>* dbecp,
358 FPTYPE* stress);
359};
360
361// cpu version first, gpu version later
362template <typename FPTYPE>
363struct cal_vkb_op<FPTYPE, base_device::DEVICE_GPU>
364{
365 void operator()(const base_device::DEVICE_GPU* ctx,
366 const int nh,
367 const int npw,
368 const int* indexes,
369 const FPTYPE* vqs_in,
370 const FPTYPE* ylms_in,
371 const std::complex<FPTYPE>* sk_in,
372 const std::complex<FPTYPE>* pref_in,
373 std::complex<FPTYPE>* vkbs_out);
374};
375
376template <typename FPTYPE>
377struct cal_vkb_deri_op<FPTYPE, base_device::DEVICE_GPU>
378{
379 void operator()(const base_device::DEVICE_GPU* ctx,
380 const int nh,
381 const int npw,
382 const int ipol,
383 const int jpol,
384 const int* indexes,
385 const FPTYPE* vqs_in,
386 const FPTYPE* vqs_deri_in,
387 const FPTYPE* ylms_in,
388 const FPTYPE* ylms_deri_in,
389 const std::complex<FPTYPE>* sk_in,
390 const std::complex<FPTYPE>* pref_in,
391 const FPTYPE* gk_in,
392 std::complex<FPTYPE>* vkbs_out);
393};
394
395// cpu version first, gpu version later
396template <typename FPTYPE>
397struct cal_vq_op<FPTYPE, base_device::DEVICE_GPU>
398{
399 void operator()(const base_device::DEVICE_GPU* ctx,
400 const FPTYPE* tab,
401 int it,
402 const FPTYPE* gk,
403 int npw,
404 const int tab_2,
405 const int tab_3,
406 const FPTYPE table_interval,
407 const int nbeta,
408 FPTYPE* vq);
409};
410
411// cpu version first, gpu version later
412template <typename FPTYPE>
413struct cal_vq_deri_op<FPTYPE, base_device::DEVICE_GPU>
414{
415 void operator()(const base_device::DEVICE_GPU* ctx,
416 const FPTYPE* tab,
417 int it,
418 const FPTYPE* gk,
419 int npw,
420 const int tab_2,
421 const int tab_3,
422 const FPTYPE table_interval,
423 const int nbeta,
424 FPTYPE* vq);
425};
426
427template <typename FPTYPE>
428struct cal_multi_dot_op<FPTYPE, base_device::DEVICE_GPU>{
429 FPTYPE operator()(const int& npw,
430 const FPTYPE& fac,
431 const FPTYPE* gk1,
432 const FPTYPE* gk2,
433 const FPTYPE* d_kfac,
434 const std::complex<FPTYPE>* psi);
435};
436
460template <typename FPTYPE>
461struct cal_stress_drhoc_aux_op<FPTYPE, base_device::DEVICE_GPU>{
462 void operator()(
463 const FPTYPE* r, const FPTYPE* rhoc,
464 const FPTYPE *gx_arr, const FPTYPE *rab, FPTYPE *drhocg,
465 const int mesh, const int igl0, const int ngg, const double omega,
466 int type
467 );
468};
469
470
480template <typename FPTYPE>
481struct cal_force_npw_op<FPTYPE, base_device::DEVICE_GPU>{
482 void operator()(const std::complex<FPTYPE> *psiv,
483 const FPTYPE* gv_x, const FPTYPE* gv_y, const FPTYPE* gv_z,
484 const FPTYPE* rhocgigg_vec,
485 FPTYPE* force,
486 const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_xz,
487 const int npw,
488 const FPTYPE omega, const FPTYPE tpiba
489 );
490};
491
492
493
494#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
495
496template <typename Device>
498{
499 void operator()(void** ptr, const int n);
500};
501
502template <typename Device>
504{
505 void operator()(void** ptr_out, const void** ptr_in, const int size);
506};
507
508} // namespace hamilt
509#endif // SRC_PW_STRESS_MULTI_DEVICE_H
#define T
Definition exp.cpp:237
Definition device.cpp:21
Definition hamilt.h:12
Definition exx_lip.h:23
T type
Definition macros.h:8
Definition stress_op.h:15
void operator()(const Device *ctx, const int &ipol, const int &jpol, const int &nkb, const int &npw, const int &npwx, const int &ik, const FPTYPE &tpiba, const FPTYPE *gcar, const FPTYPE *kvec_c, std::complex< FPTYPE > *vkbi, std::complex< FPTYPE > *vkbj, std::complex< FPTYPE > *vkb, std::complex< FPTYPE > *vkb1, std::complex< FPTYPE > *vkb2, std::complex< FPTYPE > *dbecp_noevc)
The prestep to calculate the final stresses.
Definition stress_op.h:242
void operator()(const std::complex< FPTYPE > *psiv, const FPTYPE *gv_x, const FPTYPE *gv_y, const FPTYPE *gv_z, const FPTYPE *rhocgigg_vec, FPTYPE *force, const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_xz, const int npw, const FPTYPE omega, const FPTYPE tpiba)
Definition stress_op.h:254
FPTYPE operator()(const int &npw, const FPTYPE &fac, const FPTYPE *gk1, const FPTYPE *gk2, const FPTYPE *d_kfac, const std::complex< FPTYPE > *psi)
Definition stress_op.h:232
void operator()(const FPTYPE *r, const FPTYPE *rhoc, const FPTYPE *gx_arr, const FPTYPE *rab, FPTYPE *drhocg, const int mesh, const int igl0, const int ngg, const double omega, int type)
Definition stress_op.h:158
void operator()(const int &spin, const int &nrxx, const Real &w1, const T *gradwfc, Real *crosstaus)
Definition stress_op.cpp:363
typename GetTypeReal< T >::type Real
Definition stress_op.h:159
Definition stress_op.h:57
void operator()(const Device *ctx, const bool &nondiagonal, const int &ipol, const int &jpol, const int &nkb, const int &nbands_occ, const int &ntype, const int &spin, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const FPTYPE *deeq, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
Calculate the final stresses for multi-device.
void operator()(const Device *ctx, const int &ipol, const int &jpol, const int &nkb, const int &nbands_occ, const int &ntype, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const std::complex< FPTYPE > *deeq_nc, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
void operator()(const base_device::DEVICE_CPU *ctx, const int &nkb, const int &nbands_occ, const int &ntype, const int &wg_nc, const int &ik, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const double *lambda, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
void operator()(const base_device::DEVICE_CPU *ctx, const int &nkb, const int &nbands_occ, const int &ntype, const int &wg_nc, const int &ik, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const std::complex< FPTYPE > *vu, const int *orbital_corr, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
Definition stress_op.h:181
void operator()(const Device *ctx, const int nh, const int npw, const int ipol, const int jpol, const int *indexes, const FPTYPE *vqs_in, const FPTYPE *vqs_deri_in, const FPTYPE *ylms_in, const FPTYPE *ylms_deri_in, const std::complex< FPTYPE > *sk_in, const std::complex< FPTYPE > *pref_in, const FPTYPE *gk_in, std::complex< FPTYPE > *vkbs_out)
Definition stress_op.h:166
void operator()(const Device *ctx, const int nh, const int npw, const int *indexes, const FPTYPE *vqs_in, const FPTYPE *ylms_in, const std::complex< FPTYPE > *sk_in, const std::complex< FPTYPE > *pref_in, std::complex< FPTYPE > *vkbs_out)
Definition stress_op.h:217
void operator()(const Device *ctx, const FPTYPE *tab, int it, const FPTYPE *gk, int npw, const int tab_2, const int tab_3, const FPTYPE table_interval, const int nbeta, FPTYPE *vq)
Definition stress_op.h:201
void operator()(const Device *ctx, const FPTYPE *tab, int it, const FPTYPE *gk, int npw, const int tab_2, const int tab_3, const FPTYPE table_interval, const int nbeta, FPTYPE *vq)
Definition stress_op.h:498
void operator()(void **ptr, const int n)
Definition stress_op.h:504
void operator()(void **ptr_out, const void **ptr_in, const int size)