Blaze  3.6
gemm.h
Go to the documentation of this file.
1 //=================================================================================================
33 //=================================================================================================
34 
35 #ifndef _BLAZE_MATH_BLAS_GEMM_H_
36 #define _BLAZE_MATH_BLAS_GEMM_H_
37 
38 
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42 
43 #include <blaze/math/Aliases.h>
50 #include <blaze/system/BLAS.h>
51 #include <blaze/system/Inline.h>
52 #include <blaze/util/Assert.h>
53 #include <blaze/util/Complex.h>
54 #include <blaze/util/NumericCast.h>
56 
57 
58 namespace blaze {
59 
60 //=================================================================================================
61 //
62 // BLAS WRAPPER FUNCTIONS (GEMM)
63 //
64 //=================================================================================================
65 
66 //*************************************************************************************************
69 #if BLAZE_BLAS_MODE
70 
71 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
72  int m, int n, int k, float alpha, const float* A, int lda,
73  const float* B, int ldb, float beta, float* C, int ldc );
74 
75 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
76  int m, int n, int k, double alpha, const double* A, int lda,
77  const double* B, int ldb, double beta, float* C, int ldc );
78 
79 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
80  int m, int n, int k, complex<float> alpha, const complex<float>* A,
81  int lda, const complex<float>* B, int ldb, complex<float> beta,
82  float* C, int ldc );
83 
84 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
85  int m, int n, int k, complex<double> alpha, const complex<double>* A,
86  int lda, const complex<double>* B, int ldb, complex<double> beta,
87  float* C, int ldc );
88 
89 template< typename MT1, bool SO1, typename MT2, bool SO2, typename MT3, bool SO3, typename ST >
90 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
91  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta );
92 
93 #endif
94 
95 //*************************************************************************************************
96 
97 
98 //*************************************************************************************************
99 #if BLAZE_BLAS_MODE
100 
123 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
124  int m, int n, int k, float alpha, const float* A, int lda,
125  const float* B, int ldb, float beta, float* C, int ldc )
126 {
127  cblas_sgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
128 }
129 #endif
130 //*************************************************************************************************
131 
132 
133 //*************************************************************************************************
134 #if BLAZE_BLAS_MODE
135 
158 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
159  int m, int n, int k, double alpha, const double* A, int lda,
160  const double* B, int ldb, double beta, double* C, int ldc )
161 {
162  cblas_dgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
163 }
164 #endif
165 //*************************************************************************************************
166 
167 
168 //*************************************************************************************************
169 #if BLAZE_BLAS_MODE
170 
193 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
194  int m, int n, int k, complex<float> alpha, const complex<float>* A,
195  int lda, const complex<float>* B, int ldb, complex<float> beta,
196  complex<float>* C, int ldc )
197 {
198  BLAZE_STATIC_ASSERT( sizeof( complex<float> ) == 2UL*sizeof( float ) );
199 
200  cblas_cgemm( order, transA, transB, m, n, k, reinterpret_cast<const float*>( &alpha ),
201  reinterpret_cast<const float*>( A ), lda, reinterpret_cast<const float*>( B ),
202  ldb, reinterpret_cast<const float*>( &beta ), reinterpret_cast<float*>( C ), ldc );
203 }
204 #endif
205 //*************************************************************************************************
206 
207 
208 //*************************************************************************************************
209 #if BLAZE_BLAS_MODE
210 
233 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
234  int m, int n, int k, complex<double> alpha, const complex<double>* A,
235  int lda, const complex<double>* B, int ldb, complex<double> beta,
236  complex<double>* C, int ldc )
237 {
238  BLAZE_STATIC_ASSERT( sizeof( complex<double> ) == 2UL*sizeof( double ) );
239 
240  cblas_zgemm( order, transA, transB, m, n, k, reinterpret_cast<const double*>( &alpha ),
241  reinterpret_cast<const double*>( A ), lda, reinterpret_cast<const double*>( B ),
242  ldb, reinterpret_cast<const double*>( &beta ), reinterpret_cast<double*>( C ), ldc );
243 }
244 #endif
245 //*************************************************************************************************
246 
247 
248 //*************************************************************************************************
249 #if BLAZE_BLAS_MODE
250 
265 template< typename MT1 // Type of the left-hand side target matrix
266  , bool SO1 // Storage order of the left-hand side target matrix
267  , typename MT2 // Type of the left-hand side matrix operand
268  , bool SO2 // Storage order of the left-hand side matrix operand
269  , typename MT3 // Type of the right-hand side matrix operand
270  , bool SO3 // Storage order of the right-hand side matrix operand
271  , typename ST > // Type of the scalar factors
272 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
273  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta )
274 {
278 
282 
283  BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE( ElementType_t<MT1> );
284  BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE( ElementType_t<MT2> );
285  BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE( ElementType_t<MT3> );
286 
287  const int m ( numeric_cast<int>( (~A).rows() ) );
288  const int n ( numeric_cast<int>( (~B).columns() ) );
289  const int k ( numeric_cast<int>( (~A).columns() ) );
290  const int lda( numeric_cast<int>( (~A).spacing() ) );
291  const int ldb( numeric_cast<int>( (~B).spacing() ) );
292  const int ldc( numeric_cast<int>( (~C).spacing() ) );
293 
294  gemm( ( IsRowMajorMatrix_v<MT1> )?( CblasRowMajor ):( CblasColMajor ),
295  ( SO1 == SO2 )?( CblasNoTrans ):( CblasTrans ),
296  ( SO1 == SO3 )?( CblasNoTrans ):( CblasTrans ),
297  m, n, k, alpha, (~A).data(), lda, (~B).data(), ldb, beta, (~C).data(), ldc );
298 }
299 #endif
300 //*************************************************************************************************
301 
302 } // namespace blaze
303 
304 #endif
Constraint on the data type.
Header file for auxiliary alias declarations.
#define BLAZE_CONSTRAINT_MUST_HAVE_MUTABLE_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to m...
Definition: MutableDataAccess.h:61
#define BLAZE_CONSTRAINT_MUST_HAVE_CONST_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to c...
Definition: ConstDataAccess.h:61
MT::ElementType * data(DenseMatrix< MT, SO > &dm) noexcept
Low-level data access to the dense matrix elements.
Definition: DenseMatrix.h:170
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i....
Definition: Computation.h:81
Cast operators for numeric types.
constexpr size_t columns(const Matrix< MT, SO > &matrix) noexcept
Returns the current number of columns of the matrix.
Definition: Matrix.h:514
size_t spacing(const DenseMatrix< MT, SO > &dm) noexcept
Returns the spacing between the beginning of two rows/columns.
Definition: DenseMatrix.h:253
Constraint on the data type.
Namespace of the Blaze C++ math library.
Definition: Blaze.h:58
#define BLAZE_ALWAYS_INLINE
Platform dependent setup of an enforced inline keyword.
Definition: Inline.h:85
Compile time assertion.
Header file for the DenseMatrix base class.
Constraint on the data type.
System settings for the BLAS mode.
Header file for run time assertion macros.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE(T)
Constraint on the data type.In case the given data type T is not a BLAS compatible data type (i....
Definition: BLASCompatible.h:61
constexpr size_t rows(const Matrix< MT, SO > &matrix) noexcept
Returns the current number of rows of the matrix.
Definition: Matrix.h:498
Header file for the IsRowMajorMatrix type trait.
Header file for the complex data type.
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.In case of an invalid compile time expression, a compilation error is cr...
Definition: StaticAssert.h:112
System settings for the inline keywords.