Blaze  3.6
DMatSoftmaxExpr.h
Go to the documentation of this file.
1 //=================================================================================================
33 //=================================================================================================
34 
35 #ifndef _BLAZE_MATH_EXPRESSIONS_DMATSOFTMAXEXPR_H_
36 #define _BLAZE_MATH_EXPRESSIONS_DMATSOFTMAXEXPR_H_
37 
38 
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42 
45 #include <blaze/math/views/Check.h>
47 #include <blaze/math/views/Row.h>
48 
49 
50 namespace blaze {
51 
52 //=================================================================================================
53 //
54 // GLOBAL FUNCTIONS
55 //
56 //=================================================================================================
57 
58 //*************************************************************************************************
86 template< typename MT // Type of the dense matrix
87  , bool SO > // Storage order
88 auto softmax( const DenseMatrix<MT,SO>& dm )
89 {
90  auto tmp( evaluate( exp( ~dm ) ) );
91  const auto scalar( sum( tmp ) );
92  tmp /= scalar;
93  return tmp;
94 }
95 //*************************************************************************************************
96 
97 
98 //*************************************************************************************************
134 template< bool RF // Reduction flag
135  , typename MT // Type of the dense matrix
136  , bool SO > // Storage order
137 auto softmax( const DenseMatrix<MT,SO>& dm )
138 {
139  auto tmp( evaluate( exp( ~dm ) ) );
140 
141  if( RF == rowwise ) {
142  for( size_t i=0UL; i<tmp.rows(); ++i ) {
143  auto r = row( tmp, i, unchecked );
144  const auto scalar( sum( r ) );
145  r /= scalar;
146  }
147  }
148  else {
149  for( size_t j=0UL; j<tmp.columns(); ++j ) {
150  auto c = column( tmp, j, unchecked );
151  const auto scalar( sum( c ) );
152  c /= scalar;
153  }
154  }
155 
156  return tmp;
157 }
158 //*************************************************************************************************
159 
160 } // namespace blaze
161 
162 #endif
decltype(auto) column(Matrix< MT, SO > &matrix, RCAs... args)
Creating a view on a specific column of the given matrix.
Definition: Column.h:133
Header file for the blaze::checked and blaze::unchecked instances.
constexpr Unchecked unchecked
Global Unchecked instance.The blaze::unchecked instance is an optional token for the creation of view...
Definition: Check.h:138
const MT::ResultType evaluate(const Matrix< MT, SO > &matrix)
Evaluates the given matrix expression.
Definition: Matrix.h:912
Base class for dense matrices.The DenseMatrix class is a base class for all dense matrix classes....
Definition: DenseMatrix.h:81
Namespace of the Blaze C++ math library.
Definition: Blaze.h:58
decltype(auto) sum(const DenseMatrix< MT, SO > &dm)
Reduces the given dense matrix by means of addition.
Definition: DMatReduceExpr.h:2147
Header file for the DenseMatrix base class.
decltype(auto) exp(const DenseMatrix< MT, SO > &dm)
Computes for each single element of the dense matrix dm.
Definition: DMatMapExpr.h:1632
Header file for the implementation of the Column view.
decltype(auto) row(Matrix< MT, SO > &, RRAs...)
Creating a view on a specific row of the given matrix.
Definition: Row.h:133
constexpr size_t rowwise
Reduction flag for row-wise reduction operations.
Definition: ReductionFlag.h:70
auto softmax(const DenseMatrix< MT, SO > &dm)
Computes the softmax function for the given dense matrix.
Definition: DMatSoftmaxExpr.h:88
Header file for the implementation of the Row view.
Header file for the reduction flags.