ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
RI_2D_Comm.hpp
Go to the documentation of this file.
1//=======================
2// AUTHOR : Peize Lin
3// DATE : 2022-08-17
4//=======================
5
6#ifndef RI_2D_COMM_HPP
7#define RI_2D_COMM_HPP
8
9#include "RI_2D_Comm.h"
10#include "RI_Util.h"
13#include "source_base/timer.h"
16#include <RI/global/Global_Func-2.h>
17
18#include <cmath>
19#include <string>
20#include <stdexcept>
21
22inline RI::Tensor<double> tensor_conj(const RI::Tensor<double>& t) { return t; }
23inline RI::Tensor<std::complex<double>> tensor_conj(const RI::Tensor<std::complex<double>>& t)
24{
25 RI::Tensor<std::complex<double>> r(t.shape);
26 for (int i = 0; i < t.data->size(); ++i) {
27 (*r.data)[i] = std::conj((*t.data)[i]);
28 }
29 return r;
30}
31template<typename Tdata, typename Tmatrix>
32auto RI_2D_Comm::split_m2D_ktoR(const UnitCell& ucell,
33 const K_Vectors & kv,
34 const std::vector<const Tmatrix*>&mks_2D,
35 const Parallel_2D & pv,
36 const int nspin,
37 const bool spgsym)
38-> std::vector<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>>
39{
40 ModuleBase::TITLE("RI_2D_Comm","split_m2D_ktoR");
41 ModuleBase::timer::tick("RI_2D_Comm", "split_m2D_ktoR");
42
43 const TC period = RI_Util::get_Born_vonKarmen_period(kv);
44 const std::map<int,int> nspin_k = {{1,1}, {2,2}, {4,1}};
45 const double SPIN_multiple = std::map<int, double>{ {1,0.5}, {2,1}, {4,1} }.at(nspin); // why?
46
47 std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> mRs_a2D(nspin);
48 for (int is_k = 0; is_k < nspin_k.at(nspin); ++is_k)
49 {
50 const std::vector<int> ik_list = RI_2D_Comm::get_ik_list(kv, is_k);
51 for(const TC &cell : RI_Util::get_Born_von_Karmen_cells(period))
52 {
53 RI::Tensor<Tdata> mR_2D;
54 int ik_full = 0;
55 for (const int ik : ik_list)
56 {
57 auto set_mR_2D = [&mR_2D](auto&& mk_frac) {
58 if (mR_2D.empty()) {
59 mR_2D = RI::Global_Func::convert<Tdata>(mk_frac);
60 } else {
61 mR_2D
62 = mR_2D + RI::Global_Func::convert<Tdata>(mk_frac);
63 }
64 };
65 using Tdata_m = typename Tmatrix::value_type;
66 if (!spgsym)
67 {
68 RI::Tensor<Tdata_m> mk_2D = RI_Util::Vector_to_Tensor<Tdata_m>(*mks_2D[ik], pv.get_col_size(), pv.get_row_size());
69 const Tdata_m frac = SPIN_multiple
70 * RI::Global_Func::convert<Tdata_m>(std::exp(
71 -ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * (kv.kvec_c[ik] * (RI_Util::array3_to_Vector3(cell) * ucell.latvec))));
72 if (static_cast<int>(std::round(SPIN_multiple * kv.wk[ik] * kv.get_nkstot_full())) == 2)
73 { set_mR_2D(mk_2D * (frac * 0.5) + tensor_conj(mk_2D * (frac * 0.5))); }
74 else { set_mR_2D(mk_2D * frac); }
75 }
76 else
77 { // traverse kstar, ik means ik_ibz
78 for (auto& isym_kvd : kv.kstars[ik % ik_list.size()])
79 {
80 RI::Tensor<Tdata_m> mk_2D = RI_Util::Vector_to_Tensor<Tdata_m>(*mks_2D[ik_full + is_k * kv.get_nkstot_full()], pv.get_col_size(), pv.get_row_size());
81 const Tdata_m frac = SPIN_multiple
82 * RI::Global_Func::convert<Tdata_m>(std::exp(
83 -ModuleBase::TWO_PI * ModuleBase::IMAG_UNIT * ((isym_kvd.second * ucell.G) * (RI_Util::array3_to_Vector3(cell) * ucell.latvec))));
84 set_mR_2D(mk_2D * frac);
85 ++ik_full;
86 }
87 }
88 }
89 for(int iwt0_2D=0; iwt0_2D!=mR_2D.shape[0]; ++iwt0_2D)
90 {
91 const int iwt0 =ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver)
92 ? pv.local2global_col(iwt0_2D)
93 : pv.local2global_row(iwt0_2D);
94 int iat0, iw0_b, is0_b;
95 std::tie(iat0,iw0_b,is0_b) = RI_2D_Comm::get_iat_iw_is_block(ucell,iwt0);
96 const int it0 = ucell.iat2it[iat0];
97 for(int iwt1_2D=0; iwt1_2D!=mR_2D.shape[1]; ++iwt1_2D)
98 {
99 const int iwt1 =ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver)
100 ? pv.local2global_row(iwt1_2D)
101 : pv.local2global_col(iwt1_2D);
102 int iat1, iw1_b, is1_b;
103 std::tie(iat1,iw1_b,is1_b) = RI_2D_Comm::get_iat_iw_is_block(ucell,iwt1);
104 const int it1 = ucell.iat2it[iat1];
105
106 const int is_b = RI_2D_Comm::get_is_block(is_k, is0_b, is1_b);
107 RI::Tensor<Tdata> &mR_a2D = mRs_a2D[is_b][iat0][{iat1,cell}];
108 if (mR_a2D.empty()) {
109 mR_a2D = RI::Tensor<Tdata>(
110 {static_cast<size_t>(ucell.atoms[it0].nw),
111 static_cast<size_t>(
112 ucell.atoms[it1].nw)});
113 }
114 mR_a2D(iw0_b,iw1_b) = mR_2D(iwt0_2D, iwt1_2D);
115 }
116 }
117 }
118 }
119 ModuleBase::timer::tick("RI_2D_Comm", "split_m2D_ktoR");
120 return mRs_a2D;
121}
122
123
124template<typename Tdata, typename TK>
126 const UnitCell &ucell,
127 const K_Vectors &kv,
128 const int ik,
129 const double alpha,
130 const std::vector<std::map<TA,std::map<TAC,RI::Tensor<Tdata>>>> &Hs,
131 const Parallel_Orbitals& pv,
132 TK* hk)
133{
134 ModuleBase::TITLE("RI_2D_Comm","add_Hexx");
135 ModuleBase::timer::tick("RI_2D_Comm", "add_Hexx");
136
137 const std::map<int, std::vector<int>> is_list = {{1,{0}}, {2,{kv.isk[ik]}}, {4,{0,1,2,3}}};
138 for(const int is_b : is_list.at(PARAM.inp.nspin))
139 {
140 int is0_b, is1_b;
141 std::tie(is0_b,is1_b) = RI_2D_Comm::split_is_block(is_b);
142 for(const auto &Hs_tmpA : Hs[is_b])
143 {
144 const TA &iat0 = Hs_tmpA.first;
145 for(const auto &Hs_tmpB : Hs_tmpA.second)
146 {
147 const TA &iat1 = Hs_tmpB.first.first;
148 const TC &cell1 = Hs_tmpB.first.second;
149 const std::complex<double> frac = alpha
150 * std::exp( ModuleBase::TWO_PI*ModuleBase::IMAG_UNIT * (kv.kvec_c[ik] * (RI_Util::array3_to_Vector3(cell1)*ucell.latvec)) );
151 const RI::Tensor<Tdata> &H = Hs_tmpB.second;
152 for(size_t iw0_b=0; iw0_b<H.shape[0]; ++iw0_b)
153 {
154 const int iwt0 = RI_2D_Comm::get_iwt(ucell,iat0, iw0_b, is0_b);
155 if (pv.global2local_row(iwt0) < 0) {
156 continue;
157 }
158 for(size_t iw1_b=0; iw1_b<H.shape[1]; ++iw1_b)
159 {
160 const int iwt1 = RI_2D_Comm::get_iwt(ucell,iat1, iw1_b, is1_b);
161 if (pv.global2local_col(iwt1) < 0) {
162 continue;
163 }
164 LCAO_domain::set_mat2d(iwt0, iwt1, RI::Global_Func::convert<TK>(H(iw0_b, iw1_b)) * RI::Global_Func::convert<TK>(frac), pv, hk);
165 }
166 }
167 }
168 }
169 }
170 ModuleBase::timer::tick("RI_2D_Comm", "add_Hexx");
171}
172
173std::tuple<int,int,int>
174RI_2D_Comm::get_iat_iw_is_block(const UnitCell& ucell,const int& iwt)
175{
176 const int iat = ucell.iwt2iat[iwt];
177 const int iw = ucell.iwt2iw[iwt];
178 switch(PARAM.inp.nspin)
179 {
180 case 1: case 2:
181 return std::make_tuple(iat, iw, 0);
182 case 4:
183 return std::make_tuple(iat, iw/2, iw%2);
184 default:
185 throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
186 }
187}
188
189int RI_2D_Comm::get_is_block(const int is_k, const int is_row_b, const int is_col_b)
190{
191 switch(PARAM.inp.nspin)
192 {
193 case 1: return 0;
194 case 2: return is_k;
195 case 4: return is_row_b*2+is_col_b;
196 default: throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
197 }
198}
199
200std::tuple<int,int>
202{
203 switch(PARAM.inp.nspin)
204 {
205 case 1: case 2:
206 return std::make_tuple(0, 0);
207 case 4:
208 return std::make_tuple(is_b/2, is_b%2);
209 default:
210 throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
211 }
212}
213
214
215
217 const int iat,
218 const int iw_b,
219 const int is_b)
220{
221 const int it = ucell.iat2it[iat];
222 const int ia = ucell.iat2ia[iat];
223 int iw=-1;
224 switch(PARAM.inp.nspin)
225 {
226 case 1: case 2:
227 iw = iw_b; break;
228 case 4:
229 iw = iw_b*2+is_b; break;
230 default:
231 throw std::invalid_argument(std::string(__FILE__)+" line "+std::to_string(__LINE__));
232 }
233 const int iwt = ucell.itiaiw2iwt(it,ia,iw);
234 return iwt;
235}
236
237template<typename Tdata, typename TR>
239 const int current_spin,
240 const double alpha,
241 const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Hs,
242 const Parallel_Orbitals& pv,
243 const int npol,
245 const RI::Cell_Nearest<int, int, 3, double, 3>* const cell_nearest)
246{
247 ModuleBase::TITLE("RI_2D_Comm", "add_HexxR");
248 ModuleBase::timer::tick("RI_2D_Comm", "add_HexxR");
249 const std::map<int, std::vector<int>> is_list = { {1,{0}}, {2,{current_spin}}, {4,{0,1,2,3}} };
250 for (const int is_hs : is_list.at(PARAM.inp.nspin))
251 {
252 int is0_b = 0, is1_b = 0;
253 std::tie(is0_b, is1_b) = RI_2D_Comm::split_is_block(is_hs);
254 for (const auto& Hs_tmpA : Hs[is_hs])
255 {
256 const TA& iat0 = Hs_tmpA.first;
257 for (const auto& Hs_tmpB : Hs_tmpA.second)
258 {
259 const TA& iat1 = Hs_tmpB.first.first;
260 const TC& cell = Hs_tmpB.first.second;
262 (cell_nearest ?
263 cell_nearest->get_cell_nearest_discrete(iat0, iat1, cell)
264 : cell));
265 hamilt::BaseMatrix<TR>* HlocR = hR.find_matrix(iat0, iat1, R.x, R.y, R.z);
266 auto row_indexes = pv.get_indexes_row(iat0);
267 auto col_indexes = pv.get_indexes_col(iat1);
268 const RI::Tensor<Tdata>& HexxR = (Tdata)alpha * Hs_tmpB.second;
269 for (int lw0_b = 0;lw0_b < row_indexes.size();lw0_b += npol) // block
270 {
271 const int& gw0 = row_indexes[lw0_b] / npol;
272 const int& lw0 = (npol == 2) ? (lw0_b + is0_b) : lw0_b;
273 for (int lw1_b = 0;lw1_b < col_indexes.size();lw1_b += npol)
274 {
275 const int& gw1 = col_indexes[lw1_b] / npol;
276 const int& lw1 = (npol == 2) ? (lw1_b + is1_b) : lw1_b;
277 HlocR->add_element(lw0, lw1, RI::Global_Func::convert<TR>(HexxR(gw0, gw1)));
278 }
279 }
280 }
281 }
282 }
283
284 ModuleBase::timer::tick("RI_2D_Comm", "add_HexxR");
285}
286
287#endif
RI::Tensor< double > tensor_conj(const RI::Tensor< double > &t)
Definition RI_2D_Comm.hpp:22
Definition abfs-vector3_order.h:16
Definition klist.h:13
std::vector< int > isk
ngk, number of plane waves for each k point
Definition klist.h:21
std::vector< ModuleBase::Vector3< double > > kvec_c
Definition klist.h:15
T x
Definition vector3.h:24
T y
Definition vector3.h:25
T z
Definition vector3.h:26
static void tick(const std::string &class_name_in, const std::string &name_in)
Use twice at a time: the first time, set start_flag to false; the second time, calculate the time dur...
Definition timer.cpp:57
This class packs the basic information of 2D-block-cyclic parallel distribution of an arbitrary matri...
Definition parallel_2d.h:12
int global2local_col(const int igc) const
get the local index of a global index (col)
Definition parallel_2d.h:51
int global2local_row(const int igr) const
get the local index of a global index (row)
Definition parallel_2d.h:45
Definition parallel_orbitals.h:9
std::vector< int > get_indexes_col() const
Definition parallel_orbitals.cpp:154
std::vector< int > get_indexes_row() const
gather global indexes of orbitals in this processor get_indexes_row() : global indexes (~NLOCAL) of r...
Definition parallel_orbitals.cpp:140
const Input_para & inp
Definition parameter.h:26
Definition unitcell.h:16
int *& iat2it
Definition unitcell.h:47
Tiait itiaiw2iwt(const Tiait &it, const Tiait &ia, const Tiait &iw) const
Definition unitcell.h:68
ModuleBase::Matrix3 & latvec
Definition unitcell.h:35
int *& iwt2iw
Definition unitcell.h:50
int *& iat2ia
Definition unitcell.h:48
int *& iwt2iat
Definition unitcell.h:49
Definition base_matrix.h:20
void add_element(int mu, int nu, const T &value)
add a single element to the matrix
Definition base_matrix.h:57
Definition hcontainer.h:144
BaseMatrix< T > * find_matrix(int i, int j, int rx, int ry, int rz)
find BaseMatrix with atom index atom_i and atom_j and R index (rx, ry, rz) This interface can be used...
Definition hcontainer.cpp:261
void set_mat2d(const int &global_ir, const int &global_ic, const T &v, const Parallel_Orbitals &pv, T *mat)
Definition LCAO_set_mat2d.cpp:13
const double TWO_PI
Definition constants.h:21
const std::complex< double > IMAG_UNIT(0.0, 1.0)
void TITLE(const std::string &class_name, const std::string &function_name, const bool disable)
Definition tool_title.cpp:18
std::vector< int > get_ik_list(const K_Vectors &kv, const int is_k)
Definition RI_2D_Comm.cpp:65
int get_is_block(const int is_k, const int is_row_b, const int is_col_b)
Definition RI_2D_Comm.hpp:189
std::vector< std::map< TA, std::map< TAC, RI::Tensor< Tdata > > > > split_m2D_ktoR(const UnitCell &ucell, const K_Vectors &kv, const std::vector< const Tmatrix * > &mks_2D, const Parallel_2D &pv, const int nspin, const bool spgsym=false)
int TA
Definition RI_2D_Comm.h:25
std::array< Tcell, Ndim > TC
Definition RI_2D_Comm.h:28
std::tuple< int, int > split_is_block(const int is_b)
Definition RI_2D_Comm.hpp:201
void add_HexxR(const int current_spin, const double alpha, const std::vector< std::map< TA, std::map< TAC, RI::Tensor< Tdata > > > > &Hs, const Parallel_Orbitals &pv, const int npol, hamilt::HContainer< TR > &HlocR, const RI::Cell_Nearest< int, int, 3, double, 3 > *const cell_nearest=nullptr)
Definition RI_2D_Comm.hpp:238
std::tuple< int, int, int > get_iat_iw_is_block(const UnitCell &ucell, const int &iwt)
Definition RI_2D_Comm.hpp:174
std::pair< TA, TC > TAC
Definition RI_2D_Comm.h:29
void add_Hexx(const UnitCell &ucell, const K_Vectors &kv, const int ik, const double alpha, const std::vector< std::map< TA, std::map< TAC, RI::Tensor< Tdata > > > > &Hs, const Parallel_Orbitals &pv, TK *hk)
Definition RI_2D_Comm.hpp:125
int get_iwt(const UnitCell &ucell, const int iat, const int iw_b, const int is_b)
Definition RI_2D_Comm.hpp:216
Definition RI_Util.h:22
std::array< int, 3 > get_Born_vonKarmen_period(const K_Vectors &kv)
Definition RI_Util.hpp:16
std::vector< std::array< Tcell, Ndim > > get_Born_von_Karmen_cells(const std::array< Tcell, Ndim > &Born_von_Karman_period)
Definition RI_Util.hpp:34
ModuleBase::Vector3< Tcell > array3_to_Vector3(const std::array< Tcell, 3 > &v)
Definition RI_Util.h:38
Parameter PARAM
Definition parameter.cpp:3
std::array< int, 3 > TC
Definition ri_cv_io_test.cpp:9
std::string ks_solver
xiaohui add 2013-09-01
Definition input_parameter.h:73
int nspin
LDA ; LSDA ; non-linear spin.
Definition input_parameter.h:84