11#include <torch/torch.h>
19 std::shared_ptr<NN_OFImpl>
nn;
34 torch::Device
device = torch::Device(torch::kCUDA);
47 = 3.0 / 10.0 * std::pow(3 * std::pow(M_PI, 2.0), 2.0 / 3.0)
49 const double pqcoef = 1.0 / (4.0 * std::pow(3 * std::pow(M_PI, 2.0), 2.0 / 3.0));
58 torch::Tensor
lossFunction(torch::Tensor enhancement, torch::Tensor target, torch::Tensor coef = torch::ones(1));
62 torch::Tensor coef = torch::ones(1));
Definition pauli_potential.h:10
Definition train_kedf.h:14
std::vector< std::string > descriptor_type
Definition train_kedf.h:36
std::vector< int > kernel_index
Definition train_kedf.h:37
Data data_train
Definition train_kedf.h:27
torch::Device device
Definition train_kedf.h:34
Kernel * kernel_train
Definition train_kedf.h:23
Input input
Definition train_kedf.h:20
void train()
Definition train_kedf.cpp:228
Kernel * kernel_vali
Definition train_kedf.h:24
const double cTF
Definition train_kedf.h:47
void init()
Definition train_kedf.cpp:198
const double pqcoef
Definition train_kedf.h:49
Grid grid_train
Definition train_kedf.h:21
torch::Tensor feg_predict
Definition train_kedf.h:41
std::shared_ptr< NN_OFImpl > nn
Definition train_kedf.h:19
Train_KEDF()
Definition train_kedf.h:16
Data data_vali
Definition train_kedf.h:30
void potTest()
Definition train_kedf.cpp:430
Grid grid_vali
Definition train_kedf.h:22
torch::Tensor feg_dFdgamma
Definition train_kedf.h:42
~Train_KEDF()
Definition train_kedf.cpp:6
void set_device()
Definition train_kedf.cpp:69
int ninput
Definition train_kedf.h:35
void init_input_index()
Definition train_kedf.cpp:92
torch::Tensor lossFunction(torch::Tensor enhancement, torch::Tensor target, torch::Tensor coef=torch::ones(1))
Definition train_kedf.cpp:217
PauliPotential potential
Definition train_kedf.h:25
torch::Tensor feg_inpt
Definition train_kedf.h:40
double feg3_correct
Definition train_kedf.h:45
void setUpFFT()
Definition train_kedf.cpp:14
double * train_volume
Definition train_kedf.h:28
double * vali_volume
Definition train_kedf.h:31
torch::Tensor lossFunction_new(torch::Tensor enhancement, torch::Tensor target, torch::Tensor weight, torch::Tensor coef=torch::ones(1))
Definition train_kedf.cpp:222