1#ifndef HAMILTPW_NONLOCAL_MATHS_H
2#define HAMILTPW_NONLOCAL_MATHS_H
15template <
typename FPTYPE,
typename Device>
21 this->
device = base_device::get_device_type<Device>(this->
ctx);
28 this->
device = base_device::get_device_type<Device>(this->
ctx);
65 void cal_ylm(
int lmax,
int npw,
const FPTYPE* gk_in, FPTYPE* ylm);
67 void cal_ylm_deri(
int lmax,
int npw,
const FPTYPE* gk_in, FPTYPE* ylm_deri);
69 std::vector <std::complex<FPTYPE>>
cal_pref(
int it,
const int nh);
77 const std::complex<FPTYPE>* sk_in,
78 const std::complex<FPTYPE>* pref_in,
79 std::complex<FPTYPE>* vkb_out);
87 const FPTYPE* vq_deri_in,
89 const FPTYPE* ylm_deri_in,
90 const std::complex<FPTYPE>* sk_in,
91 const std::complex<FPTYPE>* pref_in,
93 std::complex<FPTYPE>* vkb_out);
101 std::complex<FPTYPE>* vkb_out,
102 std::complex<FPTYPE>** vkb_ptrs,
119 static void dylmr2(
const int nylm,
const int ngy,
const FPTYPE* gk, FPTYPE* dylm,
const int ipol);
124 const FPTYPE& table_interval,
130template <
typename FPTYPE,
typename Device>
133 int npw = pw_basis->
npwk[ik];
134 std::vector<FPTYPE> gk(npw * 5);
136 for (
int ig = 0; ig < npw; ++ig)
141 gk[ig * 3 + 1] = q.
y;
142 gk[ig * 3 + 2] = q.
z;
146 gk[3 * npw + ig] =
norm * this->ucell_->tpiba;
147 gk[4 * npw + ig] =
norm < 1e-8 ? 0.0 : 1.0 /
norm * this->ucell_->tpiba;
154template <
typename FPTYPE,
typename Device>
157 const int ntot_ylm = (lmax + 1) * (lmax + 1);
163 std::vector<FPTYPE> ylm_cpu(ntot_ylm * npw);
167 syncmem_var_h2d_op()(ylm, ylm_cpu.data(), ylm_cpu.size());
179template <
typename FPTYPE,
typename Device>
182 const int ntot_ylm = (lmax + 1) * (lmax + 1);
189 std::vector<FPTYPE> dylmdq_cpu(3 * ntot_ylm * npw);
191 for (
int ipol = 0; ipol < 3; ipol++)
196 syncmem_var_h2d_op()(out, dylmdq_cpu.data(), dylmdq_cpu.size());
200 for (
int ipol = 0; ipol < 3; ipol++)
209template <
typename FPTYPE,
typename Device>
215 std::vector<std::complex<FPTYPE>> pref(nh);
216 for (
int ih = 0; ih < nh; ih++)
218 pref[ih] = std::pow(std::complex<FPTYPE>(0.0, -1.0), this->nhtol_(it, ih));
226template <
typename FPTYPE,
typename Device>
231 const FPTYPE* ylm_in,
232 const std::complex<FPTYPE>* sk_in,
233 const std::complex<FPTYPE>* pref_in,
234 std::complex<FPTYPE>* vkb_out)
238 for (
int ib = 0; ib < this->ucell_->atoms[it].ncpp.nbeta; ib++)
240 int l = this->nhtol_(it, ih);
242 for (
int m = 0; m < 2 * l + 1; m++)
245 std::complex<FPTYPE>* vkb_ptr = &vkb_out[ih * npw];
246 const FPTYPE* ylm_ptr = &ylm_in[lm * npw];
247 const FPTYPE* vq_ptr = &vq_in[ib * npw];
249 for (
int ig = 0; ig < npw; ig++)
251 vkb_ptr[ig] = ylm_ptr[ig] * vq_ptr[ig] * sk_in[ig] * pref_in[ih];
260template <
typename FPTYPE,
typename Device>
267 const FPTYPE* vq_deri_in,
268 const FPTYPE* ylm_in,
269 const FPTYPE* ylm_deri_in,
270 const std::complex<FPTYPE>* sk_in,
271 const std::complex<FPTYPE>* pref_in,
273 std::complex<FPTYPE>* vkb_out)
275 const int x1 = (this->lmax_ + 1) * (this->lmax_ + 1);
278 for (
int nb = 0; nb < this->ucell_->atoms[it].ncpp.nbeta; nb++)
280 const int l = this->nhtol_(it, ih);
282 for (
int m = 0; m < 2 * l + 1; m++)
284 const int lm = l * l + m;
285 std::complex<FPTYPE>* vkb_ptr = &vkb_out[ih * npw];
286 const FPTYPE* ylm_ptr = &ylm_in[lm * npw];
287 const FPTYPE* vq_ptr = &vq_in[nb * npw];
289 for (
int ig = 0; ig < npw; ig++)
291 vkb_ptr[ig] = std::complex<FPTYPE>(0.0, 0.0);
297 for (
int ig = 0; ig < npw; ig++)
299 vkb_ptr[ig] -= ylm_ptr[ig] * vq_ptr[ig] * sk_in[ig] * pref_in[ih];
304 const FPTYPE* ylm_deri_ptr1 = &ylm_deri_in[(ipol * x1 + lm) * npw];
305 const FPTYPE* ylm_deri_ptr2 = &ylm_deri_in[(jpol * x1 + lm) * npw];
306 const FPTYPE* vq_deri_ptr = &vq_deri_in[nb * npw];
307 const FPTYPE* qnorm = &gk_in[4 * npw];
308 for (
int ig = 0; ig < npw; ig++)
310 vkb_ptr[ig] -= (gk_in[ig * 3 + ipol] * ylm_deri_ptr2[ig] + gk_in[ig * 3 + jpol] * ylm_deri_ptr1[ig])
311 * vq_ptr[ig] * sk_in[ig] * pref_in[ih];
315 for (
int ig = 0; ig < npw; ig++)
317 vkb_ptr[ig] -= 2.0 * ylm_ptr[ig] * vq_deri_ptr[ig] * sk_in[ig] * pref_in[ih] * gk_in[ig * 3 + ipol]
318 * gk_in[ig * 3 + jpol] * qnorm[ig];
325template <
typename FPTYPE,
typename Device>
331 std::complex<FPTYPE>* vkb_out,
332 std::complex<FPTYPE>** vkb_ptrs,
342 for (
int nb = 0; nb < nbeta; nb++)
344 int l = nhtol[it * nhtol_nc + ih];
345 for (
int m = 0; m < 2 * l + 1; m++)
348 vkb_ptrs[ih] = &vkb_out[ih * npw];
349 ylm_ptrs[ih] = &ylm_in[lm * npw];
350 vq_ptrs[ih] = &vq_in[nb * npw];
356template <
typename FPTYPE,
typename Device>
367 const int x1 = (this->lmax_ + 1) * (this->lmax_ + 1);
368 for (
int nb = 0; nb < nbeta; nb++)
370 int l = nhtol[it * nhtol_nc + ih];
371 for (
int m = 0; m < 2 * l + 1; m++)
375 indexes[ih * 4] = lm;
376 indexes[ih * 4 + 1] = nb;
377 indexes[ih * 4 + 2] = (ipol * x1 + lm);
378 indexes[ih * 4 + 3] = (jpol * x1 + lm);
385template <
typename FPTYPE,
typename Device>
407 const FPTYPE delta = 1e-6;
408 const FPTYPE small = 1e-15;
416 std::vector<FPTYPE> gx(ngy * 3);
418 std::vector<FPTYPE> dg(ngy);
419 std::vector<FPTYPE> dgi(ngy);
427#pragma omp parallel for
429 for (
int ig = 0; ig < 3 * ngy; ig++)
435#pragma omp parallel for
437 for (
int ig = 0; ig < ngy; ig++)
439 const int igx = ig * 3, igy = ig * 3 + 1, igz = ig * 3 + 2;
440 FPTYPE norm2 = gx[igx] * gx[igx] + gx[igy] * gx[igy] + gx[igz] * gx[igz];
441 dg[ig] = delta * sqrt(norm2);
444 dgi[ig] = 1.0 / dg[ig];
455#pragma omp parallel for
457 for (
int ig = 0; ig < ngy; ig++)
459 const int index = ig * 3 + ipol;
460 gx[index] = gk[index] + dg[ig];
464 base_device::DEVICE_CPU* cpu = {};
468#pragma omp parallel for
470 for (
int ig = 0; ig < ngy; ig++)
472 const int index = ig * 3 + ipol;
473 gx[index] = gk[index] - dg[ig];
481#pragma omp parallel for collapse(2)
483 for (
int lm = 0; lm < nylm; lm++)
485 for (
int ig = 0; ig < ngy; ig++)
487 dylm[lm * ngy + ig] -= ylmaux(lm, ig);
488 dylm[lm * ngy + ig] *= 0.5 * dgi[ig];
495template <
typename FPTYPE,
typename Device>
499 const FPTYPE& table_interval,
504 assert(table_interval > 0.0);
505 const FPTYPE position = x / table_interval;
506 const int iq =
static_cast<int>(position);
508 const FPTYPE x0 = position -
static_cast<FPTYPE
>(iq);
509 const FPTYPE x1 = 1.0 - x0;
510 const FPTYPE x2 = 2.0 - x0;
511 const FPTYPE x3 = 3.0 - x0;
512 const FPTYPE y = (table(dim1, dim2, iq) * (-x2 * x3 - x1 * x3 - x1 * x2) / 6.0
513 + table(dim1, dim2, iq + 1) * (+x2 * x3 - x0 * x3 - x0 * x2) / 2.0
514 - table(dim1, dim2, iq + 2) * (+x1 * x3 - x0 * x3 - x0 * x1) / 2.0
515 + table(dim1, dim2, iq + 3) * (+x1 * x2 - x0 * x2 - x0 * x1) / 6.0)
3 elements vector
Definition vector3.h:22
T norm2(void) const
Get the square of nomr of a Vector3.
Definition vector3.h:177
T x
Definition vector3.h:24
T y
Definition vector3.h:25
T z
Definition vector3.h:26
static void Ylm_Real(const int lmax2, const int ng, const ModuleBase::Vector3< double > *g, matrix &ylm)
spherical harmonic function (real form) an array of vectors
Definition math_ylmreal.cpp:357
void zero_out(void)
Definition matrix.cpp:281
double * c
Definition matrix.h:25
void create(const int nrow, const int ncol, const bool flag_zero=true)
Definition matrix.cpp:122
double float array
Definition realarray.h:21
Special pw_basis class. It includes different k-points.
Definition pw_basis_k.h:57
int * npwk
Definition pw_basis_k.h:78
ModuleBase::Vector3< double > getgpluskcar(const int ik, const int igl) const
Definition pw_basis_k.cpp:390
Definition nonlocal_maths.hpp:17
static void dylmr2(const int nylm, const int ngy, const FPTYPE *gk, FPTYPE *dylm, const int ipol)
Definition nonlocal_maths.hpp:386
void cal_ylm_deri(int lmax, int npw, const FPTYPE *gk_in, FPTYPE *ylm_deri)
calculate the derivate of the sperical bessel function for projections
Definition nonlocal_maths.hpp:180
base_device::DEVICE_CPU * cpu_ctx
Definition nonlocal_maths.hpp:40
void cal_vkb(int it, int ia, int npw, const FPTYPE *vq_in, const FPTYPE *ylm_in, const std::complex< FPTYPE > *sk_in, const std::complex< FPTYPE > *pref_in, std::complex< FPTYPE > *vkb_out)
Definition nonlocal_maths.hpp:227
Nonlocal_maths(const ModuleBase::matrix &nhtol, const int lmax, const UnitCell *ucell_in)
Definition nonlocal_maths.hpp:26
static FPTYPE Polynomial_Interpolation_nl(const ModuleBase::realArray &table, const int &dim1, const int &dim2, const FPTYPE &table_interval, const FPTYPE &x)
polynomial interpolation tool for calculate derivate of vq
Definition nonlocal_maths.hpp:496
int lmax_
Definition nonlocal_maths.hpp:36
base_device::AbacusDevice_t device
Definition nonlocal_maths.hpp:41
void prepare_vkb_ptr(int nbeta, double *nhtol, int nhtol_nc, int npw, int it, std::complex< FPTYPE > *vkb_out, std::complex< FPTYPE > **vkb_ptrs, FPTYPE *ylm_in, FPTYPE **ylm_ptrs, FPTYPE *vq_in, FPTYPE **vq_ptrs)
calculate the ptr used in vkb_op
Definition nonlocal_maths.hpp:326
const UnitCell * ucell_
Definition nonlocal_maths.hpp:37
void cal_dvkb_index(const int nbeta, const double *nhtol, const int nhtol_nc, const int npw, const int it, const int ipol, const int jpol, int *indexes)
Definition nonlocal_maths.hpp:357
std::vector< std::complex< FPTYPE > > cal_pref(int it, const int nh)
calculate the (-i)^l factors
Definition nonlocal_maths.hpp:210
Device * ctx
Definition nonlocal_maths.hpp:39
Nonlocal_maths(const pseudopot_cell_vnl *nlpp_in, const UnitCell *ucell_in)
Definition nonlocal_maths.hpp:19
void cal_ylm(int lmax, int npw, const FPTYPE *gk_in, FPTYPE *ylm)
calculate the real spherical harmonic functions on cpu (and optionally send to gpu,...
Definition nonlocal_maths.hpp:155
ModuleBase::matrix nhtol_
Definition nonlocal_maths.hpp:35
std::vector< FPTYPE > cal_gk(int ik, const ModulePW::PW_Basis_K *pw_basis)
this function prepares all the q (G+k) information in one contiguous memory block including the x,...
Definition nonlocal_maths.hpp:131
void cal_vkb_deri(int it, int ia, int npw, int ipol, int jpol, const FPTYPE *vq_in, const FPTYPE *vq_deri_in, const FPTYPE *ylm_in, const FPTYPE *ylm_deri_in, const std::complex< FPTYPE > *sk_in, const std::complex< FPTYPE > *pref_in, const FPTYPE *gk_in, std::complex< FPTYPE > *vkb_out)
calculate the dvkb matrix for this atom
Definition nonlocal_maths.hpp:261
Definition VNL_in_pw.h:21
int lmaxkb
Definition VNL_in_pw.h:35
ModuleBase::matrix nhtol
Definition VNL_in_pw.h:68
void ZEROS(std::complex< T > *u, const TI n)
Definition global_function.h:109
AbacusDevice_t
Definition types.h:12
@ GpuDevice
Definition types.h:15
Definition memory_op.h:45
double norm(const Vec3 &v)
Definition test_partition.cpp:25