gemm.h
Go to the documentation of this file.
1 //=================================================================================================
33 //=================================================================================================
34 
35 #ifndef _BLAZE_MATH_BLAS_GEMM_H_
36 #define _BLAZE_MATH_BLAS_GEMM_H_
37 
38 
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42 
43 #include <boost/cast.hpp>
51 #include <blaze/system/BLAS.h>
52 #include <blaze/system/Inline.h>
53 #include <blaze/util/Assert.h>
54 #include <blaze/util/Complex.h>
55 
56 
57 namespace blaze {
58 
59 //=================================================================================================
60 //
61 // BLAS WRAPPER FUNCTIONS (GEMM)
62 //
63 //=================================================================================================
64 
65 //*************************************************************************************************
68 #if BLAZE_BLAS_MODE
69 
70 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
71  int m, int n, int k, float alpha, const float* A, int lda,
72  const float* B, int ldb, float beta, float* C, int ldc );
73 
74 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
75  int m, int n, int k, double alpha, const double* A, int lda,
76  const double* B, int ldb, double beta, float* C, int ldc );
77 
78 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
79  int m, int n, int k, complex<float> alpha, const complex<float>* A,
80  int lda, const complex<float>* B, int ldb, complex<float> beta,
81  float* C, int ldc );
82 
83 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
84  int m, int n, int k, complex<double> alpha, const complex<double>* A,
85  int lda, const complex<double>* B, int ldb, complex<double> beta,
86  float* C, int ldc );
87 
88 template< typename MT1, bool SO1, typename MT2, bool SO2, typename MT3, bool SO3, typename ST >
89 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
90  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta );
91 
92 #endif
93 
94 //*************************************************************************************************
95 
96 
97 //*************************************************************************************************
98 #if BLAZE_BLAS_MODE
99 
122 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
123  int m, int n, int k, float alpha, const float* A, int lda,
124  const float* B, int ldb, float beta, float* C, int ldc )
125 {
126  cblas_sgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
127 }
128 #endif
129 //*************************************************************************************************
130 
131 
132 //*************************************************************************************************
133 #if BLAZE_BLAS_MODE
134 
157 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
158  int m, int n, int k, double alpha, const double* A, int lda,
159  const double* B, int ldb, double beta, double* C, int ldc )
160 {
161  cblas_dgemm( order, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc );
162 }
163 #endif
164 //*************************************************************************************************
165 
166 
167 //*************************************************************************************************
168 #if BLAZE_BLAS_MODE
169 
192 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
193  int m, int n, int k, complex<float> alpha, const complex<float>* A,
194  int lda, const complex<float>* B, int ldb, complex<float> beta,
195  complex<float>* C, int ldc )
196 {
197  cblas_cgemm( order, transA, transB, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc );
198 }
199 #endif
200 //*************************************************************************************************
201 
202 
203 //*************************************************************************************************
204 #if BLAZE_BLAS_MODE
205 
228 BLAZE_ALWAYS_INLINE void gemm( CBLAS_ORDER order, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB,
229  int m, int n, int k, complex<double> alpha, const complex<double>* A,
230  int lda, const complex<double>* B, int ldb, complex<double> beta,
231  complex<double>* C, int ldc )
232 {
233  cblas_zgemm( order, transA, transB, m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc );
234 }
235 #endif
236 //*************************************************************************************************
237 
238 
239 //*************************************************************************************************
240 #if BLAZE_BLAS_MODE
241 
256 template< typename MT1 // Type of the left-hand side target matrix
257  , bool SO1 // Storage order of the left-hand side target matrix
258  , typename MT2 // Type of the left-hand side matrix operand
259  , bool SO2 // Storage order of the left-hand side matrix operand
260  , typename MT3 // Type of the right-hand side matrix operand
261  , bool SO3 // Storage order of the right-hand side matrix operand
262  , typename ST > // Type of the scalar factors
263 BLAZE_ALWAYS_INLINE void gemm( DenseMatrix<MT1,SO1>& C, const DenseMatrix<MT2,SO2>& A,
264  const DenseMatrix<MT3,SO3>& B, ST alpha, ST beta )
265 {
266  using boost::numeric_cast;
267 
271 
275 
279 
280  const int m ( numeric_cast<int>( (~A).rows() ) );
281  const int n ( numeric_cast<int>( (~B).columns() ) );
282  const int k ( numeric_cast<int>( (~A).columns() ) );
283  const int lda( numeric_cast<int>( (~A).spacing() ) );
284  const int ldb( numeric_cast<int>( (~B).spacing() ) );
285  const int ldc( numeric_cast<int>( (~C).spacing() ) );
286 
287  gemm( ( IsRowMajorMatrix<MT1>::value )?( CblasRowMajor ):( CblasColMajor ),
288  ( SO1 == SO2 )?( CblasNoTrans ):( CblasTrans ),
289  ( SO1 == SO3 )?( CblasNoTrans ):( CblasTrans ),
290  m, n, k, alpha, (~A).data(), lda, (~B).data(), ldb, beta, (~C).data(), ldc );
291 }
292 #endif
293 //*************************************************************************************************
294 
295 } // namespace blaze
296 
297 #endif
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_HAVE_MUTABLE_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to m...
Definition: MutableDataAccess.h:79
#define BLAZE_CONSTRAINT_MUST_HAVE_CONST_DATA_ACCESS(T)
Constraint on the data type.In case the given data type T does not provide low-level data access to c...
Definition: ConstDataAccess.h:79
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i...
Definition: Computation.h:118
BLAZE_ALWAYS_INLINE size_t rows(const Matrix< MT, SO > &matrix)
Returns the current number of rows of the matrix.
Definition: Matrix.h:308
Header file for the IsSymmetric type trait.
Namespace of the Blaze C++ math library.
Definition: Blaze.h:57
#define BLAZE_ALWAYS_INLINE
Platform dependent setup of an enforced inline keyword.
Definition: Inline.h:85
Header file for the DenseMatrix base class.
Type ElementType
Type of the sparse matrix elements.
Definition: CompressedMatrix.h:2586
Constraint on the data type.
Constraint on the data type.
const bool spacing
Adding an additional spacing line between two log messages.This setting gives the opportunity to add ...
Definition: Logging.h:70
System settings for the BLAS mode.
Header file for run time assertion macros.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_BLAS_COMPATIBLE_TYPE(T)
Constraint on the data type.In case the given data type T is not a BLAS compatible data type (i...
Definition: BlasCompatible.h:79
Header file for the IsRowMajorMatrix type trait.
BLAZE_ALWAYS_INLINE size_t columns(const Matrix< MT, SO > &matrix)
Returns the current number of columns of the matrix.
Definition: Matrix.h:324
Header file for the complex data type.
System settings for the inline keywords.