ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
nn_of.h
Go to the documentation of this file.
1#ifndef NN_OF_H
2#define NN_OF_H
3
4#include <torch/torch.h>
5
6struct NN_OFImpl:torch::nn::Module{
7 // three hidden layers and one output layer
9 int nrxx,
10 int nrxx_vali,
11 int ninpt,
12 int nnode,
13 int nlayer,
14 torch::Device device
15 );
17 {
18 // delete[] this->fcs;
19 };
20
21
22 template <class T>
24 T *data,
25 const std::vector<std::string> &descriptor_type,
26 const std::vector<int> &kernel_index,
27 torch::Tensor &nn_input
28 )
29 {
30 if (data->nx_tot <= 0) return;
31 for (int i = 0; i < descriptor_type.size(); ++i)
32 {
33 nn_input.index({"...", i}) = data->get_data(descriptor_type[i], kernel_index[i]);
34 }
35 }
36
37 torch::Tensor forward(torch::Tensor inpt);
38
39 // torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}, fc4{nullptr}, fc5{nullptr};
40 // torch::nn::Linear fcs[5] = {fc1, fc2, fc3, fc4, fc5};
41
42 torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}, fc4{nullptr};
43
44 torch::Tensor inputs;
45 torch::Tensor input_vali;
46 torch::Tensor F; // enhancement factor, output of NN
47
48 int nrxx = 10;
49 int nrxx_vali = 0;
50 int ninpt = 6;
51 int nnode = 10;
52 int nlayer = 3;
53 int nfc = 4;
54};
56
57#endif
#define T
Definition exp.cpp:237
TORCH_MODULE(NN_OF)
Definition nn_of.h:6
torch::Tensor inputs
Definition nn_of.h:44
torch::nn::Linear fc4
Definition nn_of.h:42
torch::nn::Linear fc1
Definition nn_of.h:42
int nnode
Definition nn_of.h:51
torch::Tensor input_vali
Definition nn_of.h:45
int nlayer
Definition nn_of.h:52
~NN_OFImpl()
Definition nn_of.h:16
int nrxx
Definition nn_of.h:48
torch::nn::Linear fc3
Definition nn_of.h:42
torch::Tensor F
Definition nn_of.h:46
void set_data(T *data, const std::vector< std::string > &descriptor_type, const std::vector< int > &kernel_index, torch::Tensor &nn_input)
Definition nn_of.h:23
torch::nn::Linear fc2
Definition nn_of.h:42
int ninpt
Definition nn_of.h:50
int nfc
Definition nn_of.h:53
torch::Tensor forward(torch::Tensor inpt)
Definition nn_of.cpp:27
int nrxx_vali
Definition nn_of.h:49