ABACUS develop
Atomic-orbital Based Ab-initio Computation at UStc
Loading...
Searching...
No Matches
dgemm_vbatch.h
Go to the documentation of this file.
1#pragma once
2
3#include <cuda_runtime.h>
4
5// Template version: C(batch_id) = alpha * A(batch_id) * B(batch_id) + C(batch_id)
6// As with gemm_tn_vbatch, the C accumulator is always double regardless of the
7// input type T so the per-block reduction and device-side atomicAdd run in fp64.
8template<typename T>
10 int max_m, int max_n, int max_k,
11 const int* m_d, const int* n_d, const int* k_d,
12 const T* const* A_array_d, const int* lda_d,
13 const T* const* B_array_d, const int* ldb_d,
14 double** C_array_d, const int* ldc_d,
15 int batchCount, cudaStream_t stream,
16 const T* alpha = nullptr);
17
18// Template version: C(batch_id) = alpha * A(batch_id)^T * B(batch_id) + C(batch_id)
19// The C accumulator is always double regardless of input type T: a fp32 GEMM
20// path (T=float) feeds fp32 multiplies into fp64 accumulators (registers and
21// device-side atomicAdds) to avoid catastrophic precision loss across many
22// atom-pair contributions to the same hr_gint element.
23template<typename T>
25 int max_m, int max_n, int max_k,
26 const int* m_d, const int* n_d, const int* k_d,
27 const T* const* A_array_d, const int* lda_d,
28 const T* const* B_array_d, const int* ldb_d,
29 double** C_array_d, const int* ldc_d,
30 int batchCount, cudaStream_t stream,
31 const T* alpha = nullptr);
32
33// Legacy double-only aliases for backward compatibility
34inline void dgemm_nn_vbatch(
35 int max_m, int max_n, int max_k,
36 const int* m_d, const int* n_d, const int* k_d,
37 const double* const* A_array_d, const int* lda_d,
38 const double* const* B_array_d, const int* ldb_d,
39 double** C_array_d, const int* ldc_d,
40 int batchCount, cudaStream_t stream,
41 const double* alpha = nullptr)
42{
43 gemm_nn_vbatch<double>(max_m, max_n, max_k,
44 m_d, n_d, k_d, A_array_d, lda_d, B_array_d, ldb_d,
45 C_array_d, ldc_d, batchCount, stream, alpha);
46}
47
48inline void dgemm_tn_vbatch(
49 int max_m, int max_n, int max_k,
50 const int* m_d, const int* n_d, const int* k_d,
51 const double* const* A_array_d, const int* lda_d,
52 const double* const* B_array_d, const int* ldb_d,
53 double** C_array_d, const int* ldc_d,
54 int batchCount, cudaStream_t stream,
55 const double* alpha = nullptr)
56{
57 // T=double path: A, B, and C are all double — the C-channel double-fix
58 // matches the legacy signature here.
59 gemm_tn_vbatch<double>(max_m, max_n, max_k,
60 m_d, n_d, k_d, A_array_d, lda_d, B_array_d, ldb_d,
61 C_array_d, ldc_d, batchCount, stream, alpha);
62}
void dgemm_tn_vbatch(int max_m, int max_n, int max_k, const int *m_d, const int *n_d, const int *k_d, const double *const *A_array_d, const int *lda_d, const double *const *B_array_d, const int *ldb_d, double **C_array_d, const int *ldc_d, int batchCount, cudaStream_t stream, const double *alpha=nullptr)
Definition dgemm_vbatch.h:48
void gemm_tn_vbatch(int max_m, int max_n, int max_k, const int *m_d, const int *n_d, const int *k_d, const T *const *A_array_d, const int *lda_d, const T *const *B_array_d, const int *ldb_d, double **C_array_d, const int *ldc_d, int batchCount, cudaStream_t stream, const T *alpha=nullptr)
void dgemm_nn_vbatch(int max_m, int max_n, int max_k, const int *m_d, const int *n_d, const int *k_d, const double *const *A_array_d, const int *lda_d, const double *const *B_array_d, const int *ldb_d, double **C_array_d, const int *ldc_d, int batchCount, cudaStream_t stream, const double *alpha=nullptr)
Definition dgemm_vbatch.h:34
void gemm_nn_vbatch(int max_m, int max_n, int max_k, const int *m_d, const int *n_d, const int *k_d, const T *const *A_array_d, const int *lda_d, const T *const *B_array_d, const int *ldb_d, double **C_array_d, const int *ldc_d, int batchCount, cudaStream_t stream, const T *alpha=nullptr)
#define T
Definition exp.cpp:237