ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
para_gemm.h
Go to the documentation of this file.
1#ifndef PARA_GEMM_H
2#define PARA_GEMM_H
5
6#include <vector>
7#ifdef __MPI
8#include "mpi.h"
9#endif
10
11namespace ModuleBase
12{
23template <typename T, typename Device = base_device::DEVICE_CPU>
25{
26 public:
27 PGemmCN();
28 ~PGemmCN();
29
41 void set_dimension(
42#ifdef __MPI
43 MPI_Comm comm_col,
44 MPI_Comm comm_row,
45#endif
46 const int ncolA,
47 const int LDA,
48 const int ncolB,
49 const int LDB,
50 const int nrow,
51 const int LDC,
52 const int mode = 1);
53
58 void multiply(const T alpha, const T* A, const T* B, const T beta, T* C);
59#ifdef __MPI
60 MPI_Comm col_world = MPI_COMM_NULL;
61 MPI_Comm row_world = MPI_COMM_NULL;
62
63 int col_rank = 0;
64 int col_nproc = 1;
65 int row_rank = 0;
66 int row_nproc = 1;
67
68 std::vector<int> colA_loc;
69 int max_colA = 0;
70 std::vector<int> colB_loc;
71 int max_colB = 0;
72
73 std::vector<MPI_Request> requests;
74 std::vector<int> recv_counts;
75 std::vector<int> displs;
76 int size_C_local = 0;
77 int size_C_global = 0;
78 bool gatherC = true;
79 bool divideCrow = false;
80#endif
81 int ncolA = 0;
82 int ncolB = 0;
83 int nrow = 0;
84 int LDA = 0;
85 int LDB = 0;
86 int LDC = 0;
87 private:
89 void multiply_single(const T alpha, const T* A, const T* B, const T beta, T* C);
90#ifdef __MPI
92 void multiply_col(const T alpha, const T* A, const T* B, const T beta, T* C);
94 void multiply_row(const T alpha, const T* A, const T* B, const T beta, T* C);
95#endif
101
102#ifdef __MPI
103 private:
104 std::vector<T> isend_tmp_;
105 std::vector<T> A_tmp_;
106 std::vector<T> B_tmp_;
107 std::vector<T> C_tmp_;
108 std::vector<T> C_global_tmp_;
109 T* C_local_tmp_ = nullptr;
110 T* A_tmp_device_ = nullptr;
111 T* B_tmp_device_ = nullptr;
112#endif
113
114
115};
116} // namespace ModuleBase
117#endif
this class is used to perform parallel matrix multiplication C = alpha * A^H * B + beta * C Here,...
Definition para_gemm.h:25
int row_rank
rank in row_world
Definition para_gemm.h:65
int nrow
number of rows of A or B
Definition para_gemm.h:83
std::vector< int > colA_loc
[col_nproc] number of columns of A matrix in each proc
Definition para_gemm.h:68
int col_nproc
number of procs in col_world
Definition para_gemm.h:64
T * B_tmp_device_
temperory memory for B
Definition para_gemm.h:111
T * C_local_tmp_
temperory memory for C_local
Definition para_gemm.h:109
std::vector< T > isend_tmp_
temperory memory for sending data
Definition para_gemm.h:104
~PGemmCN()
Definition para_gemm.cpp:13
std::vector< MPI_Request > requests
MPI request.
Definition para_gemm.h:73
std::vector< T > A_tmp_
temperory memory for A
Definition para_gemm.h:105
void multiply_row(const T alpha, const T *A, const T *B, const T beta, T *C)
for mode = 3
Definition para_gemm.cpp:323
void multiply(const T alpha, const T *A, const T *B, const T beta, T *C)
calculate C = alpha * A^H * B + beta * C
Definition para_gemm.cpp:147
bool divideCrow
whether divide C_global to C_local
Definition para_gemm.h:79
int max_colB
maximum number of columns of B matrix in all procs
Definition para_gemm.h:71
bool gatherC
whether gather C_local to C_global
Definition para_gemm.h:78
MPI_Comm row_world
row communicator world
Definition para_gemm.h:61
void set_dimension(MPI_Comm comm_col, MPI_Comm comm_row, const int ncolA, const int LDA, const int ncolB, const int LDB, const int nrow, const int LDC, const int mode=1)
set the dimension of A, B, and C
Definition para_gemm.cpp:23
std::vector< T > B_tmp_
temperory memory for B
Definition para_gemm.h:106
int col_rank
rank in col_world
Definition para_gemm.h:63
int size_C_global
size of C_global, which is the global C matrix gathered from all procs
Definition para_gemm.h:77
PGemmCN()
Definition para_gemm.cpp:9
MPI_Comm col_world
column communicator world
Definition para_gemm.h:60
int ncolB
number of columns of B, which is a local matrix in each proc
Definition para_gemm.h:82
std::vector< T > C_tmp_
temperory memory for C
Definition para_gemm.h:107
void multiply_col(const T alpha, const T *A, const T *B, const T beta, T *C)
for mode = 1 or 2
Definition para_gemm.cpp:191
T * A_tmp_device_
temperory memory for A
Definition para_gemm.h:110
void multiply_single(const T alpha, const T *A, const T *B, const T beta, T *C)
for col_nproc == 1
Definition para_gemm.cpp:171
int max_colA
maximum number of columns of A matrix in all procs
Definition para_gemm.h:69
int LDC
leading dimension of C, which can be C_local or C_global
Definition para_gemm.h:86
std::vector< int > recv_counts
receive counts for gathering C_local to C_global
Definition para_gemm.h:74
int LDA
leading dimension of A in each proc
Definition para_gemm.h:84
std::vector< int > displs
displacements for gathering C_local to C_global
Definition para_gemm.h:75
std::vector< int > colB_loc
[col_nproc] number of columns of B matrix in each proc
Definition para_gemm.h:70
std::vector< T > C_global_tmp_
temperory memory for C_global
Definition para_gemm.h:108
int ncolA
number of columns of A, which is a local matrix in each proc
Definition para_gemm.h:81
int LDB
leading dimension of B in each proc
Definition para_gemm.h:85
int row_nproc
number of procs in row_world
Definition para_gemm.h:66
int size_C_local
size of C_local, which is a local matrix in each proc
Definition para_gemm.h:76
#define T
Definition exp.cpp:237
#define __MPI
Definition array_pool.h:6
base device SOURCES math_dngvd_test cpp endif() if(ENABLE_GOOGLEBENCH) AddTest(TARGET PERF_MODULE_HSOLVER_KERNELS LIBS parameter $
Definition CMakeLists.txt:10
Definition memory_op.h:77
Definition memory_op.h:17