gemm.h
Go to the documentation of this file.
1 //=================================================================================================
33 //=================================================================================================
34 
35 #ifndef _BLAZE_MATH_BLAS_GEMM_H_
36 #define _BLAZE_MATH_BLAS_GEMM_H_
37 
38 
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42 
43 #include <boost/cast.hpp>
44 #include <blaze/math/Aliases.h>
52 #include <blaze/system/BLAS.h>
53 #include <blaze/system/Inline.h>
54 #include <blaze/util/Assert.h>
55 #include <blaze/util/Complex.h>
57 
58 
59 namespace blaze {
60 
61 //=================================================================================================
62 //
63 // BLAS WRAPPER FUNCTIONS (GEMM)
64 //
65 //=================================================================================================
66 
67 //*************************************************************************************************
70 #if BLAZE_BLAS_MODE
71 
72 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
73  int m, int n, int k, float alpha, const float* A, int lda,
74  const float* B, int ldb, float beta, float* C, int ldc );
75 
76 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
77  int m, int n, int k, double alpha, const double* A, int lda,
78  const double* B, int ldb, double beta, float* C, int ldc );
79 
80 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
81  int m, int n, int k, complex<float> alpha, const complex<float>* A,
82  int lda, const complex<float>* B, int ldb, complex<float> beta,
83  float* C, int ldc );
84 
85 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
86  int m, int n, int k, complex<double> alpha, const complex<double>* A,
87  int lda, const complex<double>* B, int ldb, complex<double> beta,
88  float* C, int ldc );
89 
90 template< typename MT1, bool SO1, typename MT2, bool SO2, typename MT3, bool SO3, typename ST >
91 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
92  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta );
93 
94 #endif
95 
96 //*************************************************************************************************
97 
98 
99 //*************************************************************************************************
100 #if BLAZE_BLAS_MODE
101 
124 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
125  int m, int n, int k, float alpha, const float* A, int lda,
126  const float* B, int ldb, float beta, float* C, int ldc )
127 {
128  cblas_sgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
129 }
130 #endif
131 //*************************************************************************************************
132 
133 
134 //*************************************************************************************************
135 #if BLAZE_BLAS_MODE
136 
159 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
160  int m, int n, int k, double alpha, const double* A, int lda,
161  const double* B, int ldb, double beta, double* C, int ldc )
162 {
163  cblas_dgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
164 }
165 #endif
166 //*************************************************************************************************
167 
168 
169 //*************************************************************************************************
170 #if BLAZE_BLAS_MODE
171 
194 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
195  int m, int n, int k, complex<float> alpha, const complex<float>* A,
196  int lda, const complex<float>* B, int ldb, complex<float> beta,
197  complex<float>* C, int ldc )
198 {
199  BLAZE_STATIC_ASSERT( sizeof( complex<float> ) == 2UL*sizeof( float ) );
200 
201  cblas_cgemm( order, transA, transB, m, n, k, reinterpret_cast<const float*>( &alpha ),
202  reinterpret_cast<const float*>( A ), lda, reinterpret_cast<const float*>( B ),
203  ldb, reinterpret_cast<const float*>( &beta ), reinterpret_cast<float*>( C ), ldc );
204 }
205 #endif
206 //*************************************************************************************************
207 
208 
209 //*************************************************************************************************
210 #if BLAZE_BLAS_MODE
211 
234 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
235  int m, int n, int k, complex<double> alpha, const complex<double>* A,
236  int lda, const complex<double>* B, int ldb, complex<double> beta,
237  complex<double>* C, int ldc )
238 {
239  BLAZE_STATIC_ASSERT( sizeof( complex<double> ) == 2UL*sizeof( double ) );
240 
241  cblas_zgemm( order, transA, transB, m, n, k, reinterpret_cast<const double*>( &alpha ),
242  reinterpret_cast<const double*>( A ), lda, reinterpret_cast<const double*>( B ),
243  ldb, reinterpret_cast<const double*>( &beta ), reinterpret_cast<double*>( C ), ldc );
244 }
245 #endif
246 //*************************************************************************************************
247 
248 
249 //*************************************************************************************************
250 #if BLAZE_BLAS_MODE
251 
266 template< typename MT1 // Type of the left-hand side target matrix
267  , bool SO1 // Storage order of the left-hand side target matrix
268  , typename MT2 // Type of the left-hand side matrix operand
269  , bool SO2 // Storage order of the left-hand side matrix operand
270  , typename MT3 // Type of the right-hand side matrix operand
271  , bool SO3 // Storage order of the right-hand side matrix operand
272  , typename ST > // Type of the scalar factors
273 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
274  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta )
275 {
276  using boost::numeric_cast;
277 
281 
285 
289 
290  const int m ( numeric_cast<int>( (~A).rows() ) );
291  const int n ( numeric_cast<int>( (~B).columns() ) );
292  const int k ( numeric_cast<int>( (~A).columns() ) );
293  const int lda( numeric_cast<int>( (~A).spacing() ) );
294  const int ldb( numeric_cast<int>( (~B).spacing() ) );
295  const int ldc( numeric_cast<int>( (~C).spacing() ) );
296 
297  gemm( ( IsRowMajorMatrix<MT1>::value )?( CblasRowMajor ):( CblasColMajor ),
298  ( SO1 == SO2 )?( CblasNoTrans ):( CblasTrans ),
299  ( SO1 == SO3 )?( CblasNoTrans ):( CblasTrans ),
300  m, n, k, alpha, (~A).data(), lda, (~B).data(), ldb, beta, (~C).data(), ldc );
301 }
302 #endif
303 //*************************************************************************************************
304 
305 } // namespace blaze
306 
307 #endif
Constraint on the data type.
BLAZE_ALWAYS_INLINE size_t spacing(const DenseMatrix< MT, SO > &dm) noexcept
Returns the spacing between the beginning of two rows/columns.
Definition: DenseMatrix.h:102
Header file for auxiliary alias declarations.
#define BLAZE_CONSTRAINT_MUST_HAVE_MUTABLE_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to m...
Definition: MutableDataAccess.h:61
#define BLAZE_CONSTRAINT_MUST_HAVE_CONST_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to c...
Definition: ConstDataAccess.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i...
Definition: Computation.h:81
Constraint on the data type.
Header file for the IsSymmetric type trait.
Namespace of the Blaze C++ math library.
Definition: Blaze.h:57
#define BLAZE_ALWAYS_INLINE
Platform dependent setup of an enforced inline keyword.
Definition: Inline.h:85
Compile time assertion.
Header file for the DenseMatrix base class.
BLAZE_ALWAYS_INLINE size_t columns(const Matrix< MT, SO > &matrix) noexcept
Returns the current number of columns of the matrix.
Definition: Matrix.h:336
Constraint on the data type.
System settings for the BLAS mode.
Header file for run time assertion macros.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE(T)
Constraint on the data type.In case the given data type T is not a BLAS compatible data type (i...
Definition: BLASCompatible.h:61
BLAZE_ALWAYS_INLINE size_t rows(const Matrix< MT, SO > &matrix) noexcept
Returns the current number of rows of the matrix.
Definition: Matrix.h:320
Header file for the IsRowMajorMatrix type trait.
Header file for the complex data type.
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.In case of an invalid compile time expression, a compilation error is cr...
Definition: StaticAssert.h:112
System settings for the inline keywords.