ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
train_kedf.h
Go to the documentation of this file.
1#ifndef TRAIN_KEDF_H
2#define TRAIN_KEDF_H
3
4#include "./data.h"
5#include "./grid.h"
6#include "./input.h"
7#include "./kernel.h"
8#include "./nn_of.h"
9#include "./pauli_potential.h"
10
11#include <torch/torch.h>
12
14{
15 public:
18
19 std::shared_ptr<NN_OFImpl> nn;
23 Kernel *kernel_train = nullptr;
24 Kernel *kernel_vali = nullptr;
26 //----------- training set -----------
28 double *train_volume = nullptr;
29 //---------validation set ------------
31 double *vali_volume = nullptr;
32 // ------------------------------------
33
34 torch::Device device = torch::Device(torch::kCUDA);
35 int ninput = 0;
36 std::vector<std::string> descriptor_type = {};
37 std::vector<int> kernel_index = {};
38
39 // -------- free electron gas ---------
40 torch::Tensor feg_inpt;
41 torch::Tensor feg_predict;
42 torch::Tensor feg_dFdgamma;
43
44 // ----------- constants ---------------
45 double feg3_correct = 0.541324854612918; // ln(e - 1)
46 const double cTF
47 = 3.0 / 10.0 * std::pow(3 * std::pow(M_PI, 2.0), 2.0 / 3.0)
48 * 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
49 const double pqcoef = 1.0 / (4.0 * std::pow(3 * std::pow(M_PI, 2.0), 2.0 / 3.0)); // coefficient of p and q
50
51 void train();
52 void potTest();
53 void setUpFFT();
54 void set_device();
55 void init_input_index();
56 void init();
57
58 torch::Tensor lossFunction(torch::Tensor enhancement, torch::Tensor target, torch::Tensor coef = torch::ones(1));
59 torch::Tensor lossFunction_new(torch::Tensor enhancement,
60 torch::Tensor target,
61 torch::Tensor weight,
62 torch::Tensor coef = torch::ones(1));
63};
64
65// class OF_data : public torch::data::Dataset<OF_data>
66// {
67// private:
68// torch::Tensor input;
69// torch::Tensor target;
70
71// public:
72// explicit OF_data(torch::Tensor &input, torch::Tensor &target)
73// {
74// this->input = input.clone();
75// this->target = target.clone();
76// }
77
78// torch::data::Example<> get(size_t index) override
79// {
80// return {this->input[index], this->target[index]};
81// }
82
83// torch::optional<size_t> size() const override
84// {
85// return this->input.size(0);
86// }
87// };
88
89#endif
Definition data.h:9
Definition input.h:7
Definition kernel.h:7
Definition pauli_potential.h:10
Definition train_kedf.h:14
std::vector< std::string > descriptor_type
Definition train_kedf.h:36
std::vector< int > kernel_index
Definition train_kedf.h:37
Data data_train
Definition train_kedf.h:27
torch::Device device
Definition train_kedf.h:34
Kernel * kernel_train
Definition train_kedf.h:23
Input input
Definition train_kedf.h:20
void train()
Definition train_kedf.cpp:228
Kernel * kernel_vali
Definition train_kedf.h:24
const double cTF
Definition train_kedf.h:47
void init()
Definition train_kedf.cpp:198
const double pqcoef
Definition train_kedf.h:49
Grid grid_train
Definition train_kedf.h:21
torch::Tensor feg_predict
Definition train_kedf.h:41
std::shared_ptr< NN_OFImpl > nn
Definition train_kedf.h:19
Train_KEDF()
Definition train_kedf.h:16
Data data_vali
Definition train_kedf.h:30
void potTest()
Definition train_kedf.cpp:430
Grid grid_vali
Definition train_kedf.h:22
torch::Tensor feg_dFdgamma
Definition train_kedf.h:42
~Train_KEDF()
Definition train_kedf.cpp:6
void set_device()
Definition train_kedf.cpp:69
int ninput
Definition train_kedf.h:35
void init_input_index()
Definition train_kedf.cpp:92
torch::Tensor lossFunction(torch::Tensor enhancement, torch::Tensor target, torch::Tensor coef=torch::ones(1))
Definition train_kedf.cpp:217
PauliPotential potential
Definition train_kedf.h:25
torch::Tensor feg_inpt
Definition train_kedf.h:40
double feg3_correct
Definition train_kedf.h:45
void setUpFFT()
Definition train_kedf.cpp:14
double * train_volume
Definition train_kedf.h:28
double * vali_volume
Definition train_kedf.h:31
torch::Tensor lossFunction_new(torch::Tensor enhancement, torch::Tensor target, torch::Tensor weight, torch::Tensor coef=torch::ones(1))
Definition train_kedf.cpp:222
Definition batch.h:6