ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
dsp_connector.h
Go to the documentation of this file.
1#ifndef DSP_CONNECTOR_H
2#define DSP_CONNECTOR_H
3#ifdef __DSP
4
8
9namespace mtfunc
10{
11// Base dsp functions
12void dspInitHandle(int id);
13void dspDestoryHandle(int id);
14void* malloc_ht(size_t bytes, int cluster_id);
15void free_ht(void* ptr);
16
17// mtblas functions
18
19void sgemm_mt_(const char* transa,
20 const char* transb,
21 const int* m,
22 const int* n,
23 const int* k,
24 const float* alpha,
25 const float* a,
26 const int* lda,
27 const float* b,
28 const int* ldb,
29 const float* beta,
30 float* c,
31 const int* ldc,
32 int cluster_id);
33
34void dgemm_mt_(const char* transa,
35 const char* transb,
36 const int* m,
37 const int* n,
38 const int* k,
39 const double* alpha,
40 const double* a,
41 const int* lda,
42 const double* b,
43 const int* ldb,
44 const double* beta,
45 double* c,
46 const int* ldc,
47 int cluster_id);
48
49void zgemm_mt_(const char* transa,
50 const char* transb,
51 const int* m,
52 const int* n,
53 const int* k,
54 const std::complex<double>* alpha,
55 const std::complex<double>* a,
56 const int* lda,
57 const std::complex<double>* b,
58 const int* ldb,
59 const std::complex<double>* beta,
60 std::complex<double>* c,
61 const int* ldc,
62 int cluster_id);
63
64void cgemm_mt_(const char* transa,
65 const char* transb,
66 const int* m,
67 const int* n,
68 const int* k,
69 const std::complex<float>* alpha,
70 const std::complex<float>* a,
71 const int* lda,
72 const std::complex<float>* b,
73 const int* ldb,
74 const std::complex<float>* beta,
75 std::complex<float>* c,
76 const int* ldc,
77 int cluster_id);
78
79void sgemm_mth_(const char* transa,
80 const char* transb,
81 const int* m,
82 const int* n,
83 const int* k,
84 const float* alpha,
85 const float* a,
86 const int* lda,
87 const float* b,
88 const int* ldb,
89 const float* beta,
90 float* c,
91 const int* ldc,
92 int cluster_id);
93
94void dgemm_mth_(const char* transa,
95 const char* transb,
96 const int* m,
97 const int* n,
98 const int* k,
99 const double* alpha,
100 const double* a,
101 const int* lda,
102 const double* b,
103 const int* ldb,
104 const double* beta,
105 double* c,
106 const int* ldc,
107 int cluster_id);
108
109void zgemm_mth_(const char* transa,
110 const char* transb,
111 const int* m,
112 const int* n,
113 const int* k,
114 const std::complex<double>* alpha,
115 const std::complex<double>* a,
116 const int* lda,
117 const std::complex<double>* b,
118 const int* ldb,
119 const std::complex<double>* beta,
120 std::complex<double>* c,
121 const int* ldc,
122 int cluster_id);
123
124void cgemm_mth_(const char* transa,
125 const char* transb,
126 const int* m,
127 const int* n,
128 const int* k,
129 const std::complex<float>* alpha,
130 const std::complex<float>* a,
131 const int* lda,
132 const std::complex<float>* b,
133 const int* ldb,
134 const std::complex<float>* beta,
135 std::complex<float>* c,
136 const int* ldc,
137 int cluster_id);
138
139// #define zgemm_ zgemm_mt
140
141// The next is dsp utils. It may be moved to other files if this file get too huge
142
143template <typename T>
144void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv, MPI_Comm diag_comm)
145{
146
147 using syncmem_complex_op
149
150 auto* swap = new T[notconv * nbase_x];
151 auto* target = new T[notconv * nbase_x];
152 syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x);
153 if (base_device::get_current_precision(swap) == "single")
154 {
155 MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm);
156 }
157 else
158 {
159 MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm);
160 }
161
162 syncmem_complex_op()(hcc + nbase * nbase_x, target, notconv * nbase_x);
163 syncmem_complex_op()(swap, scc + nbase * nbase_x, notconv * nbase_x);
164
165 if (base_device::get_current_precision(swap) == "single")
166 {
167 MPI_Reduce(swap, target, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm);
168 }
169 else
170 {
171 MPI_Reduce(swap, target, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm);
172 }
173
174 syncmem_complex_op()(scc + nbase * nbase_x, target, notconv * nbase_x);
175 delete[] swap;
176 delete[] target;
177}
178} // namespace mtfunc
179
180#endif
181#endif
#define T
Definition exp.cpp:237
std::string get_current_precision(const float *var)
Definition device.cpp:36
Definition dsp_connector.cpp:14
void cgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:161
void sgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:195
void dgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:227
void zgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:129
void cgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< float > *alpha, const std::complex< float > *a, const int *lda, const std::complex< float > *b, const int *ldb, const std::complex< float > *beta, std::complex< float > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:296
void dspDestoryHandle(int id)
Definition dsp_connector.cpp:21
void dspInitHandle(int id)
Definition dsp_connector.cpp:15
void sgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, float *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:65
void free_ht(void *ptr)
Definition dsp_connector.cpp:56
void dgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:97
void * malloc_ht(size_t bytes, int cluster_id)
Definition dsp_connector.cpp:46
void zgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex< double > *alpha, const std::complex< double > *a, const int *lda, const std::complex< double > *b, const int *ldb, const std::complex< double > *beta, std::complex< double > *c, const int *ldc, int cluster_id)
Definition dsp_connector.cpp:259