ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
cuda_mem_wrapper.h
Go to the documentation of this file.
1#pragma once
2#include <cuda_runtime.h>
4#include "gint_helper.cuh"
5
6template <typename T>
8{
9 public:
10
11 CudaMemWrapper() = default;
12 CudaMemWrapper(const CudaMemWrapper& other) = delete;
13 CudaMemWrapper& operator=(const CudaMemWrapper& other) = delete;
15 {
16 this->device_ptr_ = other.device_ptr_;
17 this->host_ptr_ = other.host_ptr_;
18 this->size_ = other.size_;
19 this->malloc_host_ = other.malloc_host_;
20 this->stream_ = other.stream_;
21
22 other.device_ptr_ = nullptr;
23 other.host_ptr_ = nullptr;
24 other.size_ = 0;
25 other.malloc_host_ = false;
26 other.stream_ = 0;
27 }
28
30 {
31 if (this != &other)
32 {
33 this->device_ptr_ = other.device_ptr_;
34 this->host_ptr_ = other.host_ptr_;
35 this->size_ = other.size_;
36 this->malloc_host_ = other.malloc_host_;
37 this->stream_ = other.stream_;
38
39 other.device_ptr_ = nullptr;
40 other.host_ptr_ = nullptr;
41 other.size_ = 0;
42 other.malloc_host_ = false;
43 other.stream_ = 0;
44 }
45 return *this;
46 }
47
48 CudaMemWrapper(size_t size,
49 cudaStream_t stream = 0,
50 bool malloc_host = true)
51 {
52 size_ = size;
53 malloc_host_ = malloc_host;
54 stream_ = stream;
55
56 if (malloc_host)
57 {
58 checkCuda(cudaMallocHost((void**)&host_ptr_, size_* sizeof(T)));
59 memset(host_ptr_, 0, size_ * sizeof(T));
60 }
61 else
62 { host_ptr_ = nullptr; }
63
64 checkCuda(cudaMalloc((void**)&device_ptr_, size_ * sizeof(T)));
65 checkCuda(cudaMemset(device_ptr_, 0, size_ * sizeof(T)));
66 }
67
69 {
70 free();
71 }
72
73 void copy_host_to_device_sync(size_t size)
74 {
75 if (host_ptr_ == nullptr)
76 { ModuleBase::WARNING_QUIT("cuda_mem_wrapper", "Host pointer is null, cannot copy to device."); }
77 checkCuda(cudaMemcpy(device_ptr_, host_ptr_, size * sizeof(T), cudaMemcpyHostToDevice));
78 }
79
84
85 void copy_host_to_device_async(size_t size)
86 {
87 if (host_ptr_ == nullptr)
88 { ModuleBase::WARNING_QUIT("cuda_mem_wrapper", "Host pointer is null, cannot copy to device."); }
89 checkCuda(cudaMemcpyAsync(device_ptr_, host_ptr_, size * sizeof(T), cudaMemcpyHostToDevice, stream_));
90 }
91
96
97 void copy_device_to_host_sync(size_t size)
98 {
99 if (host_ptr_ == nullptr)
100 { ModuleBase::WARNING_QUIT("cuda_mem_wrapper", "Host pointer is null, cannot copy to host."); }
101 checkCuda(cudaMemcpy(host_ptr_, device_ptr_, size * sizeof(T), cudaMemcpyDeviceToHost));
102 }
103
108
110 {
111 if (host_ptr_ == nullptr)
112 { ModuleBase::WARNING_QUIT("cuda_mem_wrapper", "Host pointer is null, cannot copy to host."); }
113 checkCuda(cudaMemcpyAsync(host_ptr_, device_ptr_, size * sizeof(T), cudaMemcpyDeviceToHost, stream_));
114 }
115
120
121 void memset_device_sync(const size_t size, const int value = 0)
122 {
123 checkCuda(cudaMemset(device_ptr_, value, size * sizeof(T)));
124 }
125
126 void memset_device_sync(const int value = 0)
127 {
129 }
130
131 void memset_device_async(const size_t size, const int value = 0)
132 {
133 checkCuda(cudaMemsetAsync(device_ptr_, value, size * sizeof(T), stream_));
134 }
135
136 void memset_device_async(const int value = 0)
137 {
139 }
140
141 void memset_host(const size_t size, const int value = 0)
142 {
143 if (host_ptr_ == nullptr)
144 { ModuleBase::WARNING_QUIT("cuda_mem_wrapper", "Host pointer is null, cannot memset host."); }
145 checkCuda(cudaMemset(host_ptr_, value, size * sizeof(T)));
146 }
147
148 void memset_host(const int value = 0)
149 {
150 memset_host(size_, value);
151 }
152
153 void free()
154 {
155 checkCuda(cudaFree(device_ptr_));
156 checkCuda(cudaFreeHost(host_ptr_));
157 }
158
160 T* get_host_ptr() { return host_ptr_; }
161 const T* get_device_ptr() const { return device_ptr_; }
162 const T* get_host_ptr() const { return host_ptr_; }
163 size_t get_size() const { return size_; }
164
165 private:
166 T* device_ptr_ = nullptr;
167 T* host_ptr_ = nullptr;
168 size_t size_ = 0;
169 bool malloc_host_ = false;
170 cudaStream_t stream_ = 0;
171};
Definition cuda_mem_wrapper.h:8
T * get_host_ptr()
Definition cuda_mem_wrapper.h:160
void free()
Definition cuda_mem_wrapper.h:153
T * get_device_ptr()
Definition cuda_mem_wrapper.h:159
CudaMemWrapper & operator=(CudaMemWrapper &&other) noexcept
Definition cuda_mem_wrapper.h:29
void memset_device_sync(const int value=0)
Definition cuda_mem_wrapper.h:126
CudaMemWrapper(const CudaMemWrapper &other)=delete
cudaStream_t stream_
Definition cuda_mem_wrapper.h:170
void copy_device_to_host_sync(size_t size)
Definition cuda_mem_wrapper.h:97
T * host_ptr_
Definition cuda_mem_wrapper.h:167
void memset_device_sync(const size_t size, const int value=0)
Definition cuda_mem_wrapper.h:121
void copy_host_to_device_async(size_t size)
Definition cuda_mem_wrapper.h:85
size_t get_size() const
Definition cuda_mem_wrapper.h:163
CudaMemWrapper()=default
const T * get_host_ptr() const
Definition cuda_mem_wrapper.h:162
~CudaMemWrapper()
Definition cuda_mem_wrapper.h:68
void copy_device_to_host_async()
Definition cuda_mem_wrapper.h:116
void memset_device_async(const size_t size, const int value=0)
Definition cuda_mem_wrapper.h:131
void copy_device_to_host_sync()
Definition cuda_mem_wrapper.h:104
CudaMemWrapper(size_t size, cudaStream_t stream=0, bool malloc_host=true)
Definition cuda_mem_wrapper.h:48
void copy_host_to_device_async()
Definition cuda_mem_wrapper.h:92
bool malloc_host_
Definition cuda_mem_wrapper.h:169
size_t size_
Definition cuda_mem_wrapper.h:168
void memset_host(const int value=0)
Definition cuda_mem_wrapper.h:148
T * device_ptr_
Definition cuda_mem_wrapper.h:166
void copy_host_to_device_sync(size_t size)
Definition cuda_mem_wrapper.h:73
CudaMemWrapper & operator=(const CudaMemWrapper &other)=delete
void copy_host_to_device_sync()
Definition cuda_mem_wrapper.h:80
void memset_host(const size_t size, const int value=0)
Definition cuda_mem_wrapper.h:141
void copy_device_to_host_async(size_t size)
Definition cuda_mem_wrapper.h:109
CudaMemWrapper(CudaMemWrapper &&other) noexcept
Definition cuda_mem_wrapper.h:14
const T * get_device_ptr() const
Definition cuda_mem_wrapper.h:161
void memset_device_async(const int value=0)
Definition cuda_mem_wrapper.h:136
#define T
Definition exp.cpp:237
void WARNING_QUIT(const std::string &, const std::string &)
Combine the functions of WARNING and QUIT.
Definition test_delley.cpp:14