ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
input.h
Go to the documentation of this file.
1#ifndef INPUT_H
2#define INPUT_H
3
4#include <torch/torch.h>
5
6class Input
7{
8 // ---------- read in the settings from nnINPUT --------
9 public:
10 Input(){};
12 {
13 delete[] this->train_dir;
14 delete[] this->train_cell;
15 delete[] this->train_a;
16 delete[] this->validation_dir;
17 delete[] this->validation_cell;
18 delete[] this->validation_a;
19
20 delete[] this->ml_gammanl;
21 delete[] this->ml_pnl;
22 delete[] this->ml_qnl;
23 delete[] this->ml_xi;
24 delete[] this->ml_tanhxi;
25 delete[] this->ml_tanhxi_nl;
26 delete[] this->ml_tanh_pnl;
27 delete[] this->ml_tanh_qnl;
28 delete[] this->ml_tanhp_nl;
29 delete[] this->ml_tanhq_nl;
30 delete[] this->chi_xi;
31 delete[] this->chi_pnl;
32 delete[] this->chi_qnl;
33 delete[] this->kernel_type;
34 delete[] this->kernel_scaling;
35 delete[] this->yukawa_alpha;
36 delete[] this->kernel_file;
37
38 };
39
40 void readInput();
41
42 template <class T> static void read_value(std::ifstream &ifs, T &var)
43 {
44 ifs >> var;
45 ifs.ignore(150, '\n');
46 return;
47 }
48
49 template <class T> static void read_values(std::ifstream &ifs, const int length, T *var)
50 {
51 for (int i = 0; i < length; ++i)
52 {
53 ifs >> var[i];
54 }
55 ifs.ignore(150, '\n');
56 return;
57 }
58
59 // training
60 int fftdim = 0;
61 int nbatch = 0;
62 int ntrain = 1;
63 int nvalidation = 0;
64 std::string *train_dir = nullptr;
65 std::string *train_cell = nullptr;
66 double *train_a = nullptr;
67 std::string *validation_dir = nullptr;
68 std::string *validation_cell = nullptr;
69 double *validation_a = nullptr;
70 std::string loss = "both";
71 int nepoch = 1000;
72 double lr_start = 0.01; // learning rate 2023-02-24
73 double lr_end = 1e-4;
74 int lr_fre = 5000;
75 double exponent = 5.; // exponent of weight rho^{exponent/3.}
76
77 // output
78 int dump_fre = 1;
79 int print_fre = 1;
80
81 // descriptors
82 // semi-local descriptors
83 bool ml_gamma = false;
84 bool ml_p = false;
85 bool ml_q = false;
86 bool ml_tanhp = false;
87 bool ml_tanhq = false;
88 double chi_p = 1.;
89 double chi_q = 1.;
90 // non-local descriptors
91 bool* ml_gammanl = nullptr;
92 bool* ml_pnl = nullptr;
93 bool* ml_qnl = nullptr;
94 bool* ml_xi = nullptr;
95 bool* ml_tanhxi = nullptr;
96 bool* ml_tanhxi_nl = nullptr;
97 bool* ml_tanh_pnl = nullptr;
98 bool* ml_tanh_qnl = nullptr;
99 bool* ml_tanhp_nl = nullptr;
100 bool* ml_tanhq_nl = nullptr;
101 double* chi_xi = nullptr;
102 double* chi_pnl = nullptr;
103 double* chi_qnl = nullptr;
104
105 int feg_limit = 0; // Free Electron Gas
106 int change_step = 0; // when feg_limit=3, change the output of net after change_step
107
108 // coefficients in loss function
109 double coef_e = 1.;
110 double coef_p = 1.;
111 double coef_feg_e = 1.;
112 double coef_feg_p = 1.;
113
114 // size of nn
115 int nnode = 10;
116 int nlayer = 3;
117
118 // kernel
119 int nkernel = 1;
120 int* kernel_type = nullptr;
121 double* kernel_scaling = nullptr;
122 double* yukawa_alpha = nullptr;
123 std::string* kernel_file = nullptr;
124
125 // GPU
126 std::string device_type = "gpu";
127 bool check_pot = false;
128
129 static void print(std::string message)
130 {
131 std::cout << message << std::endl;
132 }
133};
134#endif
Definition input.h:7
std::string * validation_cell
Definition input.h:68
double chi_q
Definition input.h:89
bool * ml_pnl
Definition input.h:92
double * train_a
Definition input.h:66
double exponent
Definition input.h:75
int fftdim
Definition input.h:60
static void print(std::string message)
Definition input.h:129
double coef_e
Definition input.h:109
bool ml_q
Definition input.h:85
bool * ml_tanhxi
Definition input.h:95
int * kernel_type
Definition input.h:120
bool ml_tanhq
Definition input.h:87
double * chi_xi
Definition input.h:101
int nvalidation
Definition input.h:63
bool check_pot
Definition input.h:127
bool ml_tanhp
Definition input.h:86
static void read_values(std::ifstream &ifs, const int length, T *var)
Definition input.h:49
std::string loss
Definition input.h:70
bool * ml_gammanl
Definition input.h:91
std::string * validation_dir
Definition input.h:67
int ntrain
Definition input.h:62
double * kernel_scaling
Definition input.h:121
double * validation_a
Definition input.h:69
double coef_feg_e
Definition input.h:111
int nnode
Definition input.h:115
double * yukawa_alpha
Definition input.h:122
int nkernel
Definition input.h:119
int nlayer
Definition input.h:116
double lr_end
Definition input.h:73
double coef_feg_p
Definition input.h:112
int print_fre
Definition input.h:79
static void read_value(std::ifstream &ifs, T &var)
Definition input.h:42
std::string * train_cell
Definition input.h:65
int change_step
Definition input.h:106
bool * ml_xi
Definition input.h:94
double lr_start
Definition input.h:72
bool * ml_tanh_qnl
Definition input.h:98
bool ml_gamma
Definition input.h:83
double coef_p
Definition input.h:110
int nepoch
Definition input.h:71
bool * ml_tanhq_nl
Definition input.h:100
Input()
Definition input.h:10
bool * ml_qnl
Definition input.h:93
void readInput()
Definition input.cpp:3
int dump_fre
Definition input.h:78
double chi_p
Definition input.h:88
bool ml_p
Definition input.h:84
int lr_fre
Definition input.h:74
std::string * train_dir
Definition input.h:64
int feg_limit
Definition input.h:105
~Input()
Definition input.h:11
std::string device_type
Definition input.h:126
double * chi_pnl
Definition input.h:102
bool * ml_tanhp_nl
Definition input.h:99
int nbatch
Definition input.h:61
bool * ml_tanh_pnl
Definition input.h:97
double * chi_qnl
Definition input.h:103
std::string * kernel_file
Definition input.h:123
bool * ml_tanhxi_nl
Definition input.h:96
#define T
Definition exp.cpp:237