ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_basic.h
Go to the documentation of this file.
1#ifndef DEEPKS_BASIC_H
2#define DEEPKS_BASIC_H
3
4#ifdef __MLALGO
5#include "deepks_param.h"
7
8#include <torch/script.h>
9#include <torch/torch.h>
10
12{
13//------------------------
14// deepks_basic.cpp
15//------------------------
16
17// The file contains 2 subroutines:
18// 1. load_model : loads model for applying V_delta
19// 2. cal_gevdm : d(des)/d(pdm), calculated using torch::autograd::grad
20// 3. cal_edelta_gedm : calculates E_delta and d(E_delta)/d(pdm)
21// this is the term V(D) that enters the expression V_delta = |alpha>V(D)<alpha|
22// caculated using torch::autograd::grad
23// 4. check_gedm : prints gedm for checking
24// 5. cal_edelta_gedm_equiv : calculates E_delta and d(E_delta)/d(pdm) for equivariant version
25// 6. prepare_atom : prepares atom tensor for output as deepks_out_labels = 2
26// 7. prepare_box : prepares box tensor for output as deepks_out_labels = 2
27
28// load the trained neural network models
29void load_model(const std::string& model_file, torch::jit::script::Module& model);
30
31// calculate gevdm
32void cal_gevdm(const int nat,
33 const DeePKS_Param& deepks_param,
34 const std::vector<torch::Tensor>& pdm,
35 std::vector<torch::Tensor>& gevdm);
36
38void cal_edelta_gedm(const int nat,
39 const DeePKS_Param& deepks_param,
40 const std::vector<torch::Tensor>& descriptor,
41 const std::vector<torch::Tensor>& pdm,
42 torch::jit::script::Module& model_deepks,
43 double** gedm,
44 double& E_delta);
45void check_gedm(const DeePKS_Param& deepks_param, double** gedm);
46void cal_edelta_gedm_equiv(const int nat,
47 const DeePKS_Param& deepks_param,
48 const std::vector<torch::Tensor>& descriptor,
49 torch::jit::script::Module& model_deepks,
50 double** gedm,
51 double& E_delta,
52 const int rank);
53
54void prepare_atom(const UnitCell& ucell, torch::Tensor& atom_out);
55void prepare_box(const UnitCell& ucell, torch::Tensor& box_out);
56} // namespace DeePKS_domain
57#endif
58#endif
Definition unitcell.h:15
Definition deepks_basic.h:12
void cal_gevdm(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &pdm, std::vector< torch::Tensor > &gevdm)
Definition deepks_basic.cpp:20
void load_model(const std::string &model_file, torch::jit::script::Module &model)
Definition deepks_basic.cpp:64
void cal_edelta_gedm(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &descriptor, const std::vector< torch::Tensor > &pdm, torch::jit::script::Module &model_deepks, double **gedm, double &E_delta)
calculate partial of energy correction to descriptors
Definition deepks_basic.cpp:218
void cal_edelta_gedm_equiv(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &descriptor, torch::jit::script::Module &model_deepks, double **gedm, double &E_delta, const int rank)
Definition deepks_basic.cpp:92
void check_gedm(const DeePKS_Param &deepks_param, double **gedm)
Definition deepks_basic.cpp:278
void prepare_box(const UnitCell &ucell, torch::Tensor &box_out)
Definition deepks_basic.cpp:320
void prepare_atom(const UnitCell &ucell, torch::Tensor &atom_out)
Definition deepks_basic.cpp:298
Definition deepks_param.h:11