ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_basic.h
Go to the documentation of this file.
1#ifndef DEEPKS_BASIC_H
2#define DEEPKS_BASIC_H
3
4#ifdef __MLALGO
5#include "LCAO_deepks_io.h"
6#include "deepks_param.h"
10
11#include <torch/script.h>
12#include <torch/torch.h>
13
15{
16//------------------------
17// deepks_basic.cpp
18//------------------------
19
20// The file contains 2 subroutines:
21// 1. load_model : loads model for applying V_delta
22// 2. cal_gevdm : d(des)/d(pdm), calculated using torch::autograd::grad
23// 3. cal_edelta_gedm : calculates E_delta and d(E_delta)/d(pdm)
24// this is the term V(D) that enters the expression V_delta = |alpha>V(D)<alpha|
25// caculated using torch::autograd::grad
26// 4. check_gedm : prints gedm for checking
27// 5. cal_edelta_gedm_equiv : calculates E_delta and d(E_delta)/d(pdm) for equivariant version
28// 6. prepare_atom : prepares atom tensor for output as deepks_out_labels = 2
29// 7. prepare_box : prepares box tensor for output as deepks_out_labels = 2
30
31// load the trained neural network models
32void load_model(const std::string& model_file, torch::jit::script::Module& model);
33
34// calculate gevdm
35void cal_gevdm(const int nat,
36 const DeePKS_Param& deepks_param,
37 const std::vector<torch::Tensor>& pdm,
38 std::vector<torch::Tensor>& gevdm);
39
41void cal_edelta_gedm(const int nat,
42 const DeePKS_Param& deepks_param,
43 const std::vector<torch::Tensor>& descriptor,
44 const std::vector<torch::Tensor>& pdm,
45 torch::jit::script::Module& model_deepks,
46 double** gedm,
47 double& E_delta);
48void check_gedm(const DeePKS_Param& deepks_param, double** gedm);
49void cal_edelta_gedm_equiv(const int nat,
50 const DeePKS_Param& deepks_param,
51 const std::vector<torch::Tensor>& descriptor,
52 double** gedm,
53 double& E_delta,
54 const int rank);
55
56void prepare_atom(const UnitCell& ucell, torch::Tensor& atom_out);
57void prepare_box(const UnitCell& ucell, torch::Tensor& box_out);
58} // namespace DeePKS_domain
59#endif
60#endif
Definition unitcell.h:17
Definition deepks_basic.h:15
void cal_gevdm(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &pdm, std::vector< torch::Tensor > &gevdm)
Definition deepks_basic.cpp:16
void load_model(const std::string &model_file, torch::jit::script::Module &model)
Definition deepks_basic.cpp:60
void cal_edelta_gedm(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &descriptor, const std::vector< torch::Tensor > &pdm, torch::jit::script::Module &model_deepks, double **gedm, double &E_delta)
calculate partial of energy correction to descriptors
Definition deepks_basic.cpp:171
void check_gedm(const DeePKS_Param &deepks_param, double **gedm)
Definition deepks_basic.cpp:239
void prepare_box(const UnitCell &ucell, torch::Tensor &box_out)
Definition deepks_basic.cpp:281
void prepare_atom(const UnitCell &ucell, torch::Tensor &atom_out)
Definition deepks_basic.cpp:259
void cal_edelta_gedm_equiv(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &descriptor, double **gedm, double &E_delta, const int rank)
Definition deepks_basic.cpp:136
Definition deepks_param.h:11