ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_vdpre.h
Go to the documentation of this file.
1#ifndef DEEPKS_VDPRE_H
2#define DEEPKS_VDPRE_H
3
4#ifdef __MLALGO
5
6#include "deepks_param.h"
10#include "source_base/timer.h"
15
16#include <torch/script.h>
17#include <torch/torch.h>
18
19namespace DeePKS_domain
20{
21//------------------------
22// deepks_vdpre.cpp
23//------------------------
24
25// This file contains 3 subroutines for calculating v_delta,
26// 1. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
27// which equals gevdm * v_delta_pdm,
28// v_delta_pdm = overlap * overlap
29// 2. prepare_phialpha : prepare phialpha for outputting npy file
30// 3. prepare_gevdm : prepare gevdm for outputting npy file
31
32// for deepks_v_delta = 1
33// calculates v_delta_precalc
34template <typename TK>
35void cal_v_delta_precalc(const int nlocal,
36 const int nat,
37 const int nks,
38 const DeePKS_Param& deepks_param,
39 const std::vector<ModuleBase::Vector3<double>>& kvec_d,
40 const std::vector<hamilt::HContainer<double>*> phialpha,
41 const std::vector<torch::Tensor> gevdm,
42 const UnitCell& ucell,
43 const LCAO_Orbitals& orb,
44 const Parallel_Orbitals& pv,
45 const Grid_Driver& GridD,
46 torch::Tensor& v_delta_precalc);
47
48// for deepks_v_delta = 2
49// prepare phialpha for outputting npy file
50template <typename TK>
51void prepare_phialpha(const int nlocal,
52 const int nat,
53 const int nks,
54 const DeePKS_Param& deepks_param,
55 const std::vector<ModuleBase::Vector3<double>>& kvec_d,
56 const std::vector<hamilt::HContainer<double>*> phialpha,
57 const UnitCell& ucell,
58 const LCAO_Orbitals& orb,
59 const Parallel_Orbitals& pv,
60 const Grid_Driver& GridD,
61 torch::Tensor& phialpha_out);
62
63// prepare gevdm for outputting npy file
64void prepare_gevdm(const int nat,
65 const DeePKS_Param& deepks_param,
66 const LCAO_Orbitals& orb,
67 const std::vector<torch::Tensor>& gevdm_in,
68 torch::Tensor& gevdm_out);
69} // namespace DeePKS_domain
70#endif
71#endif
Definition sltk_grid_driver.h:43
Definition ORB_read.h:19
3 elements vector
Definition vector3.h:22
Definition parallel_orbitals.h:9
Definition unitcell.h:17
Definition hcontainer.h:144
Definition deepks_basic.h:15
void cal_v_delta_precalc(const int nlocal, const int nat, const int nks, const DeePKS_Param &deepks_param, const std::vector< ModuleBase::Vector3< double > > &kvec_d, const std::vector< hamilt::HContainer< double > * > phialpha, const std::vector< torch::Tensor > gevdm, const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD, torch::Tensor &v_delta_precalc)
cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
Definition deepks_vdpre.cpp:24
void prepare_gevdm(const int nat, const DeePKS_Param &deepks_param, const LCAO_Orbitals &orb, const std::vector< torch::Tensor > &gevdm_in, torch::Tensor &gevdm_out)
Definition deepks_vdpre.cpp:268
void prepare_phialpha(const int nlocal, const int nat, const int nks, const DeePKS_Param &deepks_param, const std::vector< ModuleBase::Vector3< double > > &kvec_d, const std::vector< hamilt::HContainer< double > * > phialpha, const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD, torch::Tensor &phialpha_out)
Definition deepks_vdpre.cpp:174
Definition deepks_param.h:11