1#ifndef __PARALLEL_DEVICE_H__
2#define __PARALLEL_DEVICE_H__
10void isend_data(
const double* buf,
int count,
int dest,
int tag, MPI_Comm& comm, MPI_Request* request);
11void isend_data(
const std::complex<double>* buf,
int count,
int dest,
int tag, MPI_Comm& comm, MPI_Request* request);
12void isend_data(
const float* buf,
int count,
int dest,
int tag, MPI_Comm& comm, MPI_Request* request);
13void isend_data(
const std::complex<float>* buf,
int count,
int dest,
int tag, MPI_Comm& comm, MPI_Request* request);
14void send_data(
const double* buf,
int count,
int dest,
int tag, MPI_Comm& comm);
15void send_data(
const std::complex<double>* buf,
int count,
int dest,
int tag, MPI_Comm& comm);
16void send_data(
const float* buf,
int count,
int dest,
int tag, MPI_Comm& comm);
17void send_data(
const std::complex<float>* buf,
int count,
int dest,
int tag, MPI_Comm& comm);
18void recv_data(
double* buf,
int count,
int source,
int tag, MPI_Comm& comm, MPI_Status* status);
19void recv_data(std::complex<double>* buf,
int count,
int source,
int tag, MPI_Comm& comm, MPI_Status* status);
20void recv_data(
float* buf,
int count,
int source,
int tag, MPI_Comm& comm, MPI_Status* status);
21void recv_data(std::complex<float>* buf,
int count,
int source,
int tag, MPI_Comm& comm, MPI_Status* status);
22void bcast_data(std::complex<double>*
object,
const int& n,
const MPI_Comm& comm,
int root = 0);
23void bcast_data(std::complex<float>*
object,
const int& n,
const MPI_Comm& comm,
int root = 0);
24void bcast_data(
double*
object,
const int& n,
const MPI_Comm& comm,
int root = 0);
25void bcast_data(
float*
object,
const int& n,
const MPI_Comm& comm,
int root = 0);
26void reduce_data(std::complex<double>*
object,
const int& n,
const MPI_Comm& comm);
27void reduce_data(std::complex<float>*
object,
const int& n,
const MPI_Comm& comm);
28void reduce_data(
double*
object,
const int& n,
const MPI_Comm& comm);
29void reduce_data(
float*
object,
const int& n,
const MPI_Comm& comm);
30void gatherv_data(
const double* sendbuf,
int sendcount,
double* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
31void gatherv_data(
const std::complex<double>* sendbuf,
int sendcount, std::complex<double>* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
32void gatherv_data(
const float* sendbuf,
int sendcount,
float* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
33void gatherv_data(
const std::complex<float>* sendbuf,
int sendcount, std::complex<float>* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
35#if defined(__NCCL_PARALLEL_DEVICE)
36void nccl_bcast_data(
double*
object,
const int& n, MPI_Comm& comm,
int root = 0);
37void nccl_bcast_data(std::complex<double>*
object,
const int& n, MPI_Comm& comm,
int root = 0);
38void nccl_bcast_data(
float*
object,
const int& n, MPI_Comm& comm,
int root = 0);
39void nccl_bcast_data(std::complex<float>*
object,
const int& n, MPI_Comm& comm,
int root = 0);
40void nccl_reduce_data(
double*
object,
const int& n, MPI_Comm& comm);
41void nccl_reduce_data(std::complex<double>*
object,
const int& n, MPI_Comm& comm);
42void nccl_reduce_data(
float*
object,
const int& n, MPI_Comm& comm);
43void nccl_reduce_data(std::complex<float>*
object,
const int& n, MPI_Comm& comm);
44void nccl_gatherv_data(
const double* sendbuf,
int sendcount,
double* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
45void nccl_gatherv_data(
const std::complex<double>* sendbuf,
int sendcount, std::complex<double>* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
46void nccl_gatherv_data(
const float* sendbuf,
int sendcount,
float* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
47void nccl_gatherv_data(
const std::complex<float>* sendbuf,
int sendcount, std::complex<float>* recvbuf,
const int* recvcounts,
const int* displs, MPI_Comm& comm);
51template<
typename T,
typename Device>
56 T*
get(
const T*
object,
const int& n,
T* tmp_space =
nullptr);
58 void sync_d2h(
T* object_cpu,
const T*
object,
const int& n);
59 void sync_h2d(
T*
object,
const T* object_cpu,
const int& n);
67template <
typename T,
typename Device>
68void send_dev(
const T*
object,
int count,
int dest,
int tag, MPI_Comm& comm,
T* tmp_space =
nullptr)
71 send_data(
object, count, dest, tag, comm);
74 T* object_cpu = o.
get(
object, count, tmp_space);
75 send_data(object_cpu, count, dest, tag, comm);
86template <
typename T,
typename Device>
87void isend_dev(
const T*
object,
int count,
int dest,
int tag, MPI_Comm& comm, MPI_Request* request,
T* send_space)
90 isend_data(
object, count, dest, tag, comm, request);
93 T* object_cpu = o.
get(
object, count, send_space);
94 isend_data(object_cpu, count, dest, tag, comm, request);
104template <
typename T,
typename Device>
105void recv_dev(
T*
object,
int count,
int source,
int tag, MPI_Comm& comm, MPI_Status* status,
T* tmp_space =
nullptr)
108 recv_data(
object, count, source, tag, comm, status);
111 T* object_cpu = o.
get_buffer(
object, count, tmp_space);
112 recv_data(object_cpu, count, source, tag, comm, status);
113 o.
sync_h2d(
object, object_cpu, count);
130template <
typename T,
typename Device>
131void bcast_dev(
T*
object,
const int& n,
const MPI_Comm& comm,
int root = 0,
T* tmp_space =
nullptr)
133#if defined(__NCCL_PARALLEL_DEVICE)
134 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
136 nccl_bcast_data(
object, n,
const_cast<MPI_Comm&
>(comm), root);
145 MPI_Comm_rank(comm, &rank);
146 T* object_cpu = rank == root ? o.
get(
object, n, tmp_space) : o.
get_buffer(
object, n, tmp_space);
157template <
typename T,
typename Device>
158void reduce_dev(
T*
object,
const int& n,
const MPI_Comm& comm,
T* tmp_space =
nullptr)
160#if defined(__NCCL_PARALLEL_DEVICE)
161 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
163 nccl_reduce_data(
object, n,
const_cast<MPI_Comm&
>(comm));
171 T* object_cpu = o.
get(
object, n, tmp_space);
179template <
typename T,
typename Device>
183 const int* recvcounts,
186 T* tmp_sspace =
nullptr,
187 T* tmp_rspace =
nullptr)
189#if defined(__NCCL_PARALLEL_DEVICE)
190 if (std::is_same<Device, base_device::DEVICE_GPU>::value)
192 nccl_gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm);
197 gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm);
201 MPI_Comm_size(comm, &size);
202 int gather_space = displs[size - 1] + recvcounts[size - 1];
203 T* sendbuf_cpu = o1.
get(sendbuf, sendcount, tmp_sspace);
204 T* recvbuf_cpu = o2.
get_buffer(recvbuf, gather_space, tmp_rspace);
205 gatherv_data(sendbuf_cpu, sendcount, recvbuf_cpu, recvcounts, displs, comm);
206 o2.
sync_h2d(recvbuf, recvbuf_cpu, gather_space);
#define T
Definition exp.cpp:237
Definition parallel_common.h:11
void send_data(const double *buf, int count, int dest, int tag, MPI_Comm &comm)
Definition parallel_device.cpp:273
void bcast_data(std::complex< double > *object, const int &n, const MPI_Comm &comm, int root)
Definition parallel_device.cpp:305
void send_dev(const T *object, int count, int dest, int tag, MPI_Comm &comm, T *tmp_space=nullptr)
send data in Device
Definition parallel_device.h:68
void isend_data(const double *buf, int count, int dest, int tag, MPI_Comm &comm, MPI_Request *request)
Definition parallel_device.cpp:257
void isend_dev(const T *object, int count, int dest, int tag, MPI_Comm &comm, MPI_Request *request, T *send_space)
isend data in Device
Definition parallel_device.h:87
void recv_dev(T *object, int count, int source, int tag, MPI_Comm &comm, MPI_Status *status, T *tmp_space=nullptr)
recv data in Device
Definition parallel_device.h:105
void gatherv_dev(const T *sendbuf, int sendcount, T *recvbuf, const int *recvcounts, const int *displs, MPI_Comm &comm, T *tmp_sspace=nullptr, T *tmp_rspace=nullptr)
Definition parallel_device.h:180
void bcast_dev(T *object, const int &n, const MPI_Comm &comm, int root=0, T *tmp_space=nullptr)
broadcast data in Device
Definition parallel_device.h:131
void gatherv_data(const double *sendbuf, int sendcount, double *recvbuf, const int *recvcounts, const int *displs, MPI_Comm &comm)
Definition parallel_device.cpp:337
void recv_data(double *buf, int count, int source, int tag, MPI_Comm &comm, MPI_Status *status)
Definition parallel_device.cpp:289
void reduce_data(std::complex< double > *object, const int &n, const MPI_Comm &comm)
Definition parallel_device.cpp:321
void reduce_dev(T *object, const int &n, const MPI_Comm &comm, T *tmp_space=nullptr)
Definition parallel_device.h:158
Definition parallel_device.h:53
bool alloc
Definition parallel_device.h:54
T * get_buffer(const T *object, const int &n, T *tmp_space=nullptr)
void sync_h2d(T *object, const T *object_cpu, const int &n)
T * get(const T *object, const int &n, T *tmp_space=nullptr)
void sync_d2h(T *object_cpu, const T *object, const int &n)