ABACUS
develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
source
source_pw
module_ofdft
nn_of.h
Go to the documentation of this file.
1
#ifndef NN_OF_H
2
#define NN_OF_H
3
4
#include <torch/torch.h>
5
6
struct
NN_OFImpl
:torch::nn::Module{
7
// three hidden layers and one output layer
8
NN_OFImpl
(
9
int
nrxx
,
10
int
nrxx_vali
,
11
int
ninpt
,
12
int
nnode
,
13
int
nlayer
,
14
torch::Device device
15
);
16
~NN_OFImpl
()
17
{
18
// delete[] this->fcs;
19
};
20
21
22
template
<
class
T>
23
void
set_data
(
24
T
*data,
25
const
std::vector<std::string> &descriptor_type,
26
const
std::vector<int> &kernel_index,
27
torch::Tensor &nn_input
28
)
29
{
30
if
(data->nx_tot <= 0)
return
;
31
for
(
int
i
= 0;
i
< descriptor_type.size(); ++
i
)
32
{
33
nn_input.index({
"..."
,
i
}) = data->get_data(descriptor_type[
i
], kernel_index[
i
]);
34
}
35
}
36
37
torch::Tensor
forward
(torch::Tensor inpt);
38
39
// torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}, fc4{nullptr}, fc5{nullptr};
40
// torch::nn::Linear fcs[5] = {fc1, fc2, fc3, fc4, fc5};
41
42
torch::nn::Linear
fc1
{
nullptr
},
fc2
{
nullptr
},
fc3
{
nullptr
},
fc4
{
nullptr
};
43
44
torch::Tensor
inputs
;
45
torch::Tensor
input_vali
;
46
torch::Tensor
F
;
// enhancement factor, output of NN
47
48
int
nrxx
= 10;
49
int
nrxx_vali
= 0;
50
int
ninpt
= 6;
51
int
nnode
= 10;
52
int
nlayer
= 3;
53
int
nfc
= 4;
54
};
55
TORCH_MODULE
(NN_OF);
56
57
#endif
i
const std::complex< double > i
Definition
cal_pLpR.cpp:46
T
#define T
Definition
exp.cpp:237
TORCH_MODULE
TORCH_MODULE(NN_OF)
NN_OFImpl
Definition
nn_of.h:6
NN_OFImpl::inputs
torch::Tensor inputs
Definition
nn_of.h:44
NN_OFImpl::fc4
torch::nn::Linear fc4
Definition
nn_of.h:42
NN_OFImpl::fc1
torch::nn::Linear fc1
Definition
nn_of.h:42
NN_OFImpl::nnode
int nnode
Definition
nn_of.h:51
NN_OFImpl::input_vali
torch::Tensor input_vali
Definition
nn_of.h:45
NN_OFImpl::nlayer
int nlayer
Definition
nn_of.h:52
NN_OFImpl::~NN_OFImpl
~NN_OFImpl()
Definition
nn_of.h:16
NN_OFImpl::nrxx
int nrxx
Definition
nn_of.h:48
NN_OFImpl::fc3
torch::nn::Linear fc3
Definition
nn_of.h:42
NN_OFImpl::F
torch::Tensor F
Definition
nn_of.h:46
NN_OFImpl::set_data
void set_data(T *data, const std::vector< std::string > &descriptor_type, const std::vector< int > &kernel_index, torch::Tensor &nn_input)
Definition
nn_of.h:23
NN_OFImpl::fc2
torch::nn::Linear fc2
Definition
nn_of.h:42
NN_OFImpl::ninpt
int ninpt
Definition
nn_of.h:50
NN_OFImpl::nfc
int nfc
Definition
nn_of.h:53
NN_OFImpl::forward
torch::Tensor forward(torch::Tensor inpt)
Definition
nn_of.cpp:27
NN_OFImpl::nrxx_vali
int nrxx_vali
Definition
nn_of.h:49
Generated by
1.9.8