ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_vdpre.h
Go to the documentation of this file.
1#ifndef DEEPKS_VDPRE_H
2#define DEEPKS_VDPRE_H
3
4#ifdef __MLALGO
5
6#include "deepks_param.h"
10#include "source_base/timer.h"
14
15#include <torch/script.h>
16#include <torch/torch.h>
17
18namespace DeePKS_domain
19{
20//------------------------
21// deepks_vdpre.cpp
22//------------------------
23
24// This file contains 3 subroutines for calculating v_delta,
25// 1. cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
26// which equals gevdm * v_delta_pdm,
27// v_delta_pdm = overlap * overlap
28// 2. prepare_phialpha : prepare phialpha for outputting npy file
29// 3. prepare_gevdm : prepare gevdm for outputting npy file
30
31// for deepks_v_delta = 1
32// calculates v_delta_precalc
33template <typename TK>
34void cal_v_delta_precalc(const int nlocal,
35 const int nat,
36 const int nks,
37 const DeePKS_Param& deepks_param,
38 const std::vector<ModuleBase::Vector3<double>>& kvec_d,
39 const std::vector<hamilt::HContainer<double>*> phialpha,
40 const std::vector<torch::Tensor> gevdm,
41 const UnitCell& ucell,
42 const LCAO_Orbitals& orb,
43 const Parallel_Orbitals& pv,
44 const Grid_Driver& GridD,
45 torch::Tensor& v_delta_precalc);
46
47// for deepks_v_delta = 2
48// prepare phialpha for outputting npy file
49template <typename TK>
50void prepare_phialpha(const int nlocal,
51 const int nat,
52 const int nks,
53 const DeePKS_Param& deepks_param,
54 const std::vector<ModuleBase::Vector3<double>>& kvec_d,
55 const std::vector<hamilt::HContainer<double>*> phialpha,
56 const UnitCell& ucell,
57 const LCAO_Orbitals& orb,
58 const Parallel_Orbitals& pv,
59 const Grid_Driver& GridD,
60 torch::Tensor& phialpha_out);
61
62// prepare gevdm for outputting npy file
63void prepare_gevdm(const int nat,
64 const DeePKS_Param& deepks_param,
65 const LCAO_Orbitals& orb,
66 const std::vector<torch::Tensor>& gevdm_in,
67 torch::Tensor& gevdm_out);
68} // namespace DeePKS_domain
69#endif
70#endif
Definition sltk_grid_driver.h:40
Definition ORB_read.h:18
3 elements vector
Definition vector3.h:24
Definition parallel_orbitals.h:9
Definition unitcell.h:15
Definition hcontainer.h:144
Definition deepks_basic.h:12
void cal_v_delta_precalc(const int nlocal, const int nat, const int nks, const DeePKS_Param &deepks_param, const std::vector< ModuleBase::Vector3< double > > &kvec_d, const std::vector< hamilt::HContainer< double > * > phialpha, const std::vector< torch::Tensor > gevdm, const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD, torch::Tensor &v_delta_precalc)
cal_v_delta_precalc : v_delta_precalc is used for training with v_delta label,
Definition deepks_vdpre.cpp:24
void prepare_gevdm(const int nat, const DeePKS_Param &deepks_param, const LCAO_Orbitals &orb, const std::vector< torch::Tensor > &gevdm_in, torch::Tensor &gevdm_out)
Definition deepks_vdpre.cpp:268
void prepare_phialpha(const int nlocal, const int nat, const int nks, const DeePKS_Param &deepks_param, const std::vector< ModuleBase::Vector3< double > > &kvec_d, const std::vector< hamilt::HContainer< double > * > phialpha, const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD, torch::Tensor &phialpha_out)
Definition deepks_vdpre.cpp:174
Definition deepks_param.h:11