ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
input.h
Go to the documentation of this file.
1#ifndef INPUT_H
2#define INPUT_H
3
4#include <torch/torch.h>
5
6class Input
7{
8 // ---------- read in the settings from nnINPUT --------
9 public:
10 Input(){};
12 {
13 delete[] this->train_dir;
14 delete[] this->train_cell;
15 delete[] this->train_a;
16 delete[] this->validation_dir;
17 delete[] this->validation_cell;
18 delete[] this->validation_a;
19
20 delete[] this->ml_gammanl;
21 delete[] this->ml_pnl;
22 delete[] this->ml_qnl;
23 delete[] this->ml_xi;
24 delete[] this->ml_tanhxi;
25 delete[] this->ml_tanhxi_nl;
26 delete[] this->ml_tanh_pnl;
27 delete[] this->ml_tanh_qnl;
28 delete[] this->ml_tanhp_nl;
29 delete[] this->ml_tanhq_nl;
30 delete[] this->chi_xi;
31 delete[] this->chi_pnl;
32 delete[] this->chi_qnl;
33 delete[] this->kernel_type;
34 delete[] this->kernel_scaling;
35 delete[] this->yukawa_alpha;
36 delete[] this->kernel_file;
37
38 };
39
40 void readInput();
41
42 template <class T> static void read_value(std::ifstream &ifs, T &var)
43 {
44 ifs >> var;
45 ifs.ignore(150, '\n');
46 return;
47 }
48
49 template <class T> static void read_values(std::ifstream &ifs, const int length, T *var)
50 {
51 for (int i = 0; i < length; ++i)
52 {
53 ifs >> var[i];
54 }
55 ifs.ignore(150, '\n');
56 return;
57 }
58
59 // training
60 int fftdim = 0;
61 int nbatch = 0;
62 int ntrain = 1;
63 int nvalidation = 0;
64 std::string *train_dir = nullptr;
65 std::string *train_cell = nullptr;
66 double *train_a = nullptr;
67 std::string *validation_dir = nullptr;
68 std::string *validation_cell = nullptr;
69 double *validation_a = nullptr;
70 std::string loss = "both";
71 int nepoch = 1000;
72 double lr_start = 0.01; // learning rate 2023-02-24
73 double lr_end = 1e-4;
74 int lr_fre = 5000;
75 double exponent = 5.; // exponent of weight rho^{exponent/3.}
76 std::string energy_type = "kedf"; // kedf or exx
77
78 // output
79 int dump_fre = 1;
80 int print_fre = 1;
81
82 // descriptors
83 // semi-local descriptors
84 bool ml_gamma = false;
85 bool ml_p = false;
86 bool ml_q = false;
87 bool ml_tanhp = false;
88 bool ml_tanhq = false;
89 double chi_p = 1.;
90 double chi_q = 1.;
91 // non-local descriptors
92 bool* ml_gammanl = nullptr;
93 bool* ml_pnl = nullptr;
94 bool* ml_qnl = nullptr;
95 bool* ml_xi = nullptr;
96 bool* ml_tanhxi = nullptr;
97 bool* ml_tanhxi_nl = nullptr;
98 bool* ml_tanh_pnl = nullptr;
99 bool* ml_tanh_qnl = nullptr;
100 bool* ml_tanhp_nl = nullptr;
101 bool* ml_tanhq_nl = nullptr;
102 double* chi_xi = nullptr;
103 double* chi_pnl = nullptr;
104 double* chi_qnl = nullptr;
105
106 int feg_limit = 0; // Free Electron Gas
107 int change_step = 0; // when feg_limit=3, change the output of net after change_step
108
109 // coefficients in loss function
110 double coef_e = 1.;
111 double coef_p = 1.;
112 double coef_feg_e = 1.;
113 double coef_feg_p = 1.;
114
115 // size of nn
116 int nnode = 10;
117 int nlayer = 3;
118
119 // kernel
120 int nkernel = 1;
121 int* kernel_type = nullptr;
122 double* kernel_scaling = nullptr;
123 double* yukawa_alpha = nullptr;
124 std::string* kernel_file = nullptr;
125
126 // GPU
127 std::string device_type = "gpu";
128 bool check_pot = false;
129
130 static void print(std::string message)
131 {
132 std::cout << message << std::endl;
133 }
134};
135#endif
Definition input.h:7
std::string * validation_cell
Definition input.h:68
double chi_q
Definition input.h:90
bool * ml_pnl
Definition input.h:93
double * train_a
Definition input.h:66
double exponent
Definition input.h:75
int fftdim
Definition input.h:60
static void print(std::string message)
Definition input.h:130
double coef_e
Definition input.h:110
bool ml_q
Definition input.h:86
bool * ml_tanhxi
Definition input.h:96
int * kernel_type
Definition input.h:121
std::string energy_type
Definition input.h:76
bool ml_tanhq
Definition input.h:88
double * chi_xi
Definition input.h:102
int nvalidation
Definition input.h:63
bool check_pot
Definition input.h:128
bool ml_tanhp
Definition input.h:87
static void read_values(std::ifstream &ifs, const int length, T *var)
Definition input.h:49
std::string loss
Definition input.h:70
bool * ml_gammanl
Definition input.h:92
std::string * validation_dir
Definition input.h:67
int ntrain
Definition input.h:62
double * kernel_scaling
Definition input.h:122
double * validation_a
Definition input.h:69
double coef_feg_e
Definition input.h:112
int nnode
Definition input.h:116
double * yukawa_alpha
Definition input.h:123
int nkernel
Definition input.h:120
int nlayer
Definition input.h:117
double lr_end
Definition input.h:73
double coef_feg_p
Definition input.h:113
int print_fre
Definition input.h:80
static void read_value(std::ifstream &ifs, T &var)
Definition input.h:42
std::string * train_cell
Definition input.h:65
int change_step
Definition input.h:107
bool * ml_xi
Definition input.h:95
double lr_start
Definition input.h:72
bool * ml_tanh_qnl
Definition input.h:99
bool ml_gamma
Definition input.h:84
double coef_p
Definition input.h:111
int nepoch
Definition input.h:71
bool * ml_tanhq_nl
Definition input.h:101
Input()
Definition input.h:10
bool * ml_qnl
Definition input.h:94
void readInput()
Definition input.cpp:3
int dump_fre
Definition input.h:79
double chi_p
Definition input.h:89
bool ml_p
Definition input.h:85
int lr_fre
Definition input.h:74
std::string * train_dir
Definition input.h:64
int feg_limit
Definition input.h:106
~Input()
Definition input.h:11
std::string device_type
Definition input.h:127
double * chi_pnl
Definition input.h:103
bool * ml_tanhp_nl
Definition input.h:100
int nbatch
Definition input.h:61
bool * ml_tanh_pnl
Definition input.h:98
double * chi_qnl
Definition input.h:104
std::string * kernel_file
Definition input.h:124
bool * ml_tanhxi_nl
Definition input.h:97
#define T
Definition exp.cpp:237