1#ifndef SRC_PW_STRESS_MULTI_DEVICE_H
2#define SRC_PW_STRESS_MULTI_DEVICE_H
13template <
typename FPTYPE,
typename Device>
47 std::complex<FPTYPE>* vkbi,
48 std::complex<FPTYPE>* vkbj,
49 std::complex<FPTYPE>* vkb,
50 std::complex<FPTYPE>* vkb1,
51 std::complex<FPTYPE>* vkb2,
52 std::complex<FPTYPE>* dbecp_noevc);
55template <
typename FPTYPE,
typename Device>
85 const bool& nondiagonal,
89 const int& nbands_occ,
102 const std::complex<FPTYPE>* becp,
103 const std::complex<FPTYPE>* dbecp,
110 const int& nbands_occ,
121 const std::complex<FPTYPE>* deeq_nc,
122 const std::complex<FPTYPE>* becp,
123 const std::complex<FPTYPE>* dbecp,
128 const int& nbands_occ,
135 const std::complex<FPTYPE>* vu,
136 const int* orbital_corr,
137 const std::complex<FPTYPE>* becp,
138 const std::complex<FPTYPE>* dbecp,
143 const int& nbands_occ,
150 const double* lambda,
151 const std::complex<FPTYPE>* becp,
152 const std::complex<FPTYPE>* dbecp,
156template <
typename T,
typename Device>
160 void operator()(
const int& spin,
const int& nrxx,
const Real& w1,
const T* gradwfc,
Real* crosstaus);
164template <
typename FPTYPE,
typename Device>
171 const FPTYPE* vqs_in,
172 const FPTYPE* ylms_in,
173 const std::complex<FPTYPE>* sk_in,
174 const std::complex<FPTYPE>* pref_in,
175 std::complex<FPTYPE>* vkbs_out);
179template <
typename FPTYPE,
typename Device>
188 const FPTYPE* vqs_in,
189 const FPTYPE* vqs_deri_in,
190 const FPTYPE* ylms_in,
191 const FPTYPE* ylms_deri_in,
192 const std::complex<FPTYPE>* sk_in,
193 const std::complex<FPTYPE>* pref_in,
195 std::complex<FPTYPE>* vkbs_out);
199template <
typename FPTYPE,
typename Device>
209 const FPTYPE table_interval,
215template <
typename FPTYPE,
typename Device>
225 const FPTYPE table_interval,
231template <
typename FPTYPE,
typename Device>
234 const FPTYPE* r,
const FPTYPE* rhoc,
235 const FPTYPE *gx_arr,
const FPTYPE *rab, FPTYPE *drhocg,
236 const int mesh,
const int igl0,
const int ngg,
const double omega,
241template <
typename FPTYPE,
typename Device>
244 const FPTYPE* gv_x,
const FPTYPE* gv_y,
const FPTYPE* gv_z,
245 const FPTYPE* rhocgigg_vec,
247 const FPTYPE pos_x,
const FPTYPE pos_y,
const FPTYPE pos_xz,
249 const FPTYPE omega,
const FPTYPE tpiba
253template <
typename FPTYPE,
typename Device>
259 const FPTYPE* d_kfac,
260 const std::complex<FPTYPE>*
psi);
264#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
265template <
typename FPTYPE>
268 void operator()(
const base_device::DEVICE_GPU* ctx,
277 const FPTYPE* kvec_c,
278 std::complex<FPTYPE>* vkbi,
279 std::complex<FPTYPE>* vkbj,
280 std::complex<FPTYPE>* vkb,
281 std::complex<FPTYPE>* vkb1,
282 std::complex<FPTYPE>* vkb2,
283 std::complex<FPTYPE>* dbecp_noevc);
286template <
typename FPTYPE>
287struct cal_stress_nl_op<FPTYPE,
base_device::DEVICE_GPU>
289 void operator()(
const base_device::DEVICE_GPU* ctx,
290 const bool& nondiagonal,
294 const int& nbands_occ,
307 const std::complex<FPTYPE>* becp,
308 const std::complex<FPTYPE>* dbecp,
311 void operator()(
const base_device::DEVICE_GPU* ctx,
315 const int& nbands_occ,
326 const std::complex<FPTYPE>* deeq_nc,
327 const std::complex<FPTYPE>* becp,
328 const std::complex<FPTYPE>* dbecp,
331 void operator()(
const base_device::DEVICE_GPU* ctx,
333 const int& nbands_occ,
340 const std::complex<FPTYPE>* vu,
341 const int* orbital_corr,
342 const std::complex<FPTYPE>* becp,
343 const std::complex<FPTYPE>* dbecp,
346 void operator()(
const base_device::DEVICE_GPU* ctx,
348 const int& nbands_occ,
355 const double* lambda,
356 const std::complex<FPTYPE>* becp,
357 const std::complex<FPTYPE>* dbecp,
362template <
typename FPTYPE>
365 void operator()(
const base_device::DEVICE_GPU* ctx,
369 const FPTYPE* vqs_in,
370 const FPTYPE* ylms_in,
371 const std::complex<FPTYPE>* sk_in,
372 const std::complex<FPTYPE>* pref_in,
373 std::complex<FPTYPE>* vkbs_out);
376template <
typename FPTYPE>
377struct cal_vkb_deri_op<FPTYPE,
base_device::DEVICE_GPU>
379 void operator()(
const base_device::DEVICE_GPU* ctx,
385 const FPTYPE* vqs_in,
386 const FPTYPE* vqs_deri_in,
387 const FPTYPE* ylms_in,
388 const FPTYPE* ylms_deri_in,
389 const std::complex<FPTYPE>* sk_in,
390 const std::complex<FPTYPE>* pref_in,
392 std::complex<FPTYPE>* vkbs_out);
396template <
typename FPTYPE>
399 void operator()(
const base_device::DEVICE_GPU* ctx,
406 const FPTYPE table_interval,
412template <
typename FPTYPE>
413struct cal_vq_deri_op<FPTYPE,
base_device::DEVICE_GPU>
415 void operator()(
const base_device::DEVICE_GPU* ctx,
422 const FPTYPE table_interval,
427template <
typename FPTYPE>
428struct cal_multi_dot_op<FPTYPE,
base_device::DEVICE_GPU>{
433 const FPTYPE* d_kfac,
434 const std::complex<FPTYPE>*
psi);
460template <
typename FPTYPE>
461struct cal_stress_drhoc_aux_op<FPTYPE,
base_device::DEVICE_GPU>{
463 const FPTYPE* r,
const FPTYPE* rhoc,
464 const FPTYPE *gx_arr,
const FPTYPE *rab, FPTYPE *drhocg,
465 const int mesh,
const int igl0,
const int ngg,
const double omega,
480template <
typename FPTYPE>
481struct cal_force_npw_op<FPTYPE,
base_device::DEVICE_GPU>{
482 void operator()(
const std::complex<FPTYPE> *psiv,
483 const FPTYPE* gv_x,
const FPTYPE* gv_y,
const FPTYPE* gv_z,
484 const FPTYPE* rhocgigg_vec,
486 const FPTYPE pos_x,
const FPTYPE pos_y,
const FPTYPE pos_xz,
488 const FPTYPE omega,
const FPTYPE tpiba
496template <
typename Device>
502template <
typename Device>
505 void operator()(
void** ptr_out,
const void** ptr_in,
const int size);
#define T
Definition exp.cpp:237
T type
Definition macros.h:8
Definition stress_op.h:15
void operator()(const Device *ctx, const int &ipol, const int &jpol, const int &nkb, const int &npw, const int &npwx, const int &ik, const FPTYPE &tpiba, const FPTYPE *gcar, const FPTYPE *kvec_c, std::complex< FPTYPE > *vkbi, std::complex< FPTYPE > *vkbj, std::complex< FPTYPE > *vkb, std::complex< FPTYPE > *vkb1, std::complex< FPTYPE > *vkb2, std::complex< FPTYPE > *dbecp_noevc)
The prestep to calculate the final stresses.
Definition stress_op.h:242
void operator()(const std::complex< FPTYPE > *psiv, const FPTYPE *gv_x, const FPTYPE *gv_y, const FPTYPE *gv_z, const FPTYPE *rhocgigg_vec, FPTYPE *force, const FPTYPE pos_x, const FPTYPE pos_y, const FPTYPE pos_xz, const int npw, const FPTYPE omega, const FPTYPE tpiba)
Definition stress_op.h:254
FPTYPE operator()(const int &npw, const FPTYPE &fac, const FPTYPE *gk1, const FPTYPE *gk2, const FPTYPE *d_kfac, const std::complex< FPTYPE > *psi)
Definition stress_op.h:232
void operator()(const FPTYPE *r, const FPTYPE *rhoc, const FPTYPE *gx_arr, const FPTYPE *rab, FPTYPE *drhocg, const int mesh, const int igl0, const int ngg, const double omega, int type)
Definition stress_op.h:158
void operator()(const int &spin, const int &nrxx, const Real &w1, const T *gradwfc, Real *crosstaus)
Definition stress_op.cpp:363
typename GetTypeReal< T >::type Real
Definition stress_op.h:159
Definition stress_op.h:57
void operator()(const Device *ctx, const bool &nondiagonal, const int &ipol, const int &jpol, const int &nkb, const int &nbands_occ, const int &ntype, const int &spin, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const FPTYPE *deeq, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
Calculate the final stresses for multi-device.
void operator()(const Device *ctx, const int &ipol, const int &jpol, const int &nkb, const int &nbands_occ, const int &ntype, const int &deeq_2, const int &deeq_3, const int &deeq_4, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const bool &occ, const FPTYPE *d_ekb, const FPTYPE *qq_nt, const std::complex< FPTYPE > *deeq_nc, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
void operator()(const base_device::DEVICE_CPU *ctx, const int &nkb, const int &nbands_occ, const int &ntype, const int &wg_nc, const int &ik, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const double *lambda, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
void operator()(const base_device::DEVICE_CPU *ctx, const int &nkb, const int &nbands_occ, const int &ntype, const int &wg_nc, const int &ik, const int *atom_nh, const int *atom_na, const FPTYPE *d_wg, const std::complex< FPTYPE > *vu, const int *orbital_corr, const std::complex< FPTYPE > *becp, const std::complex< FPTYPE > *dbecp, FPTYPE *stress)
Definition stress_op.h:181
void operator()(const Device *ctx, const int nh, const int npw, const int ipol, const int jpol, const int *indexes, const FPTYPE *vqs_in, const FPTYPE *vqs_deri_in, const FPTYPE *ylms_in, const FPTYPE *ylms_deri_in, const std::complex< FPTYPE > *sk_in, const std::complex< FPTYPE > *pref_in, const FPTYPE *gk_in, std::complex< FPTYPE > *vkbs_out)
Definition stress_op.h:166
void operator()(const Device *ctx, const int nh, const int npw, const int *indexes, const FPTYPE *vqs_in, const FPTYPE *ylms_in, const std::complex< FPTYPE > *sk_in, const std::complex< FPTYPE > *pref_in, std::complex< FPTYPE > *vkbs_out)
Definition stress_op.h:217
void operator()(const Device *ctx, const FPTYPE *tab, int it, const FPTYPE *gk, int npw, const int tab_2, const int tab_3, const FPTYPE table_interval, const int nbeta, FPTYPE *vq)
Definition stress_op.h:201
void operator()(const Device *ctx, const FPTYPE *tab, int it, const FPTYPE *gk, int npw, const int tab_2, const int tab_3, const FPTYPE table_interval, const int nbeta, FPTYPE *vq)
Definition stress_op.h:498
void operator()(void **ptr, const int n)
Definition stress_op.h:504
void operator()(void **ptr_out, const void **ptr_in, const int size)