ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
device.h
Go to the documentation of this file.
1#ifndef MODULE_DEVICE_H_
2#define MODULE_DEVICE_H_
3
4#include "types.h"
5#include "device_helpers.h"
6#include <fstream>
7#include <mutex>
8
9#ifdef __MPI
10#include "mpi.h"
11#endif
12
13namespace base_device
14{
15
16namespace information
17{
18
23std::string get_device_name(std::string device_flag);
24
29int get_device_num(std::string device_flag);
30
37void output_device_info(std::ostream& output, const std::string& device);
38
44
49std::string get_device_flag(const std::string& device,
50 const std::string& basis_type);
51
52#if __MPI
58int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm = MPI_COMM_WORLD);
59#endif
60
61template <typename Device>
62void print_device_info(const Device* dev, std::ofstream& ofs_device)
63{
64 return;
65}
66
67template <typename Device>
68void record_device_memory(const Device* dev, std::ofstream& ofs_device, std::string str, size_t size)
69{
70 return;
71}
72
73#if defined(__CUDA) || defined(__ROCM)
74template <>
75void print_device_info<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU *ctx, std::ofstream &ofs_device);
76
77template <>
78void record_device_memory<base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* dev, std::ofstream& ofs_device, std::string str, size_t size);
79#endif
80
81} // end of namespace information
82
99public:
104 static DeviceContext& instance();
105
118 void init();
119
124 int get_device_id() const { return device_id_; }
125
130 int get_device_count() const { return device_count_; }
131
136 int get_local_rank() const { return local_rank_; }
137
138 // Disable copy and assignment
139 DeviceContext(const DeviceContext&) = delete;
141
142private:
143 DeviceContext() = default;
144 ~DeviceContext() = default;
145
146 bool initialized_ = false;
147 bool gpu_enabled_ = false;
148 int device_id_ = -1;
150 int local_rank_ = 0;
151 std::mutex init_mutex_;
152};
153
154} // end of namespace base_device
155
156#endif // MODULE_DEVICE_H_
Singleton class to manage GPU device context and initialization.
Definition device.h:98
void init()
Initialize GPU device binding.
Definition device.cpp:117
DeviceContext(const DeviceContext &)=delete
bool initialized_
Definition device.h:146
bool gpu_enabled_
Definition device.h:147
int get_local_rank() const
Get the local MPI rank within the node.
Definition device.h:136
int get_device_id() const
Get the bound GPU device ID.
Definition device.h:124
int get_device_count() const
Get the total number of GPU devices on this node.
Definition device.h:130
int device_count_
Definition device.h:149
int device_id_
Definition device.h:148
DeviceContext & operator=(const DeviceContext &)=delete
std::mutex init_mutex_
Definition device.h:151
static DeviceContext & instance()
Get the singleton instance of DeviceContext.
Definition device.cpp:112
int local_rank_
Definition device.h:150
Definition output.h:13
Type trait templates for device and precision detection.
bool probe_gpu_availability()
Safely probes for GPU availability without exiting on error.
Definition device.cpp:41
std::string get_device_flag(const std::string &device, const std::string &basis_type)
Get the device flag object for source_io PARAM.inp.device.
Definition device.cpp:63
int get_device_num(std::string device_flag)
Get the device number for source_esolver.
Definition output_device.cpp:56
void record_device_memory(const Device *dev, std::ofstream &ofs_device, std::string str, size_t size)
Definition device.h:68
void print_device_info(const Device *dev, std::ofstream &ofs_device)
Definition device.h:62
std::string get_device_name(std::string device_flag)
Get the device name for source_esolver.
Definition output_device.cpp:23
void output_device_info(std::ostream &output, const std::string &device)
Output the device information for source_esolver.
Definition output_device.cpp:94
int get_node_rank_with_mpi_shared(const MPI_Comm mpi_comm)
Get the local rank within the node using MPI_COMM_TYPE_SHARED.
Definition device.cpp:26
Definition device.cpp:21
string device_flag
Definition pw_test.cpp:13