ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
gather_mat.h
Go to the documentation of this file.
1#ifndef GATHER_MAT_H
2#define GATHER_MAT_H
3
6
7namespace module_rt
8{
9//------------------------ MPI gathering and distributing functions ------------------------//
10// This struct is used for collecting matrices from all processes to root process
11template <typename T>
13{
14 std::shared_ptr<T> p;
15 size_t row;
16 size_t col;
17 std::shared_ptr<int> desc;
18};
19
20#ifdef __MPI
21// Collect matrices from all processes to root process
22template <typename T>
23void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock<T>& mat_l, Matrix_g<T>& mat_g)
24{
25 const int* desca = mat_l.desc; // Obtain the descriptor of the local matrix
26 int ctxt = desca[1]; // BLACS context
27 int nrows = desca[2]; // Global matrix row number
28 int ncols = desca[3]; // Global matrix column number
29
30 if (myid == root_proc)
31 {
32 mat_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
33 }
34 else
35 {
36 mat_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
37 }
38
39 // Set the descriptor of the global matrix
40 mat_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
41 mat_g.row = nrows;
42 mat_g.col = ncols;
43
44 // Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
45 Cpxgemr2d(nrows, ncols, mat_l.p, 1, 1, const_cast<int*>(desca), mat_g.p.get(), 1, 1, mat_g.desc.get(), ctxt);
46}
47
48template <typename T>
49void gatherPsi(const int myid,
50 const int root_proc,
51 T* psi_l,
52 const Parallel_Orbitals& para_orb,
54{
55 const int* desc_psi = para_orb.desc_wfc; // Obtain the descriptor from Parallel_Orbitals
56 int ctxt = desc_psi[1]; // BLACS context
57 int nrows = desc_psi[2]; // Global matrix row number
58 int ncols = desc_psi[3]; // Global matrix column number
59
60 if (myid == root_proc)
61 {
62 psi_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
63 }
64 else
65 {
66 psi_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
67 }
68
69 // Set the descriptor of the global psi
70 psi_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
71 psi_g.row = nrows;
72 psi_g.col = ncols;
73
74 // Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
75 Cpxgemr2d(nrows, ncols, psi_l, 1, 1, const_cast<int*>(desc_psi), psi_g.p.get(), 1, 1, psi_g.desc.get(), ctxt);
76}
77
78template <typename T>
79void distributePsi(const Parallel_Orbitals& para_orb, T* psi_l, const module_rt::Matrix_g<T>& psi_g)
80{
81 const int* desc_psi = para_orb.desc_wfc; // Obtain the descriptor from Parallel_Orbitals
82 int ctxt = desc_psi[1]; // BLACS context
83 int nrows = desc_psi[2]; // Global matrix row number
84 int ncols = desc_psi[3]; // Global matrix column number
85
86 // Call the Cpxgemr2d function in ScaLAPACK to distribute the matrix data
87 Cpxgemr2d(nrows, ncols, psi_g.p.get(), 1, 1, psi_g.desc.get(), psi_l, 1, 1, const_cast<int*>(desc_psi), ctxt);
88}
89//------------------------ MPI gathering and distributing functions ------------------------//
90
91#endif // __MPI
92} // namespace module_rt
93#endif // GATHER_MAT_H
Definition parallel_orbitals.h:9
int desc_wfc[9]
Definition parallel_orbitals.h:37
#define T
Definition exp.cpp:237
Definition band_energy.cpp:11
void distributePsi(const Parallel_Orbitals &para_orb, T *psi_l, const module_rt::Matrix_g< T > &psi_g)
Definition gather_mat.h:79
void gatherPsi(const int myid, const int root_proc, T *psi_l, const Parallel_Orbitals &para_orb, module_rt::Matrix_g< T > &psi_g)
Definition gather_mat.h:49
void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock< T > &mat_l, Matrix_g< T > &mat_g)
Definition gather_mat.h:23
std::enable_if< block2d_data_type< T >::value, void >::type Cpxgemr2d(int M, int N, T *A, int IA, int JA, int *DESCA, T *B, int IB, int JB, int *DESCB, int ICTXT)
Definition scalapack_connector.h:186
Definition matrixblock.h:9
T * p
Definition matrixblock.h:12
const int * desc
Definition matrixblock.h:15
Definition gather_mat.h:13
size_t col
Definition gather_mat.h:16
size_t row
Definition gather_mat.h:15
std::shared_ptr< T > p
Definition gather_mat.h:14
std::shared_ptr< int > desc
Definition gather_mat.h:17