gemm.h
Go to the documentation of this file.
1 //=================================================================================================
33 //=================================================================================================
34 
35 #ifndef _BLAZE_MATH_BLAS_GEMM_H_
36 #define _BLAZE_MATH_BLAS_GEMM_H_
37 
38 
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42 
43 #include <boost/cast.hpp>
44 #include <blaze/math/Aliases.h>
52 #include <blaze/system/BLAS.h>
53 #include <blaze/system/Inline.h>
54 #include <blaze/util/Assert.h>
55 #include <blaze/util/Complex.h>
56 
57 
58 namespace blaze {
59 
60 //=================================================================================================
61 //
62 // BLAS WRAPPER FUNCTIONS (GEMM)
63 //
64 //=================================================================================================
65 
66 //*************************************************************************************************
69 #if BLAZE_BLAS_MODE
70 
71 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
72  int m, int n, int k, float alpha, const float* A, int lda,
73  const float* B, int ldb, float beta, float* C, int ldc );
74 
75 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
76  int m, int n, int k, double alpha, const double* A, int lda,
77  const double* B, int ldb, double beta, float* C, int ldc );
78 
79 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
80  int m, int n, int k, complex<float> alpha, const complex<float>* A,
81  int lda, const complex<float>* B, int ldb, complex<float> beta,
82  float* C, int ldc );
83 
84 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
85  int m, int n, int k, complex<double> alpha, const complex<double>* A,
86  int lda, const complex<double>* B, int ldb, complex<double> beta,
87  float* C, int ldc );
88 
89 template< typename MT1, bool SO1, typename MT2, bool SO2, typename MT3, bool SO3, typename ST >
90 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
91  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta );
92 
93 #endif
94 
95 //*************************************************************************************************
96 
97 
98 //*************************************************************************************************
99 #if BLAZE_BLAS_MODE
100 
123 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
124  int m, int n, int k, float alpha, const float* A, int lda,
125  const float* B, int ldb, float beta, float* C, int ldc )
126 {
127  cblas_sgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
128 }
129 #endif
130 //*************************************************************************************************
131 
132 
133 //*************************************************************************************************
134 #if BLAZE_BLAS_MODE
135 
158 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
159  int m, int n, int k, double alpha, const double* A, int lda,
160  const double* B, int ldb, double beta, double* C, int ldc )
161 {
162  cblas_dgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
163 }
164 #endif
165 //*************************************************************************************************
166 
167 
168 //*************************************************************************************************
169 #if BLAZE_BLAS_MODE
170 
193 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
194  int m, int n, int k, complex<float> alpha, const complex<float>* A,
195  int lda, const complex<float>* B, int ldb, complex<float> beta,
196  complex<float>* C, int ldc )
197 {
198  cblas_cgemm( order, transA, transB, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc );
199 }
200 #endif
201 //*************************************************************************************************
202 
203 
204 //*************************************************************************************************
205 #if BLAZE_BLAS_MODE
206 
229 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
230  int m, int n, int k, complex<double> alpha, const complex<double>* A,
231  int lda, const complex<double>* B, int ldb, complex<double> beta,
232  complex<double>* C, int ldc )
233 {
234  cblas_zgemm( order, transA, transB, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc );
235 }
236 #endif
237 //*************************************************************************************************
238 
239 
240 //*************************************************************************************************
241 #if BLAZE_BLAS_MODE
242 
257 template< typename MT1 // Type of the left-hand side target matrix
258  , bool SO1 // Storage order of the left-hand side target matrix
259  , typename MT2 // Type of the left-hand side matrix operand
260  , bool SO2 // Storage order of the left-hand side matrix operand
261  , typename MT3 // Type of the right-hand side matrix operand
262  , bool SO3 // Storage order of the right-hand side matrix operand
263  , typename ST > // Type of the scalar factors
264 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
265  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta )
266 {
267  using boost::numeric_cast;
268 
272 
276 
280 
281  const int m ( numeric_cast<int>( (~A).rows() ) );
282  const int n ( numeric_cast<int>( (~B).columns() ) );
283  const int k ( numeric_cast<int>( (~A).columns() ) );
284  const int lda( numeric_cast<int>( (~A).spacing() ) );
285  const int ldb( numeric_cast<int>( (~B).spacing() ) );
286  const int ldc( numeric_cast<int>( (~C).spacing() ) );
287 
288  gemm( ( IsRowMajorMatrix<MT1>::value )?( CblasRowMajor ):( CblasColMajor ),
289  ( SO1 == SO2 )?( CblasNoTrans ):( CblasTrans ),
290  ( SO1 == SO3 )?( CblasNoTrans ):( CblasTrans ),
291  m, n, k, alpha, (~A).data(), lda, (~B).data(), ldb, beta, (~C).data(), ldc );
292 }
293 #endif
294 //*************************************************************************************************
295 
296 } // namespace blaze
297 
298 #endif
Constraint on the data type.
Header file for auxiliary alias declarations.
#define BLAZE_CONSTRAINT_MUST_HAVE_MUTABLE_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to m...
Definition: MutableDataAccess.h:61
#define BLAZE_CONSTRAINT_MUST_HAVE_CONST_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to c...
Definition: ConstDataAccess.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i...
Definition: Computation.h:81
constexpr bool spacing
Adding an additional spacing line between two log messages.This setting gives the opportunity to add ...
Definition: Logging.h:70
Constraint on the data type.
Header file for the IsSymmetric type trait.
Namespace of the Blaze C++ math library.
Definition: Blaze.h:57
#define BLAZE_ALWAYS_INLINE
Platform dependent setup of an enforced inline keyword.
Definition: Inline.h:85
Header file for the DenseMatrix base class.
BLAZE_ALWAYS_INLINE size_t columns(const Matrix< MT, SO > &matrix) noexcept
Returns the current number of columns of the matrix.
Definition: Matrix.h:330
Constraint on the data type.
System settings for the BLAS mode.
Header file for run time assertion macros.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE(T)
Constraint on the data type.In case the given data type T is not a BLAS compatible data type (i...
Definition: BLASCompatible.h:61
BLAZE_ALWAYS_INLINE size_t rows(const Matrix< MT, SO > &matrix) noexcept
Returns the current number of rows of the matrix.
Definition: Matrix.h:314
Header file for the IsRowMajorMatrix type trait.
Header file for the complex data type.
System settings for the inline keywords.