ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
lapack.h
Go to the documentation of this file.
1#ifndef ATEN_KERNELS_LAPACK_H_
2#define ATEN_KERNELS_LAPACK_H_
3
5#include <ATen/core/tensor.h>
7
9
10namespace container {
11namespace kernels {
12
13
14template <typename T, typename Device>
15struct set_matrix {
17 const char& uplo,
18 T* A,
19 const int& dim);
20};
21
22
23// --- 1. Matrix Decomposition ---
24template <typename T, typename Device>
27 const char& uplo,
28 const char& diag,
29 const int& dim,
30 T* Mat,
31 const int& lda);
32};
33
34
35template <typename T, typename Device>
38 const char& uplo,
39 const int& dim,
40 T* Mat,
41 const int& lda);
42};
43
44template <typename T, typename Device>
47 const int& m,
48 const int& n,
49 T* Mat,
50 const int& lda,
51 int* ipiv);
52};
53
54
55template <typename T, typename Device>
58 const int& n,
59 T* Mat,
60 const int& lda,
61 const int* ipiv,
62 T* work,
63 const int& lwork);
64};
65
66// This is QR factorization in-place
67// that will change input Mat A to orthogonal/unitary matrix Q
68template <typename T, typename Device>
85 const int m,
86 const int n,
87 T *A,
88 const int lda);
89};
90
91// This is QR factorization
92// where [in]Mat will be kept and the results are stored in separate matrix Q
93// template <typename T, typename Device>
94// struct lapack_geqrf{
95// /**
96// * Perform QR factorization of a matrix using LAPACK's geqrf function.
97// *
98// * @param m The number of rows in the matrix.
99// * @param n The number of columns in the matrix.
100// * @param Mat The matrix to be factorized.
101// * On exit, the upper triangle contains the upper triangular matrix R,
102// * and the elements below the diagonal, with the array TAU, represent
103// * the unitary matrix Q as a product of min(m,n) elementary reflectors.
104// * @param lda The leading dimension of the matrix.
105// * @param tau Array of size min(m,n) containing the Householder reflectors.
106// */
107// void operator()(
108// const int m,
109// const int n,
110// T *Mat,
111// const int lda,
112// T *tau);
113// };
114
115
116// --- 2. Linear System Solvers ---
117template <typename T, typename Device>
120 const char& trans,
121 const int& n,
122 const int& nrhs,
123 T* A,
124 const int& lda,
125 const int* ipiv,
126 T* B,
127 const int& ldb);
128};
129
130
131
132// --- 3. Standard & Generalized Eigenvalue ---
133
134// ============================================================================
135// Standard Hermitian Eigenvalue Problem Solvers
136// ============================================================================
137// The following structures (lapack_heevd and lapack_heevx) implement solvers
138// for standard Hermitian eigenvalue problems of the form:
139// A * x = lambda * x
140// where:
141// - A is a Hermitian matrix
142// - lambda are the eigenvalues to be computed
143// - x are the corresponding eigenvectors
144//
145// ============================================================================
146template <typename T, typename Device>
148 // !> ZHEEVD computes all eigenvalues and, optionally, eigenvectors of a
149 // !> complex Hermitian matrix A. If eigenvectors are desired, it uses a
150 // !> divide and conquer algorithm.
151 // !> On exit, if JOBZ = 'V', then if INFO = 0, A contains the
152 // !> orthonormal eigenvectors of the matrix A.
174 using Real = typename GetTypeReal<T>::type;
176 const int dim,
177 T* Mat,
178 const int lda,
179 Real* eigen_val);
180};
181
182template <typename T, typename Device>
184 using Real = typename GetTypeReal<T>::type;
207 const int dim,
208 const int lda,
209 const T *Mat,
210 const int neig,
211 Real *eigen_val,
212 T *eigen_vec);
213};
214
215
216// ============================================================================
217// Generalized Hermitian-definite Eigenvalue Problem Solvers
218// ============================================================================
219// The following structures (lapack_hegvd and lapack_hegvx) implement solvers
220// for generalized Hermitian-definite eigenvalue problems of the form:
221// A * x = lambda * B * x
222// where:
223// - A is a Hermitian matrix
224// - B is a Hermitian positive definite matrix
225// - lambda are the eigenvalues to be computed
226// - x are the corresponding eigenvectors
227//
228// ============================================================================
229
230template <typename T, typename Device>
232 using Real = typename GetTypeReal<T>::type;
251 const int n,
252 const int lda,
253 T *Mat_A,
254 T *Mat_B,
255 Real *eigen_val,
256 T *eigen_vec);
257};
258
259template <typename T, typename Device>
261 using Real = typename GetTypeReal<T>::type;
284 const int n,
285 const int lda,
286 T *Mat_A,
287 T *Mat_B,
288 const int m,
289 Real *eigen_val,
290 T *eigen_vec);
291};
292
293
294#if defined(__CUDA) || defined(__ROCM)
295// TODO: Use C++ singleton to manage the GPU handles
296void createGpuSolverHandle(); // create cusolver handle
297void destroyGpuSolverHandle(); // destroy cusolver handle
298#endif
299
300} // namespace container
301} // namespace kernels
302
303#endif // ATEN_KERNELS_LAPACK_H_
This is a direct wrapper of some LAPACK routines. Column-Major version. Direct wrapping of standard L...
#define T
Definition exp.cpp:237
Definition tensor.cpp:8
T type
Definition tensor_types.h:89
void operator()(const int m, const int n, T *A, const int lda)
Perform in-place QR factorization of a matrix using LAPACK's geqrf function.
Definition lapack.h:45
void operator()(const int &m, const int &n, T *Mat, const int &lda, int *ipiv)
Definition lapack.h:56
void operator()(const int &n, T *Mat, const int &lda, const int *ipiv, T *work, const int &lwork)
Definition lapack.h:118
void operator()(const char &trans, const int &n, const int &nrhs, T *A, const int &lda, const int *ipiv, T *B, const int &ldb)
Definition lapack.h:147
void operator()(const int dim, T *Mat, const int lda, Real *eigen_val)
typename GetTypeReal< T >::type Real
Computes all eigenvalues and, optionally, eigenvectors of a complex Hermitian matrix.
Definition lapack.h:174
Definition lapack.h:183
typename GetTypeReal< T >::type Real
Definition lapack.h:184
void operator()(const int dim, const int lda, const T *Mat, const int neig, Real *eigen_val, T *eigen_vec)
Computes selected eigenvalues and, optionally, eigenvectors of a complex Hermitian matrix.
Definition lapack.h:231
void operator()(const int n, const int lda, T *Mat_A, T *Mat_B, Real *eigen_val, T *eigen_vec)
Computes all the eigenvalues and, optionally, the eigenvectors of a complex generalized Hermitian-def...
typename GetTypeReal< T >::type Real
Definition lapack.h:232
Definition lapack.h:260
void operator()(const int n, const int lda, T *Mat_A, T *Mat_B, const int m, Real *eigen_val, T *eigen_vec)
typename GetTypeReal< T >::type Real
Definition lapack.h:261
Definition lapack.h:36
void operator()(const char &uplo, const int &dim, T *Mat, const int &lda)
Definition lapack.h:25
void operator()(const char &uplo, const char &diag, const int &dim, T *Mat, const int &lda)
Definition lapack.h:15
void operator()(const char &uplo, T *A, const int &dim)
This file contains the definition of the DataType enum class.