ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
LCAO_deepks.h
Go to the documentation of this file.
1#ifndef LCAO_DEEPKS_H
2#define LCAO_DEEPKS_H
3
4#ifdef __MLALGO
5
6#include "deepks_basic.h"
7#include "deepks_check.h"
8#include "deepks_descriptor.h"
9#include "deepks_force.h"
10#include "deepks_fpre.h"
11#include "deepks_orbital.h"
12#include "deepks_orbpre.h"
13#include "deepks_pdm.h"
14#include "deepks_phialpha.h"
15#include "deepks_spre.h"
16#include "deepks_vdelta.h"
17#include "deepks_vdpre.h"
18#include "deepks_vdrpre.h"
21#include "source_base/matrix.h"
22#include "source_base/timer.h"
27
28#include <torch/script.h>
29#include <torch/torch.h>
30
46// caoyu add 2021-03-29
47// wenfei modified 2022-1-5
48//
49template <typename T>
51{
52
53 //-------------------
54 // public variables
55 //-------------------
56 public:
58 double E_delta = 0.0;
60 double e_delta_band = 0.0;
61
64 std::vector<std::vector<T>> V_delta;
65
66 //-------------------
67 // private variables
68 //-------------------
69 // private:
70 public: // change to public to reconstuct the code, 2024-07-22 by mohan
71 int lmaxd = 0; // max l of descirptors
72 int nmaxd = 0; //#. descriptors per l
73 int inlmax = 0; // tot. number {i,n,l} - atom, n, l
74 int n_descriptor; // natoms * des_per_atom, size of descriptor(projector) basis set
75 int des_per_atom; // \sum_L{Nchi(L)*(2L+1)}
76 std::vector<int> inl2l; // inl2l[inl] = inl2l[nl] = l (not related to iat) of descriptor with inl_index
77 ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
78
79 bool init_pdm = false; // for DeePKS NSCF calculation, set init_pdm to skip the calculation of pdm in SCF iteration
80
81 // deep neural network module that provides corrected Hamiltonian term and
82 // related derivatives. Used in cal_edelta_gedm.
83 torch::jit::script::Module model_deepks;
84
85 // saves <phi(0)|alpha(R)> and its derivatives
86 // index 0 for itself and index 1-3 for derivatives over x,y,z
87 std::vector<hamilt::HContainer<double>*> phialpha;
88
89 // density matrix in real space
91
92 // projected density matrix
93 // [tot_Inl][2l+1][2l+1], here l is corresponding to inl;
94 // [nat][nlm*nlm] for equivariant version
95 std::vector<torch::Tensor> pdm;
96
98 double** gedm; //[tot_Inl][(2l+1)*(2l+1)]
99
100 // functions for hr status: 1. get value; 2. set value;
102 {
103 return this->hr_cal;
104 }
105 void set_hr_cal(bool cal)
106 {
107 this->hr_cal = cal;
108 }
109
110 //-------------------
111 // LCAO_deepks.cpp
112 //-------------------
113
114 // This file contains constructor and destructor of the class LCAO_deepks,
115 // as well as subroutines for initializing and releasing relevant data structures
116
117 // Other than the constructor and the destructor, it contains 3 types of subroutines:
118 // 1. subroutines that are related to calculating descriptors:
119 // - init : allocates some arrays
120 // - init_index : records the index (inl)
121 // 2. subroutines that are related to V_delta:
122 // - allocate_V_delta : allocates V_delta; if calculating force, it also allocates F_delta
123
124 public:
125 explicit LCAO_Deepks();
126 ~LCAO_Deepks();
127
130 void init(const LCAO_Orbitals& orb,
131 const int nat,
132 const int ntype,
133 const int nks,
134 const Parallel_Orbitals& pv_in,
135 std::vector<int> na,
136 std::ofstream& ofs);
137
139 void allocate_V_delta(const int nat, const int nks = 1);
140
142 void init_DMR(const UnitCell& ucell,
143 const LCAO_Orbitals& orb,
144 const Parallel_Orbitals& pv,
145 const Grid_Driver& GridD);
146
148 void dpks_cal_e_delta_band(const std::vector<std::vector<T>>& dm, const int nks);
149
150 private:
151 // flag of HR status,
152 // true : HR should be calculated
153 // false : HR has been calculated
154 bool hr_cal = true;
155
156 // arrange index of descriptor in all atoms
157 void init_index(const int ntype,
158 const int nat,
159 std::vector<int> na,
160 const int tot_inl,
161 const LCAO_Orbitals& orb,
162 std::ofstream& ofs);
163
165};
166
167#endif
168#endif
Definition sltk_grid_driver.h:43
Definition LCAO_deepks.h:51
int n_descriptor
Definition LCAO_deepks.h:74
double E_delta
(Unit: Ry) Correction energy provided by NN
Definition LCAO_deepks.h:58
void allocate_V_delta(const int nat, const int nks=1)
Allocate memory for correction to Hamiltonian.
Definition LCAO_deepks.cpp:175
void init_index(const int ntype, const int nat, std::vector< int > na, const int tot_inl, const LCAO_Orbitals &orb, std::ofstream &ofs)
Definition LCAO_deepks.cpp:135
void set_hr_cal(bool cal)
Definition LCAO_deepks.h:105
int nmaxd
Definition LCAO_deepks.h:72
int inlmax
Definition LCAO_deepks.h:73
torch::jit::script::Module model_deepks
Definition LCAO_deepks.h:83
void init(const LCAO_Orbitals &orb, const int nat, const int ntype, const int nks, const Parallel_Orbitals &pv_in, std::vector< int > na, std::ofstream &ofs)
Definition LCAO_deepks.cpp:50
LCAO_Deepks()
Definition LCAO_deepks.cpp:21
double ** gedm
dE/dD, autograd from loaded model(E: Ry)
Definition LCAO_deepks.h:98
void dpks_cal_e_delta_band(const std::vector< std::vector< T > > &dm, const int nks)
a temporary interface for cal_e_delta_band
Definition LCAO_deepks.cpp:259
std::vector< hamilt::HContainer< double > * > phialpha
Definition LCAO_deepks.h:87
std::vector< torch::Tensor > pdm
Definition LCAO_deepks.h:95
int des_per_atom
Definition LCAO_deepks.h:75
hamilt::HContainer< double > * dm_r
Definition LCAO_deepks.h:90
~LCAO_Deepks()
Definition LCAO_deepks.cpp:30
double e_delta_band
(Unit: Ry)
Definition LCAO_deepks.h:60
int lmaxd
Definition LCAO_deepks.h:71
int get_hr_cal()
Definition LCAO_deepks.h:101
bool init_pdm
Definition LCAO_deepks.h:79
ModuleBase::IntArray * inl_index
Definition LCAO_deepks.h:77
std::vector< int > inl2l
Definition LCAO_deepks.h:76
void init_DMR(const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD)
Initialize the dm_r container.
Definition LCAO_deepks.cpp:211
const Parallel_Orbitals * pv
Definition LCAO_deepks.h:164
std::vector< std::vector< T > > V_delta
Definition LCAO_deepks.h:64
bool hr_cal
Definition LCAO_deepks.h:154
Definition ORB_read.h:19
Integer array.
Definition intarray.h:20
Definition parallel_orbitals.h:9
Definition unitcell.h:17
Definition hcontainer.h:144