ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
data.h
Go to the documentation of this file.
1#ifndef DATA_H
2#define DATA_H
3
4#include "./input.h"
5
6#include <torch/torch.h>
7
8class Data
9{
10 public:
11 // --------- load the data from .npy files ------
12 ~Data();
13
14 int nx = 0;
15 int nx_tot = 0;
16
17 // =========== data ===========
18 torch::Tensor rho;
19 torch::Tensor nablaRho;
20 torch::Tensor tau_tf;
21 // semi-local descriptors
22 torch::Tensor gamma;
23 torch::Tensor p;
24 torch::Tensor q;
25 torch::Tensor tanhp;
26 torch::Tensor tanhq;
27 // non-local descriptors
28 std::vector<torch::Tensor> gammanl = {};
29 std::vector<torch::Tensor> pnl = {};
30 std::vector<torch::Tensor> qnl = {};
31 std::vector<torch::Tensor> xi = {};
32 std::vector<torch::Tensor> tanhxi = {};
33 std::vector<torch::Tensor> tanhxi_nl = {};
34 std::vector<torch::Tensor> tanh_pnl = {};
35 std::vector<torch::Tensor> tanh_qnl = {};
36 std::vector<torch::Tensor> tanhp_nl = {};
37 std::vector<torch::Tensor> tanhq_nl = {};
38 // target
39 torch::Tensor enhancement;
40 torch::Tensor pauli;
41 torch::Tensor enhancement_mean;
42 torch::Tensor tau_mean; // mean Pauli energy
43 torch::Tensor pauli_mean;
44
45 // =========== label ===========
46 bool load_gamma = false;
47 bool load_p = false;
48 bool load_q = false;
49 bool load_tanhp = false;
50 bool load_tanhq = false;
51 bool* load_gammanl = nullptr;
52 bool* load_pnl = nullptr;
53 bool* load_qnl = nullptr;
54 bool* load_xi = nullptr;
55 bool* load_tanhxi = nullptr;
56 bool* load_tanhxi_nl = nullptr;
57 bool* load_tanh_pnl = nullptr;
58 bool* load_tanh_qnl = nullptr;
59 bool* load_tanhp_nl = nullptr;
60 bool* load_tanhq_nl = nullptr;
61
62 void load_data(Input &input, const int ndata, std::string *dir, const torch::Device device);
63 torch::Tensor get_data(std::string parameter, const int ikernel);
64
65 private:
66 void init_label(Input &input);
67 void init_data(const int nkernel, const int ndata, const int fftdim, const torch::Device device);
68 void load_data_(Input &input, const int ndata, const int fftdim, std::string *dir);
69
70 const double cTF = 3.0/10.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0) * 2; // 10/3*(3*pi^2)^{2/3}, multiply by 2 to convert unit from Hartree to Ry, finally in Ry*Bohr^(-2)
71
72 public:
73 void loadTensor(std::string file,
74 std::vector<long unsigned int> cshape,
75 bool fortran_order,
76 std::vector<double> &container,
77 const int index,
78 const int fftdim,
79 torch::Tensor &data);
80 // -------- dump Tensor into .npy files ---------
81 void dumpTensor(const torch::Tensor &data, std::string filename, int nx);
82 std::string file_name(std::string parameter, const int kernel_type, const double kernel_scaling);
83};
84#endif
Definition data.h:9
std::vector< torch::Tensor > tanhxi
Definition data.h:32
torch::Tensor q
Definition data.h:24
bool * load_tanh_qnl
Definition data.h:58
torch::Tensor get_data(std::string parameter, const int ikernel)
Definition data.cpp:32
bool * load_tanhxi
Definition data.h:55
void loadTensor(std::string file, std::vector< long unsigned int > cshape, bool fortran_order, std::vector< double > &container, const int index, const int fftdim, torch::Tensor &data)
Definition data.cpp:326
int nx
Definition data.h:14
bool load_tanhp
Definition data.h:49
std::vector< torch::Tensor > gammanl
Definition data.h:28
torch::Tensor enhancement_mean
Definition data.h:41
std::string file_name(std::string parameter, const int kernel_type, const double kernel_scaling)
Definition data.cpp:352
int nx_tot
Definition data.h:15
torch::Tensor rho
Definition data.h:18
bool load_p
Definition data.h:47
void init_label(Input &input)
Definition data.cpp:81
torch::Tensor tau_mean
Definition data.h:42
torch::Tensor tanhq
Definition data.h:26
bool * load_gammanl
Definition data.h:51
std::vector< torch::Tensor > xi
Definition data.h:31
bool * load_pnl
Definition data.h:52
torch::Tensor pauli_mean
Definition data.h:43
const double cTF
Definition data.h:70
void load_data_(Input &input, const int ndata, const int fftdim, std::string *dir)
Definition data.cpp:231
std::vector< torch::Tensor > pnl
Definition data.h:29
bool * load_tanh_pnl
Definition data.h:57
std::vector< torch::Tensor > tanh_qnl
Definition data.h:35
void init_data(const int nkernel, const int ndata, const int fftdim, const torch::Device device)
Definition data.cpp:150
torch::Tensor tanhp
Definition data.h:25
std::vector< torch::Tensor > tanhp_nl
Definition data.h:36
void dumpTensor(const torch::Tensor &data, std::string filename, int nx)
Definition data.cpp:340
bool * load_tanhxi_nl
Definition data.h:56
bool load_tanhq
Definition data.h:50
torch::Tensor enhancement
Definition data.h:39
std::vector< torch::Tensor > tanh_pnl
Definition data.h:34
~Data()
Definition data.cpp:4
torch::Tensor gamma
Definition data.h:22
bool * load_xi
Definition data.h:54
bool * load_tanhp_nl
Definition data.h:59
torch::Tensor tau_tf
Definition data.h:20
bool load_gamma
Definition data.h:46
bool load_q
Definition data.h:48
std::vector< torch::Tensor > tanhq_nl
Definition data.h:37
torch::Tensor p
Definition data.h:23
std::vector< torch::Tensor > tanhxi_nl
Definition data.h:33
std::vector< torch::Tensor > qnl
Definition data.h:30
bool * load_tanhq_nl
Definition data.h:60
void load_data(Input &input, const int ndata, std::string *dir, const torch::Device device)
Definition data.cpp:18
torch::Tensor pauli
Definition data.h:40
bool * load_qnl
Definition data.h:53
torch::Tensor nablaRho
Definition data.h:19
Definition input.h:7
Definition tensor.cpp:8
file(GLOB ATen_CORE_SRCS "*.cpp") set(ATen_CPU_SRCS $
Definition CMakeLists.txt:1