ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_descriptor.h
Go to the documentation of this file.
1#ifndef DEEPKS_DESCRIPTOR_H
2#define DEEPKS_DESCRIPTOR_H
3
4#ifdef __MLALGO
5
6#include "deepks_param.h"
8
9#include <torch/script.h>
10#include <torch/torch.h>
11
12namespace DeePKS_domain
13{
14//------------------------
15// deepks_descriptor.cpp
16//------------------------
17
18// This file contains interfaces with libtorch,
19// including loading of model and calculating gradients
20// as well as subroutines that prints the results for checking
21
22// The file contains 8 subroutines:
23// 1. cal_descriptor : obtains descriptors which are eigenvalues of pdm
24// by calling torch::linalg::eigh
25// 2. check_descriptor : prints descriptor for checking
26// 3. cal_descriptor_equiv : calculates descriptor in equivalent version
27
30void cal_descriptor(const int nat,
31 const DeePKS_Param& deepks_param,
32 const std::vector<torch::Tensor>& pdm,
33 std::vector<torch::Tensor>& descriptor);
35void check_descriptor(const DeePKS_Param& deepks_param,
36 const UnitCell& ucell,
37 const std::string& out_dir,
38 const std::vector<torch::Tensor>& descriptor,
39 const int rank);
40
41void cal_descriptor_equiv(const int nat,
42 const DeePKS_Param& deepks_param,
43 const std::vector<torch::Tensor>& pdm,
44 std::vector<torch::Tensor>& descriptor);
45} // namespace DeePKS_domain
46#endif
47#endif
Definition unitcell.h:15
Definition deepks_basic.h:12
void cal_descriptor(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &pdm, std::vector< torch::Tensor > &descriptor)
Definition deepks_descriptor.cpp:38
void cal_descriptor_equiv(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &pdm, std::vector< torch::Tensor > &descriptor)
Definition deepks_descriptor.cpp:18
void check_descriptor(const DeePKS_Param &deepks_param, const UnitCell &ucell, const std::string &out_dir, const std::vector< torch::Tensor > &descriptor, const int rank)
print descriptors based on LCAO basis
Definition deepks_descriptor.cpp:73
Definition deepks_param.h:11