Blaze 3.9
DMatSoftmaxExpr.h
Go to the documentation of this file.
1//=================================================================================================
33//=================================================================================================
34
35#ifndef _BLAZE_MATH_EXPRESSIONS_DMATSOFTMAXEXPR_H_
36#define _BLAZE_MATH_EXPRESSIONS_DMATSOFTMAXEXPR_H_
37
38
39//*************************************************************************************************
40// Includes
41//*************************************************************************************************
42
48
49
50namespace blaze {
51
52//=================================================================================================
53//
54// GLOBAL FUNCTIONS
55//
56//=================================================================================================
57
58//*************************************************************************************************
86template< typename MT // Type of the dense matrix
87 , bool SO > // Storage order
88auto softmax( const DenseMatrix<MT,SO>& dm )
89{
90 auto tmp( evaluate( exp( *dm - max( *dm ) ) ) );
91 const auto scalar( sum( tmp ) );
92 tmp /= scalar;
93 return tmp;
94}
95//*************************************************************************************************
96
97
98//*************************************************************************************************
134template< ReductionFlag RF // Reduction flag
135 , typename MT // Type of the dense matrix
136 , bool SO > // Storage order
138{
139 const size_t expansion( ( RF == rowwise ) ? (*dm).columns() : (*dm).rows() );
140 auto tmp( evaluate( exp( *dm - expand( max<RF>(*dm), expansion ) ) ) );
141
142 if( RF == rowwise ) {
143 for( size_t i=0UL; i<tmp.rows(); ++i ) {
144 auto r = row( tmp, i, unchecked );
145 const auto scalar( sum( r ) );
146 r /= scalar;
147 }
148 }
149 else {
150 for( size_t j=0UL; j<tmp.columns(); ++j ) {
151 auto c = column( tmp, j, unchecked );
152 const auto scalar( sum( c ) );
153 c /= scalar;
154 }
155 }
156
157 return tmp;
158}
159//*************************************************************************************************
160
161} // namespace blaze
162
163#endif
Header file for the blaze::checked and blaze::unchecked instances.
Header file for the reduction flags.
constexpr ReductionFlag rowwise
Reduction flag for row-wise reduction operations.
Definition: ReductionFlag.h:77
Base class for dense matrices.
Definition: DenseMatrix.h:82
Header file for the DenseMatrix base class.
decltype(auto) column(Matrix< MT, SO > &matrix, RCAs... args)
Creating a view on a specific column of the given matrix.
Definition: Column.h:137
decltype(auto) exp(const DenseMatrix< MT, SO > &dm)
Computes for each single element of the dense matrix dm.
Definition: DMatMapExpr.h:1801
decltype(auto) max(const DenseMatrix< MT1, SO1 > &lhs, const DenseMatrix< MT2, SO2 > &rhs)
Computes the componentwise maximum of the dense matrices lhs and rhs.
Definition: DMatDMatMapExpr.h:1375
auto softmax(const DenseMatrix< MT, SO > &dm)
Computes the row-/columnwise softmax function for the given dense matrix.
Definition: DMatSoftmaxExpr.h:137
decltype(auto) sum(const DenseMatrix< MT, SO > &dm)
Reduces the given dense matrix by means of addition.
Definition: DMatReduceExpr.h:2156
decltype(auto) expand(const DenseVector< VT, TF > &dv, size_t expansion)
Expansion of the given dense vector.
Definition: DVecExpandExpr.h:746
MT::ResultType evaluate(const Matrix< MT, SO > &matrix)
Evaluates the given matrix expression.
Definition: Matrix.h:1282
decltype(auto) row(Matrix< MT, SO > &, RRAs...)
Creating a view on a specific row of the given matrix.
Definition: Row.h:137
constexpr Unchecked unchecked
Global Unchecked instance.
Definition: Check.h:146
Header file for the implementation of the Column view.
Header file for the implementation of the Row view.