ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
ml_base.h
Go to the documentation of this file.
1#ifndef ML_BASE_H
2#define ML_BASE_H
3
4#include <vector>
5#include <string>
6#include <map>
7#include <memory>
8#include <cmath>
9
10#ifdef __MLALGO
13
14// The ML_Base class encapsulates common functionality for Machine Learning based
15// constructs in OFDFT and EXX.
17{
18public:
19 ML_Base();
20 ~ML_Base();
21
22 // Common Interface
23 void set_device(std::string device_inpt);
24
25 // Tools
26 void loadVector(std::string filename, std::vector<double> &data);
27 void dumpVector(std::string filename, const std::vector<double> &data);
28 void dumpTensor(std::string filename, const torch::Tensor &data);
29 void dumpMatrix(std::string filename, const ModuleBase::matrix &data);
30
31 int nx_tot = 0; // equal to nx (called by NN)
32 torch::Tensor get_data(std::string parameter, const int ikernel) const;
33
34protected:
35 void updateInput(const double * const * prho, const ModulePW::PW_Basis *pw_rho);
36 void NN_forward(const double * const * prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad);
37 void get_potential_(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential);
38
39 // Potential Terms - these appear identical in both classes or are intended to be shared
40 double potGammaTerm(int ir);
41 double potPTerm1(int ir);
42 double potQTerm1(int ir);
43 double potXiTerm1(int ir);
44 double potTanhxiTerm1(int ir);
45 double potTanhpTerm1(int ir);
46 double potTanhqTerm1(int ir);
47
48 // Derived classes should ensure they can work with these signatures.
49 // Note: ML_EXX originally passed tau_lda for some of these.
50 // If tau_lda is needed, derived classes can override or we can add it to member variables.
51 // For now, keeping signatures compatible with member access.
52 void potGammanlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rGammanlTerm);
53 void potXinlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rXinlTerm);
54 void potTanhxinlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rTanhxinlTerm);
55 void potTanhxi_nlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rTanhxi_nlTerm);
56 void potPPnlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rPPnlTerm);
57 void potQQnlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rQQnlTerm);
58 void potTanhpTanh_pnlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rTanhpTanh_pnlTerm);
59 void potTanhqTanh_qnlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rTanhqTanh_qnlTerm);
60 void potTanhpTanhp_nlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rTanhpTanhp_nlTerm);
61 void potTanhqTanhq_nlTerm(const double * const *prho, const std::vector<double> &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector<double> &rTanhqTanhq_nlTerm);
62
63protected:
64 // --- Member Variables (Common) ---
65
67
68 int nx = 0; // number of grid points
69 double dV = 0.;
70
71 // Constants
72 double pqcoef = 1.0 / (4.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0)); // coefficient of p and q
73 double feg_net_F = 0.0;
74 double feg3_correct = 0.541324854612918; // ln(e - 1)
75 double energy_prefactor = 0.0; // cTF for KEDF, cDirac for EXX
76 double energy_exponent = 0.0; // 5/3 for KEDF, 4/3 for EXX
77
78 // Descriptors and hyperparameters
79 int ninput = 0; // number of descriptors
80 std::vector<double> gamma;
81 std::vector<double> p;
82 std::vector<double> q;
83 std::vector<std::vector<double>> gammanl;
84 std::vector<std::vector<double>> pnl;
85 std::vector<std::vector<double>> qnl;
86 std::vector<std::vector<double>> nablaRho;
87
88 // Parameters
89 std::vector<double> chi_xi;
90 double chi_p = 1.0;
91 double chi_q = 1.0;
92 std::vector<std::vector<double>> xi;
93 std::vector<std::vector<double>> tanhxi;
94 std::vector<std::vector<double>> tanhxi_nl;
95 std::vector<double> tanhp;
96 std::vector<double> tanhq;
97
98 // plan 1
99 std::vector<double> chi_pnl;
100 std::vector<double> chi_qnl;
101 std::vector<std::vector<double>> tanh_pnl;
102 std::vector<std::vector<double>> tanh_qnl;
103 // plan 2
104 std::vector<std::vector<double>> tanhp_nl;
105 std::vector<std::vector<double>> tanhq_nl;
106
107 // GPU / Device
108 torch::DeviceType device_type = torch::kCPU;
109 torch::Device device = torch::Device(torch::kCPU);
110 torch::Device device_CPU = torch::Device(torch::kCPU);
111
112 // Neural Network
113 std::shared_ptr<NN_OFImpl> nn;
114 double* enhancement_cpu_ptr = nullptr;
115 double* gradient_cpu_ptr = nullptr;
116 int nkernel = 1;
117
118 // Switch flags
119 bool ml_gamma = false;
120 bool ml_p = false;
121 bool ml_q = false;
122 bool ml_tanhp = false;
123 bool ml_tanhq = false;
124 bool ml_gammanl = false;
125 bool ml_pnl = false;
126 bool ml_qnl = false;
127 bool ml_xi = false;
128 bool ml_tanhxi = false;
129 bool ml_tanhxi_nl = false;
130 bool ml_tanh_pnl = false;
131 bool ml_tanh_qnl = false;
132 bool ml_tanhp_nl = false;
133 bool ml_tanhq_nl = false;
134
135 // Maps
136 std::vector<std::string> descriptor_type;
137 std::vector<int> kernel_index;
138 std::map<std::string, std::vector<int>> descriptor2kernel;
139 std::map<std::string, std::vector<int>> descriptor2index;
140 std::map<std::string, std::vector<bool>> gene_data_label;
141};
142
143#endif // __MLALGO
144#endif // ML_BASE_H
Definition ml_base.h:17
double pqcoef
Definition ml_base.h:72
std::vector< int > kernel_index
Definition ml_base.h:137
std::vector< std::vector< double > > tanhxi_nl
Definition ml_base.h:94
void potXinlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rXinlTerm)
Definition ml_base_pot.cpp:72
int nkernel
Definition ml_base.h:116
void potTanhqTanhq_nlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhqTanhq_nlTerm)
Definition ml_base_pot.cpp:420
std::vector< std::string > descriptor_type
Definition ml_base.h:136
bool ml_xi
Definition ml_base.h:127
double chi_q
Definition ml_base.h:91
double feg_net_F
Definition ml_base.h:73
ML_Base()
Definition ml_base.cpp:6
std::vector< std::vector< double > > tanh_qnl
Definition ml_base.h:102
torch::Device device_CPU
Definition ml_base.h:110
double potGammaTerm(int ir)
Definition ml_base_pot.cpp:5
bool ml_q
Definition ml_base.h:121
double feg3_correct
Definition ml_base.h:74
std::shared_ptr< NN_OFImpl > nn
Definition ml_base.h:113
bool ml_qnl
Definition ml_base.h:126
torch::DeviceType device_type
Definition ml_base.h:108
void potTanhqTanh_qnlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhqTanh_qnlTerm)
Definition ml_base_pot.cpp:311
std::map< std::string, std::vector< int > > descriptor2index
Definition ml_base.h:139
void NN_forward(const double *const *prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad)
Definition ml_base.cpp:98
double potTanhpTerm1(int ir)
Definition ml_base_pot.cpp:40
double dV
Definition ml_base.h:69
int nx
Definition ml_base.h:68
bool ml_tanhxi
Definition ml_base.h:128
bool ml_tanh_pnl
Definition ml_base.h:130
void potTanhxi_nlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhxi_nlTerm)
Definition ml_base_pot.cpp:113
bool ml_gamma
Definition ml_base.h:119
int nx_tot
Definition ml_base.h:31
std::vector< std::vector< double > > gammanl
Definition ml_base.h:83
void dumpTensor(std::string filename, const torch::Tensor &data)
Definition ml_base.cpp:204
double * enhancement_cpu_ptr
Definition ml_base.h:114
torch::Tensor get_data(std::string parameter, const int ikernel) const
Definition ml_base.cpp:135
std::vector< std::vector< double > > tanhp_nl
Definition ml_base.h:104
double potPTerm1(int ir)
Definition ml_base_pot.cpp:9
std::vector< std::vector< double > > tanhxi
Definition ml_base.h:93
std::vector< double > chi_pnl
Definition ml_base.h:99
~ML_Base()
Definition ml_base.cpp:8
std::vector< double > q
Definition ml_base.h:82
std::map< std::string, std::vector< int > > descriptor2kernel
Definition ml_base.h:138
double potTanhqTerm1(int ir)
Definition ml_base_pot.cpp:45
std::vector< double > p
Definition ml_base.h:81
void potPPnlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rPPnlTerm)
Definition ml_base_pot.cpp:146
double potQTerm1(int ir)
Definition ml_base_pot.cpp:13
void dumpVector(std::string filename, const std::vector< double > &data)
Definition ml_base.cpp:195
std::vector< std::vector< double > > tanhq_nl
Definition ml_base.h:105
std::vector< double > gamma
Definition ml_base.h:80
void potGammanlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rGammanlTerm)
Definition ml_base_pot.cpp:52
std::vector< std::vector< double > > tanh_pnl
Definition ml_base.h:101
void potTanhpTanh_pnlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhpTanh_pnlTerm)
Definition ml_base_pot.cpp:252
std::vector< double > chi_qnl
Definition ml_base.h:100
std::vector< std::vector< double > > nablaRho
Definition ml_base.h:86
double potTanhxiTerm1(int ir)
Definition ml_base_pot.cpp:28
std::vector< double > tanhq
Definition ml_base.h:96
double potXiTerm1(int ir)
Definition ml_base_pot.cpp:17
ModuleIO::Cal_MLKEDF_Descriptors * cal_tool
Definition ml_base.h:66
bool ml_tanh_qnl
Definition ml_base.h:131
std::vector< std::vector< double > > qnl
Definition ml_base.h:85
void dumpMatrix(std::string filename, const ModuleBase::matrix &data)
Definition ml_base.cpp:212
double energy_prefactor
Definition ml_base.h:75
bool ml_pnl
Definition ml_base.h:125
std::vector< double > tanhp
Definition ml_base.h:95
bool ml_tanhxi_nl
Definition ml_base.h:129
void potTanhxinlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhxinlTerm)
Definition ml_base_pot.cpp:92
double * gradient_cpu_ptr
Definition ml_base.h:115
void potTanhpTanhp_nlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhpTanhp_nlTerm)
Definition ml_base_pot.cpp:362
std::vector< std::vector< double > > pnl
Definition ml_base.h:84
std::vector< double > chi_xi
Definition ml_base.h:89
torch::Device device
Definition ml_base.h:109
bool ml_tanhq
Definition ml_base.h:123
bool ml_tanhq_nl
Definition ml_base.h:133
bool ml_p
Definition ml_base.h:120
void set_device(std::string device_inpt)
Definition ml_base.cpp:13
double energy_exponent
Definition ml_base.h:76
bool ml_gammanl
Definition ml_base.h:124
void updateInput(const double *const *prho, const ModulePW::PW_Basis *pw_rho)
Definition ml_base.cpp:37
void loadVector(std::string filename, std::vector< double > &data)
Definition ml_base.cpp:189
bool ml_tanhp_nl
Definition ml_base.h:132
int ninput
Definition ml_base.h:79
double chi_p
Definition ml_base.h:90
std::map< std::string, std::vector< bool > > gene_data_label
Definition ml_base.h:140
void get_potential_(const double *const *prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential)
Definition ml_base.cpp:155
bool ml_tanhp
Definition ml_base.h:122
std::vector< std::vector< double > > xi
Definition ml_base.h:92
void potQQnlTerm(const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rQQnlTerm)
Definition ml_base_pot.cpp:204
Definition matrix.h:18
A class to calculate the descriptors for ML KEDF. Sun, Liang, and Mohan Chen. Physical Review B 109....
Definition cal_mlkedf_descriptors.h:21
A class which can convert a function of "r" to the corresponding linear superposition of plane waves ...
Definition pw_basis.h:56