ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
dsp_connector.h
Go to the documentation of this file.
1#ifndef DSP_CONNECTOR_H
2#define DSP_CONNECTOR_H
3#ifdef __DSP
4
7
8namespace mtfunc
9{
10// Base dsp functions
11void dspInitHandle(int id);
12void dspDestoryHandle(int id);
13void* malloc_ht(size_t bytes, int cluster_id);
14void free_ht(void* ptr);
15
16// mtblas functions
17
18void sgemm_mt_(const char* transa,
19 const char* transb,
20 const int* m,
21 const int* n,
22 const int* k,
23 const float* alpha,
24 const float* a,
25 const int* lda,
26 const float* b,
27 const int* ldb,
28 const float* beta,
29 float* c,
30 const int* ldc,
31 int cluster_id);
32
33void dgemm_mt_(const char* transa,
34 const char* transb,
35 const int* m,
36 const int* n,
37 const int* k,
38 const double* alpha,
39 const double* a,
40 const int* lda,
41 const double* b,
42 const int* ldb,
43 const double* beta,
44 double* c,
45 const int* ldc,
46 int cluster_id);
47
48void zgemm_mt_(const char* transa,
49 const char* transb,
50 const int* m,
51 const int* n,
52 const int* k,
53 const std::complex<double>* alpha,
54 const std::complex<double>* a,
55 const int* lda,
56 const std::complex<double>* b,
57 const int* ldb,
58 const std::complex<double>* beta,
59 std::complex<double>* c,
60 const int* ldc,
61 int cluster_id);
62
63void cgemm_mt_(const char* transa,
64 const char* transb,
65 const int* m,
66 const int* n,
67 const int* k,
68 const std::complex<float>* alpha,
69 const std::complex<float>* a,
70 const int* lda,
71 const std::complex<float>* b,
72 const int* ldb,
73 const std::complex<float>* beta,
74 std::complex<float>* c,
75 const int* ldc,
76 int cluster_id);
77
78
79
80void sgemv_mt_(const char* transa,
81 const int* m,
82 const int* n,
83 const float* alpha,
84 const float* a,
85 const int* lda,
86 const float* x,
87 const int* incx,
88 const float* beta,
89 float* y,
90 const int* incy,
91 int cluster_id);
92
93void dgemv_mt_(const char* transa,
94 const int* m,
95 const int* n,
96 const double* alpha,
97 const double* a,
98 const int* lda,
99 const double* x,
100 const int* incx,
101 const double* beta,
102 double* y,
103 const int* incy,
104 int cluster_id);
105
106void zgemv_mt_(const char* transa,
107 const int* m,
108 const int* n,
109 const std::complex<double>* alpha,
110 const std::complex<double>* a,
111 const int* lda,
112 const std::complex<double>* x,
113 const int* incx,
114 const std::complex<double>* beta,
115 std::complex<double>* y,
116 const int* incy,
117 int cluster_id);
118
119void cgemv_mt_(const char* transa,
120 const int* m,
121 const int* n,
122 const std::complex<float>* alpha,
123 const std::complex<float>* a,
124 const int* lda,
125 const std::complex<float>* x,
126 const int* incx,
127 const std::complex<float>* beta,
128 std::complex<float>* y,
129 const int* incy,
130 int cluster_id);
131
132void sgemm_mth_(const char* transa,
133 const char* transb,
134 const int* m,
135 const int* n,
136 const int* k,
137 const float* alpha,
138 const float* a,
139 const int* lda,
140 const float* b,
141 const int* ldb,
142 const float* beta,
143 float* c,
144 const int* ldc,
145 int cluster_id);
146
147void dgemm_mth_(const char* transa,
148 const char* transb,
149 const int* m,
150 const int* n,
151 const int* k,
152 const double* alpha,
153 const double* a,
154 const int* lda,
155 const double* b,
156 const int* ldb,
157 const double* beta,
158 double* c,
159 const int* ldc,
160 int cluster_id);
161
162void zgemm_mth_(const char* transa,
163 const char* transb,
164 const int* m,
165 const int* n,
166 const int* k,
167 const std::complex<double>* alpha,
168 const std::complex<double>* a,
169 const int* lda,
170 const std::complex<double>* b,
171 const int* ldb,
172 const std::complex<double>* beta,
173 std::complex<double>* c,
174 const int* ldc,
175 int cluster_id);
176
177void zgemm_pack_mth_(const char* transa,
178 const char* transb,
179 const int* m,
180 const int* n,
181 const int* k,
182 const std::complex<double>* alpha,
183 const std::complex<double>* a,
184 const int* lda,
185 const std::complex<double>* b,
186 const int* ldb,
187 const std::complex<double>* beta,
188 std::complex<double>* c,
189 const int* ldc,
190 int cluster_id);
191
192void cgemm_mth_(const char* transa,
193 const char* transb,
194 const int* m,
195 const int* n,
196 const int* k,
197 const std::complex<float>* alpha,
198 const std::complex<float>* a,
199 const int* lda,
200 const std::complex<float>* b,
201 const int* ldb,
202 const std::complex<float>* beta,
203 std::complex<float>* c,
204 const int* ldc,
205 int cluster_id);
206
207void cgemm_pack_mth_(const char* transa,
208 const char* transb,
209 const int* m,
210 const int* n,
211 const int* k,
212 const std::complex<float>* alpha,
213 const std::complex<float>* a,
214 const int* lda,
215 const std::complex<float>* b,
216 const int* ldb,
217 const std::complex<float>* beta,
218 std::complex<float>* c,
219 const int* ldc,
220 int cluster_id);
221
222void sgemv_mth_(const char* transa,
223 const int* m,
224 const int* n,
225 const float* alpha,
226 const float* a,
227 const int* lda,
228 const float* x,
229 const int* incx,
230 const float* beta,
231 float* y,
232 const int* incy,
233 int cluster_id);
234
235void dgemv_mth_(const char* transa,
236 const int* m,
237 const int* n,
238 const double* alpha,
239 const double* a,
240 const int* lda,
241 const double* x,
242 const int* incx,
243 const double* beta,
244 double* y,
245 const int* incy,
246 int cluster_id);
247
248void zgemv_mth_(const char* transa,
249 const int* m,
250 const int* n,
251 const std::complex<double>* alpha,
252 const std::complex<double>* a,
253 const int* lda,
254 const std::complex<double>* x,
255 const int* incx,
256 const std::complex<double>* beta,
257 std::complex<double>* y,
258 const int* incy,
259 int cluster_id);
260
261void cgemv_mth_(const char* transa,
262 const int* m,
263 const int* n,
264 const std::complex<float>* alpha,
265 const std::complex<float>* a,
266 const int* lda,
267 const std::complex<float>* x,
268 const int* incx,
269 const std::complex<float>* beta,
270 std::complex<float>* y,
271 const int* incy,
272 int cluster_id);
273
274// #define zgemm_ zgemm_mt
275
276// The next is dsp utils. It may be moved to other files if this file get too huge
277
278template <typename T>
279void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm)
280{
281
282 using syncmem_complex_op
284
285 auto* swap = new T[notconv * nbase_x];
286 auto* target = new T[notconv * nbase_x];
287 syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x);
288 if (base_device::get_current_precision(swap) == "single")
289 {
290 MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm);
291 }
292 else
293 {
294 MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm);
295 }
296
297 syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x);
298 syncmem_complex_op()(swap, scc + nbase * nbase_x, notconv * nbase_x);
299
300 if (base_device::get_current_precision(swap) == "single")
301 {
302 MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm);
303 }
304 else
305 {
306 MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm);
307 }
308
309 syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x);
310 delete[] swap;
311 delete[] target;
312}
313} // namespace mtfunc
314
315#endif
316#endif
#define T
Definition exp.cpp:237
std::string get_current_precision(const T *var)
Get the precision string for a given numeric type.
Definition dsp_connector.cpp:14
void cgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:158
void zgemm_pack_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:406
void sgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:304
void sgemv_mt_(const char *transa, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:190
void dgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:336
void dgemv_mth_(const char *transa, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:592
void dgemv_mt_(const char *transa, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *x, const int *incx, const double *beta, double *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:218
void zgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:126
void cgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:466
void sgemv_mth_(const char *transa, const int *m, const int *n, const float *alpha, const float *a, const int *lda, const float *x, const int *incx, const float *beta, float *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:564
void dspDestoryHandle(int id)
Definition dsp_connector.cpp:21
void dspInitHandle(int id)
Definition dsp_connector.cpp:15
void zgemv_mt_(const char *transa, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *x, const int *incx, const std::complex< double > *beta, std::complex< double > *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:246
void sgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:62
void free_ht(void *ptr)
Definition dsp_connector.cpp:56
void cgemv_mt_(const char *transa, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *x, const int *incx, const std::complex< float > *beta, std::complex< float > *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:274
void zgemv_mth_(const char *transa, const int *m, const int *n, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *x, const int *incx, const std::complex< double > *beta, std::complex< double > *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:620
void cgemm_pack_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:506
void dgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:94
void * malloc_ht(size_t bytes, int cluster_id)
Definition dsp_connector.cpp:48
void cgemv_mth_(const char *transa, const int *m, const int *n, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *x, const int *incx, const std::complex< float > *beta, std::complex< float > *y, const int *incy, int cluster_id)
Definition dsp_connector.cpp:656
void zgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:368