ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_descriptor.h
Go to the documentation of this file.
1#ifndef DEEPKS_DESCRIPTOR_H
2#define DEEPKS_DESCRIPTOR_H
3
4#ifdef __MLALGO
5
6#include "deepks_param.h"
8#include "source_base/timer.h"
10
11#include <torch/script.h>
12#include <torch/torch.h>
13
14namespace DeePKS_domain
15{
16//------------------------
17// deepks_descriptor.cpp
18//------------------------
19
20// This file contains interfaces with libtorch,
21// including loading of model and calculating gradients
22// as well as subroutines that prints the results for checking
23
24// The file contains 8 subroutines:
25// 1. cal_descriptor : obtains descriptors which are eigenvalues of pdm
26// by calling torch::linalg::eigh
27// 2. check_descriptor : prints descriptor for checking
28// 3. cal_descriptor_equiv : calculates descriptor in equivalent version
29
32void cal_descriptor(const int nat,
33 const DeePKS_Param& deepks_param,
34 const std::vector<torch::Tensor>& pdm,
35 std::vector<torch::Tensor>& descriptor);
37void check_descriptor(const DeePKS_Param& deepks_param,
38 const UnitCell& ucell,
39 const std::string& out_dir,
40 const std::vector<torch::Tensor>& descriptor,
41 const int rank);
42
43void cal_descriptor_equiv(const int nat,
44 const DeePKS_Param& deepks_param,
45 const std::vector<torch::Tensor>& pdm,
46 std::vector<torch::Tensor>& descriptor);
47} // namespace DeePKS_domain
48#endif
49#endif
Definition unitcell.h:17
Definition deepks_basic.h:15
void cal_descriptor(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &pdm, std::vector< torch::Tensor > &descriptor)
Definition deepks_descriptor.cpp:38
void cal_descriptor_equiv(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &pdm, std::vector< torch::Tensor > &descriptor)
Definition deepks_descriptor.cpp:18
void check_descriptor(const DeePKS_Param &deepks_param, const UnitCell &ucell, const std::string &out_dir, const std::vector< torch::Tensor > &descriptor, const int rank)
print descriptors based on LCAO basis
Definition deepks_descriptor.cpp:73
Definition deepks_param.h:11