ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
gather_mat.h
Go to the documentation of this file.
1#ifndef GATHER_MAT_H
2#define GATHER_MAT_H
3
6
7namespace module_rt
8{
9//------------------------ MPI gathering and distributing functions ------------------------//
10// This struct is used for collecting matrices from all processes to root process
11template <typename T>
13{
14 std::shared_ptr<T> p;
15 size_t row;
16 size_t col;
17 std::shared_ptr<int> desc;
18};
19
20#ifdef __MPI
21// Collect matrices from all processes to root process
22template <typename T>
23void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock<T>& mat_l, Matrix_g<T>& mat_g)
24{
25 const int* desca = mat_l.desc; // Obtain the descriptor of the local matrix
26 int ctxt = desca[1]; // BLACS context
27 int nrows = desca[2]; // Global matrix row number
28 int ncols = desca[3]; // Global matrix column number
29
30 if (myid == root_proc)
31 {
32 mat_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
33 }
34 else
35 {
36 mat_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
37 }
38
39 // Set the descriptor of the global matrix
40 mat_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
41 mat_g.row = nrows;
42 mat_g.col = ncols;
43
44 // Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
45 Cpxgemr2d(nrows, ncols, mat_l.p, 1, 1, const_cast<int*>(desca), mat_g.p.get(), 1, 1, mat_g.desc.get(), ctxt);
46}
47
48template <typename T>
50{
51 const int* desc_local = mat_l.desc; // Obtain the descriptor from Parallel_Orbitals
52 int ctxt = desc_local[1]; // BLACS context
53 int nrows = desc_local[2]; // Global matrix row number
54 int ncols = desc_local[3]; // Global matrix column number
55
56 // Check matrix size consistency
57 if (mat_g.row != static_cast<size_t>(nrows) || mat_g.col != static_cast<size_t>(ncols))
58 {
59 throw std::invalid_argument("module_rt::distributeMatrix: Global matrix size mismatch.");
60 }
61
62 // Call the Cpxgemr2d function in ScaLAPACK to distribute the matrix data
63 Cpxgemr2d(nrows, ncols, mat_g.p.get(), 1, 1, mat_g.desc.get(), mat_l.p, 1, 1, const_cast<int*>(desc_local), ctxt);
64}
65
66template <typename T>
67void gatherPsi(const int myid,
68 const int root_proc,
69 T* psi_l,
70 const Parallel_Orbitals& para_orb,
72{
73 const int* desc_psi = para_orb.desc_wfc; // Obtain the descriptor from Parallel_Orbitals
74 int ctxt = desc_psi[1]; // BLACS context
75 int nrows = desc_psi[2]; // Global matrix row number
76 int ncols = desc_psi[3]; // Global matrix column number
77
78 if (myid == root_proc)
79 {
80 psi_g.p.reset(new T[nrows * ncols]); // No need to delete[] since it is a shared_ptr
81 }
82 else
83 {
84 psi_g.p.reset(new T[nrows * ncols]); // Placeholder for non-root processes
85 }
86
87 // Set the descriptor of the global psi
88 psi_g.desc.reset(new int[9]{1, ctxt, nrows, ncols, nrows, ncols, 0, 0, nrows});
89 psi_g.row = nrows;
90 psi_g.col = ncols;
91
92 // Call the Cpxgemr2d function in ScaLAPACK to collect the matrix data
93 Cpxgemr2d(nrows, ncols, psi_l, 1, 1, const_cast<int*>(desc_psi), psi_g.p.get(), 1, 1, psi_g.desc.get(), ctxt);
94}
95
96template <typename T>
97void distributePsi(const Parallel_Orbitals& para_orb, T* psi_l, const module_rt::Matrix_g<T>& psi_g)
98{
99 const int* desc_psi = para_orb.desc_wfc; // Obtain the descriptor from Parallel_Orbitals
100 int ctxt = desc_psi[1]; // BLACS context
101 int nrows = desc_psi[2]; // Global matrix row number
102 int ncols = desc_psi[3]; // Global matrix column number
103
104 // Call the Cpxgemr2d function in ScaLAPACK to distribute the matrix data
105 Cpxgemr2d(nrows, ncols, psi_g.p.get(), 1, 1, psi_g.desc.get(), psi_l, 1, 1, const_cast<int*>(desc_psi), ctxt);
106}
107//------------------------ MPI gathering and distributing functions ------------------------//
108
109#endif // __MPI
110} // namespace module_rt
111#endif // GATHER_MAT_H
Definition parallel_orbitals.h:9
int desc_wfc[9]
Definition parallel_orbitals.h:37
#define T
Definition exp.cpp:237
Definition band_energy.cpp:11
void distributePsi(const Parallel_Orbitals &para_orb, T *psi_l, const module_rt::Matrix_g< T > &psi_g)
Definition gather_mat.h:97
void gatherPsi(const int myid, const int root_proc, T *psi_l, const Parallel_Orbitals &para_orb, module_rt::Matrix_g< T > &psi_g)
Definition gather_mat.h:67
void distributeMatrix(hamilt::MatrixBlock< T > &mat_l, const module_rt::Matrix_g< T > &mat_g)
Definition gather_mat.h:49
void gatherMatrix(const int myid, const int root_proc, const hamilt::MatrixBlock< T > &mat_l, Matrix_g< T > &mat_g)
Definition gather_mat.h:23
std::enable_if< block2d_data_type< T >::value, void >::type Cpxgemr2d(int M, int N, T *A, int IA, int JA, int *DESCA, T *B, int IB, int JB, int *DESCB, int ICTXT)
Definition scalapack_connector.h:186
Definition matrixblock.h:9
T * p
Definition matrixblock.h:12
const int * desc
Definition matrixblock.h:15
Definition gather_mat.h:13
size_t col
Definition gather_mat.h:16
size_t row
Definition gather_mat.h:15
std::shared_ptr< T > p
Definition gather_mat.h:14
std::shared_ptr< int > desc
Definition gather_mat.h:17