ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Protected Member Functions | Protected Attributes | List of all members
ML_Base Class Reference

#include <ml_base.h>

Inheritance diagram for ML_Base:
Collaboration diagram for ML_Base:

Public Member Functions

 ML_Base ()
 
 ~ML_Base ()
 
void set_device (std::string device_inpt)
 
void loadVector (std::string filename, std::vector< double > &data)
 
void dumpVector (std::string filename, const std::vector< double > &data)
 
void dumpTensor (std::string filename, const torch::Tensor &data)
 
void dumpMatrix (std::string filename, const ModuleBase::matrix &data)
 
torch::Tensor get_data (std::string parameter, const int ikernel) const
 

Public Attributes

int nx_tot = 0
 

Protected Member Functions

void updateInput (const double *const *prho, const ModulePW::PW_Basis *pw_rho)
 
void NN_forward (const double *const *prho, const ModulePW::PW_Basis *pw_rho, bool cal_grad)
 
void get_potential_ (const double *const *prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential)
 
double potGammaTerm (int ir)
 
double potPTerm1 (int ir)
 
double potQTerm1 (int ir)
 
double potXiTerm1 (int ir)
 
double potTanhxiTerm1 (int ir)
 
double potTanhpTerm1 (int ir)
 
double potTanhqTerm1 (int ir)
 
void potGammanlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rGammanlTerm)
 
void potXinlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rXinlTerm)
 
void potTanhxinlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhxinlTerm)
 
void potTanhxi_nlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhxi_nlTerm)
 
void potPPnlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rPPnlTerm)
 
void potQQnlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rQQnlTerm)
 
void potTanhpTanh_pnlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhpTanh_pnlTerm)
 
void potTanhqTanh_qnlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhqTanh_qnlTerm)
 
void potTanhpTanhp_nlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhpTanhp_nlTerm)
 
void potTanhqTanhq_nlTerm (const double *const *prho, const std::vector< double > &tau_lda, const ModulePW::PW_Basis *pw_rho, std::vector< double > &rTanhqTanhq_nlTerm)
 

Protected Attributes

ModuleIO::Cal_MLKEDF_Descriptorscal_tool = nullptr
 
int nx = 0
 
double dV = 0.
 
double pqcoef = 1.0 / (4.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0))
 
double feg_net_F = 0.0
 
double feg3_correct = 0.541324854612918
 
double energy_prefactor = 0.0
 
double energy_exponent = 0.0
 
int ninput = 0
 
std::vector< double > gamma
 
std::vector< double > p
 
std::vector< double > q
 
std::vector< std::vector< double > > gammanl
 
std::vector< std::vector< double > > pnl
 
std::vector< std::vector< double > > qnl
 
std::vector< std::vector< double > > nablaRho
 
std::vector< double > chi_xi
 
double chi_p = 1.0
 
double chi_q = 1.0
 
std::vector< std::vector< double > > xi
 
std::vector< std::vector< double > > tanhxi
 
std::vector< std::vector< double > > tanhxi_nl
 
std::vector< double > tanhp
 
std::vector< double > tanhq
 
std::vector< double > chi_pnl
 
std::vector< double > chi_qnl
 
std::vector< std::vector< double > > tanh_pnl
 
std::vector< std::vector< double > > tanh_qnl
 
std::vector< std::vector< double > > tanhp_nl
 
std::vector< std::vector< double > > tanhq_nl
 
torch::DeviceType device_type = torch::kCPU
 
torch::Device device = torch::Device(torch::kCPU)
 
torch::Device device_CPU = torch::Device(torch::kCPU)
 
std::shared_ptr< NN_OFImplnn
 
double * enhancement_cpu_ptr = nullptr
 
double * gradient_cpu_ptr = nullptr
 
int nkernel = 1
 
bool ml_gamma = false
 
bool ml_p = false
 
bool ml_q = false
 
bool ml_tanhp = false
 
bool ml_tanhq = false
 
bool ml_gammanl = false
 
bool ml_pnl = false
 
bool ml_qnl = false
 
bool ml_xi = false
 
bool ml_tanhxi = false
 
bool ml_tanhxi_nl = false
 
bool ml_tanh_pnl = false
 
bool ml_tanh_qnl = false
 
bool ml_tanhp_nl = false
 
bool ml_tanhq_nl = false
 
std::vector< std::string > descriptor_type
 
std::vector< int > kernel_index
 
