ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
parallel_device.h
Go to the documentation of this file.
1#ifndef __PARALLEL_DEVICE_H__
2#define __PARALLEL_DEVICE_H__
3#ifdef __MPI
4#include "mpi.h"
5#include <complex>
6#include <type_traits>
8namespace Parallel_Common
9{
10void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
11void isend_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
12void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
13void isend_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
14void send_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm);
15void send_data(const std::complex<double>* buf, int count, int dest, int tag, MPI_Comm& comm);
16void send_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm);
17void send_data(const std::complex<float>* buf, int count, int dest, int tag, MPI_Comm& comm);
18void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
19void recv_data(std::complex<double>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
20void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
21void recv_data(std::complex<float>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
22void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm, int root = 0);
23void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm, int root = 0);
24void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root = 0);
25void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root = 0);
26void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
27void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
28void reduce_data(double* object, const int& n, const MPI_Comm& comm);
29void reduce_data(float* object, const int& n, const MPI_Comm& comm);
30void gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
31void gatherv_data(const std::complex<double>* sendbuf, int sendcount, std::complex<double>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
32void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
33void gatherv_data(const std::complex<float>* sendbuf, int sendcount, std::complex<float>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
34
35#if defined(__NCCL_PARALLEL_DEVICE)
36void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root = 0);
37void nccl_bcast_data(std::complex<double>* object, const int& n, MPI_Comm& comm, int root = 0);
38void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root = 0);
39void nccl_bcast_data(std::complex<float>* object, const int& n, MPI_Comm& comm, int root = 0);
40void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm);
41void nccl_reduce_data(std::complex<double>* object, const int& n, MPI_Comm& comm);
42void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm);
43void nccl_reduce_data(std::complex<float>* object, const int& n, MPI_Comm& comm);
44void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
45void nccl_gatherv_data(const std::complex<double>* sendbuf, int sendcount, std::complex<double>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
46void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
47void nccl_gatherv_data(const std::complex<float>* sendbuf, int sendcount, std::complex<float>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);
48#endif
49
50#ifndef __CUDA_MPI
51template<typename T, typename Device>
53{
54 bool alloc = false;
55 T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr);
56 T* get(const T* object, const int& n, T* tmp_space = nullptr);
57 void del(T* object);
58 void sync_d2h(T* object_cpu, const T* object, const int& n);
59 void sync_h2d(T* object, const T* object_cpu, const int& n);
60};
61#endif
62
67template <typename T, typename Device>
68void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* tmp_space = nullptr)
69{
70#ifdef __CUDA_MPI
71 send_data(object, count, dest, tag, comm);
72#else
74 T* object_cpu = o.get(object, count, tmp_space);
75 send_data(object_cpu, count, dest, tag, comm);
76 o.del(object_cpu);
77#endif
78 return;
79}
80
86template <typename T, typename Device>
87void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* send_space)
88{
89#ifdef __CUDA_MPI
90 isend_data(object, count, dest, tag, comm, request);
91#else
93 T* object_cpu = o.get(object, count, send_space);
94 isend_data(object_cpu, count, dest, tag, comm, request);
95 o.del(object_cpu);
96#endif
97 return;
98}
99
104template <typename T, typename Device>
105void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status, T* tmp_space = nullptr)
106{
107#ifdef __CUDA_MPI
108 recv_data(object, count, source, tag, comm, status);
109#else
111 T* object_cpu = o.get_buffer(object, count, tmp_space);
112 recv_data(object_cpu, count, source, tag, comm, status);
113 o.sync_h2d(object, object_cpu, count);
114 o.del(object_cpu);
115#endif
116 return;
117}
118
130template <typename T, typename Device>
131void bcast_dev(T* object, const int& n, const MPI_Comm& comm, int root = 0, T* tmp_space = nullptr)
132{
133#if defined(__NCCL_PARALLEL_DEVICE)
134 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
135 {
136 nccl_bcast_data(object, n, const_cast<MPI_Comm&>(comm), root);
137 return;
138 }
139#endif
140#ifdef __CUDA_MPI
141 bcast_data(object, n, comm, root);
142#else
144 int rank = 0;
145 MPI_Comm_rank(comm, &rank);
146 T* object_cpu = rank == root ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space);
147 bcast_data(object_cpu, n, comm, root);
148 if (rank != root)
149 {
150 o.sync_h2d(object, object_cpu, n);
151 }
152 o.del(object_cpu);
153#endif
154 return;
155}
156
157template <typename T, typename Device>
158void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
159{
160#if defined(__NCCL_PARALLEL_DEVICE)
161 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
162 {
163 nccl_reduce_data(object, n, const_cast<MPI_Comm&>(comm));
164 return;
165 }
166#endif
167#ifdef __CUDA_MPI
168 reduce_data(object, n, comm);
169#else
171 T* object_cpu = o.get(object, n, tmp_space);
172 reduce_data(object_cpu, n, comm);
173 o.sync_h2d(object, object_cpu, n);
174 o.del(object_cpu);
175#endif
176 return;
177}
178
179template <typename T, typename Device>
180void gatherv_dev(const T* sendbuf,
181 int sendcount,
182 T* recvbuf,
183 const int* recvcounts,
184 const int* displs,
185 MPI_Comm& comm,
186 T* tmp_sspace = nullptr,
187 T* tmp_rspace = nullptr)
188{
189#if defined(__NCCL_PARALLEL_DEVICE)
190 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
191 {
192 nccl_gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm);
193 return;
194 }
195#endif
196#ifdef __CUDA_MPI
197 gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm);
198#else
200 int size = 0;
201 MPI_Comm_size(comm, &size);
202 int gather_space = displs[size - 1] + recvcounts[size - 1];
203 T* sendbuf_cpu = o1.get(sendbuf, sendcount, tmp_sspace);
204 T* recvbuf_cpu = o2.get_buffer(recvbuf, gather_space, tmp_rspace);
205 gatherv_data(sendbuf_cpu, sendcount, recvbuf_cpu, recvcounts, displs, comm);
206 o2.sync_h2d(recvbuf, recvbuf_cpu, gather_space);
207 o1.del(sendbuf_cpu);
208 o2.del(recvbuf_cpu);
209#endif
210 return;
211}
212
213}
214
215
216#endif
217#endif
#define T
Definition exp.cpp:237
Definition parallel_common.h:11
void send_data(const double *buf, int count, int dest, int tag, MPI_Comm &comm)
Definition parallel_device.cpp:273
void bcast_data(std::complex< double > *object, const int &n, const MPI_Comm &comm, int root)
Definition parallel_device.cpp:305
void send_dev(const T *object, int count, int dest, int tag, MPI_Comm &comm, T *tmp_space=nullptr)
send data in Device
Definition parallel_device.h:68
void isend_data(const double *buf, int count, int dest, int tag, MPI_Comm &comm, MPI_Request *request)
Definition parallel_device.cpp:257
void isend_dev(const T *object, int count, int dest, int tag, MPI_Comm &comm, MPI_Request *request, T *send_space)
isend data in Device
Definition parallel_device.h:87
void recv_dev(T *object, int count, int source, int tag, MPI_Comm &comm, MPI_Status *status, T *tmp_space=nullptr)
recv data in Device
Definition parallel_device.h:105
void gatherv_dev(const T *sendbuf, int sendcount, T *recvbuf, const int *recvcounts, const int *displs, MPI_Comm &comm, T *tmp_sspace=nullptr, T *tmp_rspace=nullptr)
Definition parallel_device.h:180
void bcast_dev(T *object, const int &n, const MPI_Comm &comm, int root=0, T *tmp_space=nullptr)
broadcast data in Device
Definition parallel_device.h:131
void gatherv_data(const double *sendbuf, int sendcount, double *recvbuf, const int *recvcounts, const int *displs, MPI_Comm &comm)
Definition parallel_device.cpp:337
void recv_data(double *buf, int count, int source, int tag, MPI_Comm &comm, MPI_Status *status)
Definition parallel_device.cpp:289
void reduce_data(std::complex< double > *object, const int &n, const MPI_Comm &comm)
Definition parallel_device.cpp:321
void reduce_dev(T *object, const int &n, const MPI_Comm &comm, T *tmp_space=nullptr)
Definition parallel_device.h:158
Definition parallel_device.h:53
bool alloc
Definition parallel_device.h:54
T * get_buffer(const T *object, const int &n, T *tmp_space=nullptr)
void sync_h2d(T *object, const T *object_cpu, const int &n)
T * get(const T *object, const int &n, T *tmp_space=nullptr)
void sync_d2h(T *object_cpu, const T *object, const int &n)