ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
deepks_fpre.h
Go to the documentation of this file.
1#ifndef DEEPKS_FPRE_H
2#define DEEPKS_FPRE_H
3
4#ifdef __MLALGO
5
6#include "deepks_param.h"
10#include "source_base/timer.h"
15
16#include <torch/script.h>
17#include <torch/torch.h>
18
19namespace DeePKS_domain
20{
21//------------------------
22// deepks_fpre.cpp
23//------------------------
24
25// This file contains 2 subroutines for calculating,
26// 1. cal_gdmx, calculating gdmx
27// 2. cal_gvx : gvx is used for training with force label, which is gradient of descriptors,
28// calculated by d(des)/dX = d(pdm)/dX * d(des)/d(pdm) = gdmx * gvdm
29// using einsum
30
31// calculate the gradient of pdm with regard to atomic positions
32// d/dX D_{Inl,mm'}
33template <typename TK>
34void cal_gdmx(const int nks,
35 const DeePKS_Param& deepks_param,
36 const std::vector<ModuleBase::Vector3<double>>& kvec_d,
37 std::vector<hamilt::HContainer<double>*> phialpha,
39 const UnitCell& ucell,
40 const LCAO_Orbitals& orb,
41 const Parallel_Orbitals& pv,
42 const Grid_Driver& GridD,
43 torch::Tensor& gdmx);
44
54void cal_gvx(const int nat,
55 const DeePKS_Param& deepks_param,
56 const std::vector<torch::Tensor>& gevdm,
57 const torch::Tensor& gdmx,
58 torch::Tensor& gvx,
59 const int rank);
60
61} // namespace DeePKS_domain
62#endif
63#endif
Definition sltk_grid_driver.h:43
Definition ORB_read.h:19
3 elements vector
Definition vector3.h:22
Definition parallel_orbitals.h:9
Definition unitcell.h:17
Definition hcontainer.h:144
Definition deepks_basic.h:15
void cal_gvx(const int nat, const DeePKS_Param &deepks_param, const std::vector< torch::Tensor > &gevdm, const torch::Tensor &gdmx, torch::Tensor &gvx, const int rank)
Definition deepks_fpre.cpp:144
void cal_gdmx(const int nks, const DeePKS_Param &deepks_param, const std::vector< ModuleBase::Vector3< double > > &kvec_d, std::vector< hamilt::HContainer< double > * > phialpha, const hamilt::HContainer< double > *dmr, const UnitCell &ucell, const LCAO_Orbitals &orb, const Parallel_Orbitals &pv, const Grid_Driver &GridD, torch::Tensor &gdmx)
Definition deepks_fpre.cpp:17
Definition deepks_param.h:11