std::map< std::string, std::vector< int > > descriptor2kernel
 
std::map< std::string, std::vector< int > > descriptor2index
 
std::map< std::string, std::vector< bool > > gene_data_label
 

Constructor & Destructor Documentation

◆ ML_Base()

ML_Base::ML_Base ( )

◆ ~ML_Base()

ML_Base::~ML_Base ( )

Member Function Documentation

◆ dumpMatrix()

void ML_Base::dumpMatrix ( std::string  filename,
const ModuleBase::matrix data 
)
Here is the call graph for this function:
Here is the caller graph for this function:

◆ dumpTensor()

void ML_Base::dumpTensor ( std::string  filename,
const torch::Tensor &  data 
)
Here is the call graph for this function:
Here is the caller graph for this function:

◆ dumpVector()

void ML_Base::dumpVector ( std::string  filename,
const std::vector< double > &  data 
)
Here is the caller graph for this function:

◆ get_data()

torch::Tensor ML_Base::get_data ( std::string  parameter,
const int  ikernel 
) const

◆ get_potential_()

void ML_Base::get_potential_ ( const double *const *  prho,
const ModulePW::PW_Basis pw_rho,
ModuleBase::matrix rpotential 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ loadVector()

void ML_Base::loadVector ( std::string  filename,
std::vector< double > &  data 
)
Here is the caller graph for this function:

◆ NN_forward()

void ML_Base::NN_forward ( const double *const *  prho,
const ModulePW::PW_Basis pw_rho,
bool  cal_grad 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potGammanlTerm()

void ML_Base::potGammanlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rGammanlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potGammaTerm()

double ML_Base::potGammaTerm ( int  ir)
protected
Here is the caller graph for this function:

◆ potPPnlTerm()

void ML_Base::potPPnlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rPPnlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potPTerm1()

double ML_Base::potPTerm1 ( int  ir)
protected
Here is the caller graph for this function:

◆ potQQnlTerm()

void ML_Base::potQQnlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rQQnlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potQTerm1()

double ML_Base::potQTerm1 ( int  ir)
protected
Here is the caller graph for this function:

◆ potTanhpTanh_pnlTerm()

void ML_Base::potTanhpTanh_pnlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rTanhpTanh_pnlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhpTanhp_nlTerm()

void ML_Base::potTanhpTanhp_nlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rTanhpTanhp_nlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhpTerm1()

double ML_Base::potTanhpTerm1 ( int  ir)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhqTanh_qnlTerm()

void ML_Base::potTanhqTanh_qnlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rTanhqTanh_qnlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhqTanhq_nlTerm()

void ML_Base::potTanhqTanhq_nlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rTanhqTanhq_nlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhqTerm1()

double ML_Base::potTanhqTerm1 ( int  ir)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhxi_nlTerm()

void ML_Base::potTanhxi_nlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rTanhxi_nlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhxinlTerm()

void ML_Base::potTanhxinlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rTanhxinlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potTanhxiTerm1()

double ML_Base::potTanhxiTerm1 ( int  ir)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potXinlTerm()

void ML_Base::potXinlTerm ( const double *const *  prho,
const std::vector< double > &  tau_lda,
const ModulePW::PW_Basis pw_rho,
std::vector< double > &  rXinlTerm 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

◆ potXiTerm1()

double ML_Base::potXiTerm1 ( int  ir)
protected
Here is the caller graph for this function:

◆ set_device()

void ML_Base::set_device ( std::string  device_inpt)
Here is the caller graph for this function:

◆ updateInput()

void ML_Base::updateInput ( const double *const *  prho,
const ModulePW::PW_Basis pw_rho 
)
protected
Here is the call graph for this function:
Here is the caller graph for this function:

Member Data Documentation

◆ cal_tool

ModuleIO::Cal_MLKEDF_Descriptors* ML_Base::cal_tool = nullptr
protected

◆ chi_p

double ML_Base::chi_p = 1.0
protected

◆ chi_pnl

std::vector<double> ML_Base::chi_pnl
protected

◆ chi_q

double ML_Base::chi_q = 1.0
protected

◆ chi_qnl

std::vector<double> ML_Base::chi_qnl
protected

◆ chi_xi

std::vector<double> ML_Base::chi_xi
protected

◆ descriptor2index

std::map<std::string, std::vector<int> > ML_Base::descriptor2index
protected

◆ descriptor2kernel

std::map<std::string, std::vector<int> > ML_Base::descriptor2kernel
protected

◆ descriptor_type

std::vector<std::string> ML_Base::descriptor_type
protected

◆ device

torch::Device ML_Base::device = torch::Device(torch::kCPU)
protected

◆ device_CPU

torch::Device ML_Base::device_CPU = torch::Device(torch::kCPU)
protected

◆ device_type

torch::DeviceType ML_Base::device_type = torch::kCPU
protected

◆ dV

double ML_Base::dV = 0.
protected

◆ energy_exponent

double ML_Base::energy_exponent = 0.0
protected

◆ energy_prefactor

double ML_Base::energy_prefactor = 0.0
protected

◆ enhancement_cpu_ptr

double* ML_Base::enhancement_cpu_ptr = nullptr
protected

◆ feg3_correct

double ML_Base::feg3_correct = 0.541324854612918
protected

◆ feg_net_F

double ML_Base::feg_net_F = 0.0
protected

◆ gamma

std::vector<double> ML_Base::gamma
protected

◆ gammanl

std::vector<std::vector<double> > ML_Base::gammanl
protected

◆ gene_data_label

std::map<std::string, std::vector<bool> > ML_Base::gene_data_label
protected

◆ gradient_cpu_ptr

double* ML_Base::gradient_cpu_ptr = nullptr
protected

◆ kernel_index

std::vector<int> ML_Base::kernel_index
protected

◆ ml_gamma

bool ML_Base::ml_gamma = false
protected

◆ ml_gammanl

bool ML_Base::ml_gammanl = false
protected

◆ ml_p

bool ML_Base::ml_p = false
protected

◆ ml_pnl

bool ML_Base::ml_pnl = false
protected

◆ ml_q

bool ML_Base::ml_q = false
protected

◆ ml_qnl

bool ML_Base::ml_qnl = false
protected

◆ ml_tanh_pnl

bool ML_Base::ml_tanh_pnl = false
protected

◆ ml_tanh_qnl

bool ML_Base::ml_tanh_qnl = false
protected

◆ ml_tanhp

bool ML_Base::ml_tanhp = false
protected

◆ ml_tanhp_nl

bool ML_Base::ml_tanhp_nl = false
protected

◆ ml_tanhq

bool ML_Base::ml_tanhq = false
protected

◆ ml_tanhq_nl

bool ML_Base::ml_tanhq_nl = false
protected

◆ ml_tanhxi

bool ML_Base::ml_tanhxi = false
protected

◆ ml_tanhxi_nl

bool ML_Base::ml_tanhxi_nl = false
protected

◆ ml_xi

bool ML_Base::ml_xi = false
protected

◆ nablaRho

std::vector<std::vector<double> > ML_Base::nablaRho
protected

◆ ninput

int ML_Base::ninput = 0
protected

◆ nkernel

int ML_Base::nkernel = 1
protected

◆ nn

std::shared_ptr<NN_OFImpl> ML_Base::nn
protected

◆ nx

int ML_Base::nx = 0
protected

◆ nx_tot

int ML_Base::nx_tot = 0

◆ p

std::vector<double> ML_Base::p
protected

◆ pnl

std::vector<std::vector<double> > ML_Base::pnl
protected

◆ pqcoef

double ML_Base::pqcoef = 1.0 / (4.0 * std::pow(3*std::pow(M_PI, 2.0), 2.0/3.0))
protected

◆ q

std::vector<double> ML_Base::q
protected

◆ qnl

std::vector<std::vector<double> > ML_Base::qnl
protected

◆ tanh_pnl

std::vector<std::vector<double> > ML_Base::tanh_pnl
protected

◆ tanh_qnl

std::vector<std::vector<double> > ML_Base::tanh_qnl
protected

◆ tanhp

std::vector<double> ML_Base::tanhp
protected

◆ tanhp_nl

std::vector<std::vector<double> > ML_Base::tanhp_nl
protected

◆ tanhq

std::vector<double> ML_Base::tanhq
protected

◆ tanhq_nl

std::vector<std::vector<double> > ML_Base::tanhq_nl
protected

◆ tanhxi

std::vector<std::vector<double> > ML_Base::tanhxi
protected

◆ tanhxi_nl

std::vector<std::vector<double> > ML_Base::tanhxi_nl
protected

◆ xi

std::vector<std::vector<double> > ML_Base::xi
protected

The documentation for this class was generated from the following files: