1#ifndef CUBLASMP_CONTEXT_H
2#define CUBLASMP_CONTEXT_H
9#include <cuda_runtime.h>
17#include <cusolverMp.h>
26#define LOG_DEBUG(msg) \
29 if (g_EnableDebugLog) \
31 std::cerr << "[DEBUG] " << msg << " (at " << __func__ << ")" << std::endl; \
46 cudaStream_t stream =
nullptr;
50 ncclComm_t nccl_comm =
nullptr;
52 cublasMpHandle_t cublasmp_handle =
nullptr;
53 cublasMpGrid_t cublasmp_grid =
nullptr;
55 cusolverMpHandle_t cusolvermp_handle =
nullptr;
56 cusolverMpGrid_t cusolvermp_grid =
nullptr;
63inline void init_cublasmp_resources(
CublasMpResources& res, MPI_Comm mpi_comm,
const int* desc)
74 int cblacs_ctxt = desc[1];
89 cudaSetDevice(device_id);
90 cudaStreamCreate(&res.stream);
99 MPI_Bcast((
void*)&
id,
sizeof(
id), MPI_BYTE, 0, res.
mpi_comm);
101 ncclCommInitRank(&res.nccl_comm, size,
id, rank);
104 cublasMpCreate(&res.cublasmp_handle, res.stream);
105 cublasMpGridCreate(nprows, npcols, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, res.nccl_comm, &res.cublasmp_grid);
108 cusolverMpCreate(&res.cusolvermp_handle, device_id, res.stream);
109 cusolverMpCreateDeviceGrid(res.cusolvermp_handle,
110 &res.cusolvermp_grid,
114 CUSOLVERMP_GRID_MAPPING_ROW_MAJOR);
128 cudaStreamSynchronize(res.stream);
132 if (res.cublasmp_grid)
134 cublasMpGridDestroy(res.cublasmp_grid);
136 if (res.cublasmp_handle)
138 cublasMpDestroy(res.cublasmp_handle);
142 if (res.cusolvermp_grid)
144 cusolverMpDestroyGrid(res.cusolvermp_grid);
146 if (res.cusolvermp_handle)
148 cusolverMpDestroy(res.cusolvermp_handle);
154 ncclCommDestroy(res.nccl_comm);
159 cudaStreamDestroy(res.stream);
void Cblacs_gridinfo(int icontxt, int *nprow, int *npcol, int *myprow, int *mypcol)
int get_device_id() const
Get the bound GPU device ID.
Definition device.h:124
static DeviceContext & instance()
Get the singleton instance of DeviceContext.
Definition device.cpp:112
std::ofstream ofs_running
Definition global_variable.cpp:38
Definition cublasmp_context.h:38
bool is_initialized
Definition cublasmp_context.h:39
MPI_Comm mpi_comm
Definition cublasmp_context.h:42
int mypcol
Definition tddft_test.cpp:14
int myprow
Definition tddft_test.cpp:14