29#include <torch/script.h>
30#include <torch/torch.h>
88 std::vector<hamilt::HContainer<double>*>
phialpha;
96 std::vector<torch::Tensor>
pdm;
Definition sltk_grid_driver.h:43
Definition LCAO_deepks.h:52
int n_descriptor
Definition LCAO_deepks.h:75
double E_delta
(Unit: Ry) Correction energy provided by NN
Definition LCAO_deepks.h:59
void allocate_V_delta(const int nat, const int nks=1)
Allocate memory for correction to Hamiltonian.
Definition LCAO_deepks.cpp:175
void init_index(const int ntype, const int nat, std::vector< int > na, const int tot_inl, const LCAO_Orbitals &orb, std::ofstream &ofs)
Definition LCAO_deepks.cpp:135
void set_hr_cal(bool cal)
Definition LCAO_deepks.h:106
int nmaxd
Definition LCAO_deepks.h:73
int inlmax
Definition LCAO_deepks.h:74
torch::jit::script::Module model_deepks
Definition LCAO_deepks.h:84
void init(const LCAO_Orbitals &orb, const int nat, const int ntype, const int nks, const Parallel_Orbitals &pv_in, std::vector< int > na, std::ofstream &ofs)
Definition LCAO_deepks.cpp:50
LCAO_Deepks()
Definition LCAO_deepks.cpp:21
double ** gedm
dE/dD, autograd from loaded model(E: Ry)
Definition LCAO_deepks.h:99
void dpks_cal_e_delta_band(const std::vector< std::vector< T > > &dm, const int nks)
a temporary interface for cal_e_delta_band
Definition LCAO_deepks.cpp:259
std::vector< hamilt::HContainer< double > * > phialpha
Definition LCAO_deepks.h:88
std::vector< torch::Tensor > pdm
Definition LCAO_deepks.h:96
int des_per_atom
Definition LCAO_deepks.h:76
hamilt::HContainer< double > * dm_r
Definition LCAO_deepks.h:91
~LCAO_Deepks()
Definition LCAO_deepks.cpp:30
double e_delta_band
(Unit: Ry)
Definition LCAO_deepks.h:61
int lmaxd
Definition LCAO_deepks.h:72
int get_hr_cal()
Definition LCAO_deepks.h:102
bool init_pdm
Definition LCAO_deepks.h:80
ModuleBase::IntArray * inl_index
Definition LCAO_deepks.h:78
std::vector< int > inl2l
Definition LCAO_deepks.h:77
void init_DMR(const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD)
Initialize the dm_r container.
Definition LCAO_deepks.cpp:211
const Parallel_Orbitals * pv
Definition LCAO_deepks.h:165
std::vector< std::vector< T > > V_delta
Definition LCAO_deepks.h:65
bool hr_cal
Definition LCAO_deepks.h:155
Integer array.
Definition intarray.h:20
Definition parallel_orbitals.h:9
Definition hcontainer.h:144