22 #ifndef _BLAZE_MATH_EXPRESSIONS_TDMATTSMATMULTEXPR_H_
23 #define _BLAZE_MATH_EXPRESSIONS_TDMATTSMATMULTEXPR_H_
31 #include <boost/type_traits/remove_reference.hpp>
87 template<
typename MT1
89 class TDMatTSMatMultExpr :
public DenseMatrix< TDMatTSMatMultExpr<MT1,MT2>, true >
95 typedef typename MT1::ResultType
RT1;
96 typedef typename MT2::ResultType
RT2;
97 typedef typename MT1::ElementType
ET1;
98 typedef typename MT2::ElementType
ET2;
99 typedef typename MT1::CompositeType
CT1;
100 typedef typename MT2::CompositeType
CT2;
128 enum { vectorizable = 0 };
156 typedef typename boost::remove_reference<CT2>::type::ConstIterator ConstIterator;
161 if(
lhs_.columns() == 0UL )
169 const ConstIterator end( B.end(j) );
170 ConstIterator element( B.begin(j) );
177 tmp =
lhs_(i,element->index()) * element->value();
179 for( ; element!=end; ++element )
180 tmp +=
lhs_(i,element->index()) * element->value();
186 for(
size_t k=1UL; k<
lhs_.columns(); ++k ) {
211 return rhs_.columns();
241 template<
typename T >
243 return (
lhs_.isAliased( alias ) ||
rhs_.isAliased( alias ) );
253 template<
typename T >
255 return (
lhs_.isAliased( alias ) ||
rhs_.isAliased( alias ) );
279 template<
typename MT >
288 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
300 for(
size_t i=0UL; i<A.rows(); ++i ) {
301 for(
size_t j=0UL; j<B.columns(); ++j )
303 ConstIterator element( B.begin(j) );
304 const ConstIterator end( B.end(j) );
306 if( element == end ) {
307 reset( (~lhs)(i,j) );
311 (~lhs)(i,j) = A(i,element->index()) * element->value();
313 for( ; element!=end; ++element )
314 (~lhs)(i,j) += A(i,element->index()) * element->value();
335 template<
typename MT >
344 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
356 for(
size_t j=0UL; j<B.columns(); ++j ) {
357 for(
size_t i=0UL; i<(~lhs).
rows(); ++i ) {
358 reset( (~lhs)(i,j) );
360 ConstIterator element( B.begin(j) );
361 const ConstIterator end( B.end(j) );
362 for( ; element!=end; ++element ) {
363 for(
size_t i=0UL; i<A.rows(); ++i ) {
364 if(
isDefault( (~lhs)(element->index(),j) ) )
365 (~lhs)(i,j) = A(i,element->index()) * element->value();
367 (~lhs)(i,j) += A(i,element->index()) * element->value();
389 template<
typename MT >
390 friend inline typename DisableIf< IsResizable<typename MT::ElementType> >::Type
398 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
410 BLAZE_INTERNAL_ASSERT( ( A.rows() - ( A.rows() % 4UL ) ) == ( A.rows() & size_t(-4) ),
"Invalid end calculation" );
411 const size_t iend( A.rows() & size_t(-4) );
413 for(
size_t i=0UL; i<iend; i+=4UL ) {
414 for(
size_t j=0UL; j<B.columns(); ++j )
416 ConstIterator element( B.begin(j) );
417 const ConstIterator end( B.end(j) );
419 if( element == end ) {
420 reset( (~lhs)(i ,j) );
421 reset( (~lhs)(i+1UL,j) );
422 reset( (~lhs)(i+2UL,j) );
423 reset( (~lhs)(i+3UL,j) );
427 (~lhs)(i ,j) = A(i ,element->index()) * element->value();
428 (~lhs)(i+1UL,j) = A(i+1UL,element->index()) * element->value();
429 (~lhs)(i+2UL,j) = A(i+2UL,element->index()) * element->value();
430 (~lhs)(i+3UL,j) = A(i+3UL,element->index()) * element->value();
432 for( ; element!=end; ++element ) {
433 (~lhs)(i ,j) += A(i ,element->index()) * element->value();
434 (~lhs)(i+1UL,j) += A(i+1UL,element->index()) * element->value();
435 (~lhs)(i+2UL,j) += A(i+2UL,element->index()) * element->value();
436 (~lhs)(i+3UL,j) += A(i+3UL,element->index()) * element->value();
441 for(
size_t i=iend; i<A.rows(); ++i ) {
442 for(
size_t j=0UL; j<B.columns(); ++j )
444 ConstIterator element( B.begin(j) );
445 const ConstIterator end( B.end(j) );
447 if( element == end ) {
448 reset( (~lhs)(i,j) );
452 (~lhs)(i,j) = A(i,element->index()) * element->value();
454 for( ; element!=end; ++element )
455 (~lhs)(i,j) += A(i,element->index()) * element->value();
476 template<
typename MT >
477 friend inline typename DisableIf< IsResizable<typename MT::ElementType> >::Type
485 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
497 BLAZE_INTERNAL_ASSERT( ( A.rows() - ( A.rows() % 4UL ) ) == ( A.rows() & size_t(-4) ),
"Invalid end calculation" );
498 const size_t iend( A.rows() & size_t(-4) );
500 for(
size_t j=0UL; j<B.columns(); ++j ) {
501 for(
size_t i=0UL; i<iend; i+=4UL ) {
502 reset( (~lhs)(i ,j) );
503 reset( (~lhs)(i+1UL,j) );
504 reset( (~lhs)(i+2UL,j) );
505 reset( (~lhs)(i+3UL,j) );
507 for(
size_t i=iend; i<(~lhs).
rows(); ++i ) {
508 reset( (~lhs)(i,j) );
510 ConstIterator element( B.begin(j) );
511 const ConstIterator end( B.end(j) );
513 while( element!=end )
515 const ET2 v1( element->value() );
516 const size_t i1( element->index() );
519 if( element != end ) {
520 const ET2 v2( element->value() );
521 const size_t i2( element->index() );
524 if( element != end ) {
525 const ET2 v3( element->value() );
526 const size_t i3( element->index() );
529 if( element != end ) {
530 const ET2 v4( element->value() );
531 const size_t i4( element->index() );
534 for(
size_t i=0UL; i<iend; i+=4UL ) {
535 (~lhs)(i ,j) += A(i ,i1) * v1 + A(i ,i2) * v2 + A(i ,i3) * v3 + A(i ,i4) * v4;
536 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2 + A(i+1UL,i3) * v3 + A(i+1UL,i4) * v4;
537 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2 + A(i+2UL,i3) * v3 + A(i+2UL,i4) * v4;
538 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2 + A(i+3UL,i3) * v3 + A(i+3UL,i4) * v4;
540 for(
size_t i=iend; i<A.rows(); ++i ) {
541 (~lhs)(i,j) += A(i,i1) * v1 + A(i,i2) * v2 + A(i,i3) * v3 + A(i,i4) * v4;
545 for(
size_t i=0UL; i<iend; i+=4UL ) {
546 (~lhs)(i ,j) += A(i ,i1) * v1 + A(i ,i2) * v2 + A(i ,i3) * v3;
547 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2 + A(i+1UL,i3) * v3;
548 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2 + A(i+2UL,i3) * v3;
549 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2 + A(i+3UL,i3) * v3;
551 for(
size_t i=iend; i<A.rows(); ++i ) {
552 (~lhs)(i,j) += A(i,i1) * v1 + A(i,i2) * v2 + A(i,i3) * v3;
557 for(
size_t i=0UL; i<iend; i+=4UL ) {
558 (~lhs)(i ,j) += A(i ,i1) * v1 + A(i ,i2) * v2;
559 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2;
560 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2;
561 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2;
563 for(
size_t i=iend; i<A.rows(); ++i ) {
564 (~lhs)(i,j) += A(i,i1) * v1 + A(i,i2) * v2;
569 for(
size_t i=0UL; i<iend; i+=4UL ) {
570 (~lhs)(i ,j) += A(i ,i1) * v1;
571 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1;
572 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1;
573 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1;
575 for(
size_t i=iend; i<A.rows(); ++i ) {
576 (~lhs)(i,j) += A(i,i1) * v1;
597 template<
typename MT
603 typedef typename SelectType< SO, ResultType, OppositeType >::Type TmpType;
615 const TmpType tmp( rhs );
634 template<
typename MT >
642 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
654 BLAZE_INTERNAL_ASSERT( ( A.rows() - ( A.rows() % 4UL ) ) == ( A.rows() & size_t(-4) ),
"Invalid end calculation" );
655 const size_t iend( A.rows() & size_t(-4) );
657 for(
size_t i=0UL; i<iend; i+=4UL ) {
658 for(
size_t j=0UL; j<B.columns(); ++j )
660 ConstIterator element( B.begin(j) );
661 const ConstIterator end( B.end(j) );
663 for( ; element!=end; ++element ) {
664 (~lhs)(i ,j) += A(i ,element->index()) * element->value();
665 (~lhs)(i+1UL,j) += A(i+1UL,element->index()) * element->value();
666 (~lhs)(i+2UL,j) += A(i+2UL,element->index()) * element->value();
667 (~lhs)(i+3UL,j) += A(i+3UL,element->index()) * element->value();
672 for(
size_t i=iend; i<A.rows(); ++i ) {
673 for(
size_t j=0UL; j<B.columns(); ++j )
675 ConstIterator element( B.begin(j) );
676 const ConstIterator end( B.end(j) );
678 for( ; element!=end; ++element )
679 (~lhs)(i,j) += A(i,element->index()) * element->value();
699 template<
typename MT >
707 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
719 BLAZE_INTERNAL_ASSERT( ( A.rows() - ( A.rows() % 4UL ) ) == ( A.rows() & size_t(-4) ),
"Invalid end calculation" );
720 const size_t iend( A.rows() & size_t(-4) );
722 for(
size_t j=0UL; j<B.columns(); ++j )
724 ConstIterator element( B.begin(j) );
725 const ConstIterator end( B.end(j) );
727 while( element!=end )
729 const ET2 v1( element->value() );
730 const size_t i1( element->index() );
733 if( element != end ) {
734 const ET2 v2( element->value() );
735 const size_t i2( element->index() );
738 if( element != end ) {
739 const ET2 v3( element->value() );
740 const size_t i3( element->index() );
743 if( element != end ) {
744 const ET2 v4( element->value() );
745 const size_t i4( element->index() );
748 for(
size_t i=0UL; i<iend; i+=4UL ) {
749 (~lhs)(i ,j) += A(i ,i1) * v1 + A(i ,i2) * v2 + A(i ,i3) * v3 + A(i ,i4) * v4;
750 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2 + A(i+1UL,i3) * v3 + A(i+1UL,i4) * v4;
751 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2 + A(i+2UL,i3) * v3 + A(i+2UL,i4) * v4;
752 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2 + A(i+3UL,i3) * v3 + A(i+3UL,i4) * v4;
754 for(
size_t i=iend; i<A.rows(); ++i ) {
755 (~lhs)(i,j) += A(i,i1) * v1 + A(i,i2) * v2 + A(i,i3) * v3 + A(i,i4) * v4;
759 for(
size_t i=0UL; i<iend; i+=4UL ) {
760 (~lhs)(i ,j) += A(i ,i1) * v1 + A(i ,i2) * v2 + A(i ,i3) * v3;
761 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2 + A(i+1UL,i3) * v3;
762 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2 + A(i+2UL,i3) * v3;
763 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2 + A(i+3UL,i3) * v3;
765 for(
size_t i=iend; i<A.rows(); ++i ) {
766 (~lhs)(i,j) += A(i,i1) * v1 + A(i,i2) * v2 + A(i,i3) * v3;
771 for(
size_t i=0UL; i<iend; i+=4UL ) {
772 (~lhs)(i ,j) += A(i ,i1) * v1 + A(i ,i2) * v2;
773 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2;
774 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2;
775 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2;
777 for(
size_t i=iend; i<A.rows(); ++i ) {
778 (~lhs)(i,j) += A(i,i1) * v1 + A(i,i2) * v2;
783 for(
size_t i=0UL; i<iend; i+=4UL ) {
784 (~lhs)(i ,j) += A(i ,i1) * v1;
785 (~lhs)(i+1UL,j) += A(i+1UL,i1) * v1;
786 (~lhs)(i+2UL,j) += A(i+2UL,i1) * v1;
787 (~lhs)(i+3UL,j) += A(i+3UL,i1) * v1;
789 for(
size_t i=iend; i<A.rows(); ++i ) {
790 (~lhs)(i,j) += A(i,i1) * v1;
816 template<
typename MT >
824 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
836 BLAZE_INTERNAL_ASSERT( ( A.rows() - ( A.rows() % 4UL ) ) == ( A.rows() & size_t(-4) ),
"Invalid end calculation" );
837 const size_t iend( A.rows() & size_t(-4) );
839 for(
size_t i=0UL; i<iend; i+=4 ) {
840 for(
size_t j=0UL; j<B.columns(); ++j )
842 ConstIterator element( B.begin(j) );
843 const ConstIterator end( B.end(j) );
845 for( ; element!=end; ++element ) {
846 (~lhs)(i ,j) -= A(i ,element->index()) * element->value();
847 (~lhs)(i+1UL,j) -= A(i+1UL,element->index()) * element->value();
848 (~lhs)(i+2UL,j) -= A(i+2UL,element->index()) * element->value();
849 (~lhs)(i+3UL,j) -= A(i+3UL,element->index()) * element->value();
854 for(
size_t i=iend; i<A.rows(); ++i ) {
855 for(
size_t j=0UL; j<B.columns(); ++j )
857 ConstIterator element( B.begin(j) );
858 const ConstIterator end( B.end(j) );
860 for( ; element!=end; ++element )
861 (~lhs)(i,j) -= A(i,element->index()) * element->value();
881 template<
typename MT >
889 typedef typename boost::remove_reference<RT>::type::ConstIterator ConstIterator;
901 BLAZE_INTERNAL_ASSERT( ( A.rows() - ( A.rows() % 4UL ) ) == ( A.rows() & size_t(-4) ),
"Invalid end calculation" );
902 const size_t iend( A.rows() & size_t(-4) );
904 for(
size_t j=0UL; j<B.columns(); ++j )
906 ConstIterator element( B.begin(j) );
907 const ConstIterator end( B.end(j) );
909 while( element!=end )
911 const ET2 v1( element->value() );
912 const size_t i1( element->index() );
915 if( element != end ) {
916 const ET2 v2( element->value() );
917 const size_t i2( element->index() );
920 if( element != end ) {
921 const ET2 v3( element->value() );
922 const size_t i3( element->index() );
925 if( element != end ) {
926 const ET2 v4( element->value() );
927 const size_t i4( element->index() );
930 for(
size_t i=0UL; i<iend; i+=4UL ) {
931 (~lhs)(i ,j) -= A(i ,i1) * v1 + A(i ,i2) * v2 + A(i ,i3) * v3 + A(i ,i4) * v4;
932 (~lhs)(i+1UL,j) -= A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2 + A(i+1UL,i3) * v3 + A(i+1UL,i4) * v4;
933 (~lhs)(i+2UL,j) -= A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2 + A(i+2UL,i3) * v3 + A(i+2UL,i4) * v4;
934 (~lhs)(i+3UL,j) -= A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2 + A(i+3UL,i3) * v3 + A(i+3UL,i4) * v4;
936 for(
size_t i=iend; i<A.rows(); ++i ) {
937 (~lhs)(i,j) -= A(i,i1) * v1 + A(i,i2) * v2 + A(i,i3) * v3 + A(i,i4) * v4;
941 for(
size_t i=0UL; i<iend; i+=4UL ) {
942 (~lhs)(i ,j) -= A(i ,i1) * v1 + A(i ,i2) * v2 + A(i ,i3) * v3;
943 (~lhs)(i+1UL,j) -= A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2 + A(i+1UL,i3) * v3;
944 (~lhs)(i+2UL,j) -= A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2 + A(i+2UL,i3) * v3;
945 (~lhs)(i+3UL,j) -= A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2 + A(i+3UL,i3) * v3;
947 for(
size_t i=iend; i<A.rows(); ++i ) {
948 (~lhs)(i,j) -= A(i,i1) * v1 + A(i,i2) * v2 + A(i,i3) * v3;
953 for(
size_t i=0UL; i<iend; i+=4UL ) {
954 (~lhs)(i ,j) -= A(i ,i1) * v1 + A(i ,i2) * v2;
955 (~lhs)(i+1UL,j) -= A(i+1UL,i1) * v1 + A(i+1UL,i2) * v2;
956 (~lhs)(i+2UL,j) -= A(i+2UL,i1) * v1 + A(i+2UL,i2) * v2;
957 (~lhs)(i+3UL,j) -= A(i+3UL,i1) * v1 + A(i+3UL,i2) * v2;
959 for(
size_t i=iend; i<A.rows(); ++i ) {
960 (~lhs)(i,j) -= A(i,i1) * v1 + A(i,i2) * v2;
965 for(
size_t i=0UL; i<iend; i+=4UL ) {
966 (~lhs)(i ,j) -= A(i ,i1) * v1;
967 (~lhs)(i+1UL,j) -= A(i+1UL,i1) * v1;
968 (~lhs)(i+2UL,j) -= A(i+2UL,i1) * v1;
969 (~lhs)(i+3UL,j) -= A(i+3UL,i1) * v1;
971 for(
size_t i=iend; i<A.rows(); ++i ) {
972 (~lhs)(i,j) -= A(i,i1) * v1;
1042 template<
typename T1
1044 inline const TDMatTSMatMultExpr<T1,T2>
1050 throw std::invalid_argument(
"Matrix sizes do not match" );
1078 template<
typename MT1
1080 inline typename RowExprTrait< TDMatTSMatMultExpr<MT1,MT2> >::Type
1081 row(
const TDMatTSMatMultExpr<MT1,MT2>& dm,
size_t index )
1085 return row( dm.leftOperand(), index ) * dm.rightOperand();
1104 template<
typename MT1
1106 inline typename ColumnExprTrait< TDMatTSMatMultExpr<MT1,MT2> >::Type
1107 column(
const TDMatTSMatMultExpr<MT1,MT2>& dm,
size_t index )
1111 return dm.leftOperand() *
column( dm.rightOperand(), index );
1127 template<
typename MT1,
typename MT2,
typename VT >
1128 struct TDMatDVecMultExprTrait< TDMatTSMatMultExpr<MT1,MT2>, VT >
1132 typedef typename SelectType< IsDenseMatrix<MT1>::value && IsColumnMajorMatrix<MT1>::value &&
1133 IsSparseMatrix<MT2>::value && IsColumnMajorMatrix<MT2>::value &&
1134 IsDenseVector<VT>::value && !IsTransposeVector<VT>::value
1135 ,
typename TDMatDVecMultExprTrait< MT1, typename TSMatDVecMultExprTrait<MT2,VT>::Type >::Type
1136 , INVALID_TYPE >::Type Type;
1145 template<
typename MT1,
typename MT2,
typename VT >
1146 struct TDMatSVecMultExprTrait< TDMatTSMatMultExpr<MT1,MT2>, VT >
1150 typedef typename SelectType< IsDenseMatrix<MT1>::value && IsColumnMajorMatrix<MT1>::value &&
1151 IsSparseMatrix<MT2>::value && IsColumnMajorMatrix<MT2>::value &&
1152 IsSparseVector<VT>::value && !IsTransposeVector<VT>::value
1153 ,
typename TDMatDVecMultExprTrait< MT1, typename TSMatDVecMultExprTrait<MT2,VT>::Type >::Type
1154 , INVALID_TYPE >::Type Type;
1163 template<
typename VT,
typename MT1,
typename MT2 >
1164 struct TDVecTDMatMultExprTrait< VT, TDMatTSMatMultExpr<MT1,MT2> >
1168 typedef typename SelectType< IsDenseVector<VT>::value && IsTransposeVector<VT>::value &&
1169 IsDenseMatrix<MT1>::value && IsColumnMajorMatrix<MT1>::value &&
1170 IsSparseMatrix<MT2>::value && IsColumnMajorMatrix<MT2>::value
1171 ,
typename TDVecTSMatMultExprTrait< typename TDVecTDMatMultExprTrait<VT,MT1>::Type, MT2 >::Type
1172 , INVALID_TYPE >::Type Type;
1181 template<
typename VT,
typename MT1,
typename MT2 >
1182 struct TSVecTDMatMultExprTrait< VT, TDMatTSMatMultExpr<MT1,MT2> >
1186 typedef typename SelectType< IsSparseVector<VT>::value && IsTransposeVector<VT>::value &&
1187 IsDenseMatrix<MT1>::value && IsColumnMajorMatrix<MT1>::value &&
1188 IsSparseMatrix<MT2>::value && IsColumnMajorMatrix<MT2>::value
1189 ,
typename TDVecTSMatMultExprTrait< typename TSVecTDMatMultExprTrait<VT,MT1>::Type, MT2 >::Type
1190 , INVALID_TYPE >::Type Type;
1199 template<
typename MT1,
typename MT2 >
1200 struct RowExprTrait< TDMatTSMatMultExpr<MT1,MT2> >
1204 typedef typename MultExprTrait< typename RowExprTrait<const MT1>::Type, MT2 >::Type Type;
1213 template<
typename MT1,
typename MT2 >
1214 struct ColumnExprTrait< TDMatTSMatMultExpr<MT1,MT2> >
1218 typedef typename MultExprTrait< MT1, typename ColumnExprTrait<const MT2>::Type >::Type Type;