ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
vnl_op.h
Go to the documentation of this file.
1#ifndef W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_VNL_OP_H
2#define W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_VNL_OP_H
3
4#include "source_psi/psi.h"
5
6#include <complex>
7
8namespace hamilt
9{
10
11template <typename FPTYPE, typename Device>
13{
40 void operator()(const Device* ctx,
41 const int& ntype,
42 const int& npw,
43 const int& npwx,
44 const int& nhm,
45 const int& tab_2,
46 const int& tab_3,
47 const int* atom_na,
48 const int* atom_nb,
49 const int* atom_nh,
50 const FPTYPE& DQ,
51 const FPTYPE& tpiba,
52 const std::complex<FPTYPE>& NEG_IMAG_UNIT,
53 const FPTYPE* gk,
54 const FPTYPE* ylm,
55 const FPTYPE* indv,
56 const FPTYPE* nhtol,
57 const FPTYPE* nhtolm,
58 const FPTYPE* tab,
59 FPTYPE* vkb1,
60 const std::complex<FPTYPE>* sk,
61 std::complex<FPTYPE>* vkb_in);
62};
63
64#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
65template <typename FPTYPE>
66struct cal_vnl_op<FPTYPE, base_device::DEVICE_GPU>
67{
68 void operator()(const base_device::DEVICE_GPU* ctx,
69 const int& ntype,
70 const int& npw,
71 const int& npwx,
72 const int& nhm,
73 const int& tab_2,
74 const int& tab_3,
75 const int* atom_na,
76 const int* atom_nb,
77 const int* atom_nh,
78 const FPTYPE& DQ,
79 const FPTYPE& tpiba,
80 const std::complex<FPTYPE>& NEG_IMAG_UNIT,
81 const FPTYPE* gk,
82 const FPTYPE* ylm,
83 const FPTYPE* indv,
84 const FPTYPE* nhtol,
85 const FPTYPE* nhtolm,
86 const FPTYPE* tab,
87 FPTYPE* vkb1,
88 const std::complex<FPTYPE>* sk,
89 std::complex<FPTYPE>* vkb_in);
90};
91#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
92} // namespace hamilt
93#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_VNL_OP_H
Definition device.cpp:21
Definition hamilt.h:12
Definition vnl_op.h:13
void operator()(const Device *ctx, const int &ntype, const int &npw, const int &npwx, const int &nhm, const int &tab_2, const int &tab_3, const int *atom_na, const int *atom_nb, const int *atom_nh, const FPTYPE &DQ, const FPTYPE &tpiba, const std::complex< FPTYPE > &NEG_IMAG_UNIT, const FPTYPE *gk, const FPTYPE *ylm, const FPTYPE *indv, const FPTYPE *nhtol, const FPTYPE *nhtolm, const FPTYPE *tab, FPTYPE *vkb1, const std::complex< FPTYPE > *sk, std::complex< FPTYPE > *vkb_in)
Calculate the getvnl for multi-device.