ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
parallel_device.h
Go to the documentation of this file.
1#ifndef __PARALLEL_DEVICE_H__
2#define __PARALLEL_DEVICE_H__
3#ifdef __MPI
4#include "mpi.h"
7#include <complex>
8namespace Parallel_Common
9{
10void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
11void isend_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
12void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
13void isend_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
14void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm);
15void send_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm);
16void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm);
17void send_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm);
18void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
19void recv_data(std::complex<double>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
20void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
21void recv_data(std::complex<float>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
22void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
23void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
24void bcast_data(double* object, const int& n, const MPI_Comm& comm);
25void bcast_data(float* object, const int& n, const MPI_Comm& comm);
26void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
27void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
28void reduce_data(double* object, const int& n, const MPI_Comm& comm);
29void reduce_data(float* object, const int& n, const MPI_Comm& comm);
30void gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
31void gatherv_data(const std::complex<double>* sendbuf, int sendcount, std::complex<double>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
32void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
33void gatherv_data(const std::complex<float>* sendbuf, int sendcount, std::complex<float>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
34
35#ifndef __CUDA_MPI
36template<typename T, typename Device>
38{
39 bool alloc = false;
40 T* get(const T* object, const int& n, T* tmp_space = nullptr);
41 void del(T* object);
42 void sync_d2h(T* object_cpu, const T* object, const int& n);
43 void sync_h2d(T* object, const T* object_cpu, const int& n);
44};
45#endif
46
51template <typename T, typename Device>
52void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* tmp_space = nullptr)
53{
54#ifdef __CUDA_MPI
55 send_data(object, count, dest, tag, comm);
56#else
58 T* object_cpu = o.get(object, count, tmp_space);
59 o.sync_d2h(object_cpu, object, count);
60 send_data(object_cpu, count, dest, tag, comm);
61 o.del(object_cpu);
62#endif
63 return;
64}
65
71template <typename T, typename Device>
72void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* send_space)
73{
74#ifdef __CUDA_MPI
75 isend_data(object, count, dest, tag, comm, request);
76#else
78 T* object_cpu = o.get(object, count, send_space);
79 o.sync_d2h(object_cpu, object, count);
80 isend_data(object_cpu, count, dest, tag, comm, request);
81 o.del(object_cpu);
82#endif
83 return;
84}
85
90template <typename T, typename Device>
91void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status, T* tmp_space = nullptr)
92{
93#ifdef __CUDA_MPI
94 recv_data(object, count, source, tag, comm, status);
95#else
97 T* object_cpu = o.get(object, count, tmp_space);
98 recv_data(object_cpu, count, source, tag, comm, status);
99 o.sync_h2d(object, object_cpu, count);
100 o.del(object_cpu);
101#endif
102 return;
103}
104
116template <typename T, typename Device>
117void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
118{
119#ifdef __CUDA_MPI
120 bcast_data(object, n, comm);
121#else
123 T* object_cpu = o.get(object, n, tmp_space);
124 o.sync_d2h(object_cpu, object, n);
125 bcast_data(object_cpu, n, comm);
126 o.sync_h2d(object, object_cpu, n);
127 o.del(object_cpu);
128#endif
129 return;
130}
131
132template <typename T, typename Device>
133void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
134{
135#ifdef __CUDA_MPI
136 reduce_data(object, n, comm);
137#else
139 T* object_cpu = o.get(object, n, tmp_space);
140 o.sync_d2h(object_cpu, object, n);
141 reduce_data(object_cpu, n, comm);
142 o.sync_h2d(object, object_cpu, n);
143 o.del(object_cpu);
144#endif
145 return;
146}
147
148template <typename T, typename Device>
149void gatherv_dev(const T* sendbuf,
150 int sendcount,
151 T* recvbuf,
152 const int* recvcounts,
153 const int* displs,
154 MPI_Comm& comm,
155 T* tmp_sspace = nullptr,
156 T* tmp_rspace = nullptr)
157{
158#ifdef __CUDA_MPI
159 gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm);
160#else
162 int size = 0;
163 MPI_Comm_size(comm, &size);
164 int gather_space = displs[size - 1] + recvcounts[size - 1];
165 T* sendbuf_cpu = o1.get(sendbuf, sendcount, tmp_sspace);
166 T* recvbuf_cpu = o2.get(recvbuf, gather_space, tmp_rspace);
167 o1.sync_d2h(sendbuf_cpu, sendbuf, sendcount);
168 gatherv_data(sendbuf_cpu, sendcount, recvbuf_cpu, recvcounts, displs, comm);
169 o2.sync_h2d(recvbuf, recvbuf_cpu, gather_space);
170 o1.del(sendbuf_cpu);
171 o2.del(recvbuf_cpu);
172#endif
173 return;
174}
175
176}
177
178
179#endif
180#endif
#define T
Definition exp.cpp:237
Definition parallel_common.h:11
void send_data(const double *buf, int count, int dest, int tag, MPI_Comm &comm)
Definition parallel_device.cpp:21
void send_dev(const T *object, int count, int dest, int tag, MPI_Comm &comm, T *tmp_space=nullptr)
send data in Device
Definition parallel_device.h:52
void bcast_data(std::complex< double > *object, const int &n, const MPI_Comm &comm)
Definition parallel_device.cpp:53
void isend_data(const double *buf, int count, int dest, int tag, MPI_Comm &comm, MPI_Request *request)
Definition parallel_device.cpp:5
void isend_dev(const T *object, int count, int dest, int tag, MPI_Comm &comm, MPI_Request *request, T *send_space)
isend data in Device
Definition parallel_device.h:72
void recv_dev(T *object, int count, int source, int tag, MPI_Comm &comm, MPI_Status *status, T *tmp_space=nullptr)
recv data in Device
Definition parallel_device.h:91
void gatherv_dev(const T *sendbuf, int sendcount, T *recvbuf, const int *recvcounts, const int *displs, MPI_Comm &comm, T *tmp_sspace=nullptr, T *tmp_rspace=nullptr)
Definition parallel_device.h:149
void gatherv_data(const double *sendbuf, int sendcount, double *recvbuf, const int *recvcounts, const int *displs, MPI_Comm &comm)
Definition parallel_device.cpp:85
void recv_data(double *buf, int count, int source, int tag, MPI_Comm &comm, MPI_Status *status)
Definition parallel_device.cpp:37
void bcast_dev(T *object, const int &n, const MPI_Comm &comm, T *tmp_space=nullptr)
bcast data in Device
Definition parallel_device.h:117
void reduce_data(std::complex< double > *object, const int &n, const MPI_Comm &comm)
Definition parallel_device.cpp:69
void reduce_dev(T *object, const int &n, const MPI_Comm &comm, T *tmp_space=nullptr)
Definition parallel_device.h:133
Definition parallel_device.h:38
bool alloc
Definition parallel_device.h:39
void sync_h2d(T *object, const T *object_cpu, const int &n)
T * get(const T *object, const int &n, T *tmp_space=nullptr)
void sync_d2h(T *object_cpu, const T *object, const int &n)