ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cublasmp_context.h
Go to the documentation of this file.
1#ifndef CUBLASMP_CONTEXT_H
2#define CUBLASMP_CONTEXT_H
3
4#ifdef __MPI
5#include <mpi.h>
6#endif
7
8#ifdef __CUDA
9#include <cuda_runtime.h>
10#endif
11
12#ifdef __CUBLASMP
15
16#include <cublasmp.h>
17#include <cusolverMp.h>
18#include <iostream>
19#include <nccl.h>
20
21extern "C"
22{
24}
25
26#define LOG_DEBUG(msg) \
27 do \
28 { \
29 if (g_EnableDebugLog) \
30 { \
31 std::cerr << "[DEBUG] " << msg << " (at " << __func__ << ")" << std::endl; \
32 } \
33 } while (0)
34#endif // __CUBLASMP
35
36// The struct is ALWAYS available.
38{
39 bool is_initialized = false;
40
41#ifdef __MPI
42 MPI_Comm mpi_comm = MPI_COMM_NULL;
43#endif
44
45#ifdef __CUDA
46 cudaStream_t stream = nullptr;
47#endif
48
49#ifdef __CUBLASMP
50 ncclComm_t nccl_comm = nullptr;
51
52 cublasMpHandle_t cublasmp_handle = nullptr;
53 cublasMpGrid_t cublasmp_grid = nullptr;
54
55 cusolverMpHandle_t cusolvermp_handle = nullptr;
56 cusolverMpGrid_t cusolvermp_grid = nullptr;
57#endif
58};
59
60// API functions are only visible when cuBLASMp is enabled.
61#ifdef __CUBLASMP
62
63inline void init_cublasmp_resources(CublasMpResources& res, MPI_Comm mpi_comm, const int* desc)
64{
65 if (res.is_initialized)
66 {
67 return;
68 }
69
70 res.mpi_comm = mpi_comm;
71 MPI_Barrier(res.mpi_comm);
72
73 // 1. Get BLACS topology info
74 int cblacs_ctxt = desc[1];
75 int nprows, npcols, myprow, mypcol;
76 Cblacs_gridinfo(cblacs_ctxt, &nprows, &npcols, &myprow, &mypcol);
77
78 GlobalV::ofs_running << "nprows = " << nprows << std::endl;
79 GlobalV::ofs_running << "npcols = " << npcols << std::endl;
80 GlobalV::ofs_running << "myprow = " << myprow << std::endl;
81 GlobalV::ofs_running << "mypcol = " << mypcol << std::endl;
83
84 int rank, size;
85 MPI_Comm_rank(res.mpi_comm, &rank);
86 MPI_Comm_size(res.mpi_comm, &size);
87
89 cudaSetDevice(device_id);
90 cudaStreamCreate(&res.stream);
91
92 // 2. Initialize NCCL communicator
93 ncclUniqueId id;
94 if (rank == 0)
95 {
96 ncclGetUniqueId(&id);
97 }
98 // Broadcast the unique NCCL ID to all ranks
99 MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, res.mpi_comm);
100 // Initialize NCCL with the generated ID
101 ncclCommInitRank(&res.nccl_comm, size, id, rank);
102
103 // 3. Initialize cuBLASMp specific resources
104 cublasMpCreate(&res.cublasmp_handle, res.stream);
105 cublasMpGridCreate(nprows, npcols, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, res.nccl_comm, &res.cublasmp_grid);
106
107 // 4. Initialize cuSOLVERMp specific resources
108 cusolverMpCreate(&res.cusolvermp_handle, device_id, res.stream);
109 cusolverMpCreateDeviceGrid(res.cusolvermp_handle,
110 &res.cusolvermp_grid,
111 res.nccl_comm,
112 nprows,
113 npcols,
114 CUSOLVERMP_GRID_MAPPING_ROW_MAJOR);
115
116 res.is_initialized = true;
117}
118
119inline void finalize_cublasmp_resources(CublasMpResources& res)
120{
121 if (!res.is_initialized)
122 {
123 return;
124 }
125
126 if (res.stream)
127 {
128 cudaStreamSynchronize(res.stream);
129 }
130
131 // Destroy cuBLASMp resources
132 if (res.cublasmp_grid)
133 {
134 cublasMpGridDestroy(res.cublasmp_grid);
135 }
136 if (res.cublasmp_handle)
137 {
138 cublasMpDestroy(res.cublasmp_handle);
139 }
140
141 // Destroy cuSOLVERMp resources
142 if (res.cusolvermp_grid)
143 {
144 cusolverMpDestroyGrid(res.cusolvermp_grid);
145 }
146 if (res.cusolvermp_handle)
147 {
148 cusolverMpDestroy(res.cusolvermp_handle);
149 }
150
151 // Destroy NCCL communicator
152 if (res.nccl_comm)
153 {
154 ncclCommDestroy(res.nccl_comm);
155 }
156
157 if (res.stream)
158 {
159 cudaStreamDestroy(res.stream);
160 }
161
162 res.is_initialized = false;
163}
164
165#endif // __CUBLASMP
166
167#endif // CUBLASMP_CONTEXT_H
void Cblacs_gridinfo(int icontxt, int *nprow, int *npcol, int *myprow, int *mypcol)
int get_device_id() const
Get the bound GPU device ID.
Definition device.h:124
static DeviceContext & instance()
Get the singleton instance of DeviceContext.
Definition device.cpp:112
std::ofstream ofs_running
Definition global_variable.cpp:38
Definition cublasmp_context.h:38
bool is_initialized
Definition cublasmp_context.h:39
MPI_Comm mpi_comm
Definition cublasmp_context.h:42
int mypcol
Definition tddft_test.cpp:14
int myprow
Definition tddft_test.cpp:14