ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
LCAO_deepks.h
Go to the documentation of this file.
1#ifndef LCAO_DEEPKS_H
2#define LCAO_DEEPKS_H
3
4#ifdef __MLALGO
5
6#include "deepks_basic.h"
7#include "deepks_check.h"
8#include "deepks_descriptor.h"
9#include "deepks_force.h"
10#include "deepks_fpre.h"
11#include "deepks_orbital.h"
12#include "deepks_orbpre.h"
13#include "deepks_pdm.h"
14#include "deepks_phialpha.h"
15#include "deepks_spre.h"
16#include "deepks_vdelta.h"
17#include "deepks_vdpre.h"
18#include "deepks_vdrpre.h"
21#include "source_base/matrix.h"
22#include "source_base/timer.h"
27#include "source_io/winput.h"
28
29#include <torch/script.h>
30#include <torch/torch.h>
31
47// caoyu add 2021-03-29
48// wenfei modified 2022-1-5
49//
50template <typename T>
52{
53
54 //-------------------
55 // public variables
56 //-------------------
57 public:
59 double E_delta = 0.0;
61 double e_delta_band = 0.0;
62
65 std::vector<std::vector<T>> V_delta;
66
67 //-------------------
68 // private variables
69 //-------------------
70 // private:
71 public: // change to public to reconstuct the code, 2024-07-22 by mohan
72 int lmaxd = 0; // max l of descirptors
73 int nmaxd = 0; //#. descriptors per l
74 int inlmax = 0; // tot. number {i,n,l} - atom, n, l
75 int n_descriptor; // natoms * des_per_atom, size of descriptor(projector) basis set
76 int des_per_atom; // \sum_L{Nchi(L)*(2L+1)}
77 std::vector<int> inl2l; // inl2l[inl] = inl2l[nl] = l (not related to iat) of descriptor with inl_index
78 ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
79
80 bool init_pdm = false; // for DeePKS NSCF calculation, set init_pdm to skip the calculation of pdm in SCF iteration
81
82 // deep neural network module that provides corrected Hamiltonian term and
83 // related derivatives. Used in cal_edelta_gedm.
84 torch::jit::script::Module model_deepks;
85
86 // saves <phi(0)|alpha(R)> and its derivatives
87 // index 0 for itself and index 1-3 for derivatives over x,y,z
88 std::vector<hamilt::HContainer<double>*> phialpha;
89
90 // density matrix in real space
92
93 // projected density matrix
94 // [tot_Inl][2l+1][2l+1], here l is corresponding to inl;
95 // [nat][nlm*nlm] for equivariant version
96 std::vector<torch::Tensor> pdm;
97
99 double** gedm; //[tot_Inl][(2l+1)*(2l+1)]
100
101 // functions for hr status: 1. get value; 2. set value;
103 {
104 return this->hr_cal;
105 }
106 void set_hr_cal(bool cal)
107 {
108 this->hr_cal = cal;
109 }
110
111 //-------------------
112 // LCAO_deepks.cpp
113 //-------------------
114
115 // This file contains constructor and destructor of the class LCAO_deepks,
116 // as well as subroutines for initializing and releasing relevant data structures
117
118 // Other than the constructor and the destructor, it contains 3 types of subroutines:
119 // 1. subroutines that are related to calculating descriptors:
120 // - init : allocates some arrays
121 // - init_index : records the index (inl)
122 // 2. subroutines that are related to V_delta:
123 // - allocate_V_delta : allocates V_delta; if calculating force, it also allocates F_delta
124
125 public:
126 explicit LCAO_Deepks();
127 ~LCAO_Deepks();
128
131 void init(const LCAO_Orbitals& orb,
132 const int nat,
133 const int ntype,
134 const int nks,
135 const Parallel_Orbitals& pv_in,
136 std::vector<int> na,
137 std::ofstream& ofs);
138
140 void allocate_V_delta(const int nat, const int nks = 1);
141
143 void init_DMR(const UnitCell& ucell,
144 const LCAO_Orbitals& orb,
145 const Parallel_Orbitals& pv,
146 const Grid_Driver& GridD);
147
149 void dpks_cal_e_delta_band(const std::vector<std::vector<T>>& dm, const int nks);
150
151 private:
152 // flag of HR status,
153 // true : HR should be calculated
154 // false : HR has been calculated
155 bool hr_cal = true;
156
157 // arrange index of descriptor in all atoms
158 void init_index(const int ntype,
159 const int nat,
160 std::vector<int> na,
161 const int tot_inl,
162 const LCAO_Orbitals& orb,
163 std::ofstream& ofs);
164
166};
167
168#endif
169#endif
Definition sltk_grid_driver.h:43
Definition LCAO_deepks.h:52
int n_descriptor
Definition LCAO_deepks.h:75
double E_delta
(Unit: Ry) Correction energy provided by NN
Definition LCAO_deepks.h:59
void allocate_V_delta(const int nat, const int nks=1)
Allocate memory for correction to Hamiltonian.
Definition LCAO_deepks.cpp:175
void init_index(const int ntype, const int nat, std::vector< int > na, const int tot_inl, const LCAO_Orbitals &orb, std::ofstream &ofs)
Definition LCAO_deepks.cpp:135
void set_hr_cal(bool cal)
Definition LCAO_deepks.h:106
int nmaxd
Definition LCAO_deepks.h:73
int inlmax
Definition LCAO_deepks.h:74
torch::jit::script::Module model_deepks
Definition LCAO_deepks.h:84
void init(const LCAO_Orbitals &orb, const int nat, const int ntype, const int nks, const Parallel_Orbitals &pv_in, std::vector< int > na, std::ofstream &ofs)
Definition LCAO_deepks.cpp:50
LCAO_Deepks()
Definition LCAO_deepks.cpp:21
double ** gedm
dE/dD, autograd from loaded model(E: Ry)
Definition LCAO_deepks.h:99
void dpks_cal_e_delta_band(const std::vector< std::vector< T > > &dm, const int nks)
a temporary interface for cal_e_delta_band
Definition LCAO_deepks.cpp:259
std::vector< hamilt::HContainer< double > * > phialpha
Definition LCAO_deepks.h:88
std::vector< torch::Tensor > pdm
Definition LCAO_deepks.h:96
int des_per_atom
Definition LCAO_deepks.h:76
hamilt::HContainer< double > * dm_r
Definition LCAO_deepks.h:91
~LCAO_Deepks()
Definition LCAO_deepks.cpp:30
double e_delta_band
(Unit: Ry)
Definition LCAO_deepks.h:61
int lmaxd
Definition LCAO_deepks.h:72
int get_hr_cal()
Definition LCAO_deepks.h:102
bool init_pdm
Definition LCAO_deepks.h:80
ModuleBase::IntArray * inl_index
Definition LCAO_deepks.h:78
std::vector< int > inl2l
Definition LCAO_deepks.h:77
void init_DMR(const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD)
Initialize the dm_r container.
Definition LCAO_deepks.cpp:211
const Parallel_Orbitals * pv
Definition LCAO_deepks.h:165
std::vector< std::vector< T > > V_delta
Definition LCAO_deepks.h:65
bool hr_cal
Definition LCAO_deepks.h:155
Definition ORB_read.h:19
Integer array.
Definition intarray.h:20
Definition parallel_orbitals.h:9
Definition unitcell.h:16
Definition hcontainer.h:144