35 #ifndef _BLAZE_MATH_DENSE_MMM_H_ 36 #define _BLAZE_MATH_DENSE_MMM_H_ 105 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
106 void mmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
108 using ET1 = ElementType_<MT1>;
109 using ET2 = ElementType_<MT2>;
110 using ET3 = ElementType_<MT3>;
111 using SIMDType = SIMDTrait_<ET1>;
129 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
131 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
132 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
137 const size_t M( A.rows() );
138 const size_t N( B.columns() );
139 const size_t K( A.columns() );
143 DynamicMatrix<ET2,false> A2( M, KBLOCK );
144 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
149 else if( !
isOne( beta ) ) {
154 size_t kblock( 0UL );
156 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
159 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
162 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
165 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
166 const size_t iend ( IsUpper<MT2>::value ? kk+kblock : M );
167 const size_t isize ( iend - ibegin );
169 A2 =
serial( submatrix<!remainder>( A, ibegin, kk, isize, kblock ) );
172 size_t jblock( 0UL );
176 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
178 if( ( IsLower<MT3>::value && kk+kblock <= jj ) ||
179 ( IsUpper<MT3>::value && jj+jblock <= kk ) ) {
184 B2 =
serial( submatrix<!remainder>( B, kk, jj, kblock, jblock ) );
188 if( IsFloatingPoint<ET1>::value )
190 for( ; (i+5UL) <= isize; i+=5UL )
194 for( ; (j+2UL) <= jblock; j+=2UL )
196 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
198 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
200 const SIMDType a1( A2.load(i ,k) );
201 const SIMDType a2( A2.load(i+1UL,k) );
202 const SIMDType a3( A2.load(i+2UL,k) );
203 const SIMDType a4( A2.load(i+3UL,k) );
204 const SIMDType a5( A2.load(i+4UL,k) );
206 const SIMDType b1( B2.load(k,j ) );
207 const SIMDType b2( B2.load(k,j+1UL) );
221 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
222 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
223 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
224 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
225 (~C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
226 (~C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
227 (~C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
228 (~C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
229 (~C)(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
230 (~C)(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
235 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
237 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
239 const SIMDType a1( A2.load(i ,k) );
240 const SIMDType a2( A2.load(i+1UL,k) );
241 const SIMDType a3( A2.load(i+2UL,k) );
242 const SIMDType a4( A2.load(i+3UL,k) );
243 const SIMDType a5( A2.load(i+4UL,k) );
245 const SIMDType b1( B2.load(k,j) );
254 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
255 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
256 (~C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
257 (~C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
258 (~C)(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
264 for( ; (i+4UL) <= isize; i+=4UL )
268 for( ; (j+2UL) <= jblock; j+=2UL )
270 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
272 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
274 const SIMDType a1( A2.load(i ,k) );
275 const SIMDType a2( A2.load(i+1UL,k) );
276 const SIMDType a3( A2.load(i+2UL,k) );
277 const SIMDType a4( A2.load(i+3UL,k) );
279 const SIMDType b1( B2.load(k,j ) );
280 const SIMDType b2( B2.load(k,j+1UL) );
292 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
293 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
294 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
295 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
296 (~C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
297 (~C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
298 (~C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
299 (~C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
304 SIMDType xmm1, xmm2, xmm3, xmm4;
306 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
308 const SIMDType a1( A2.load(i ,k) );
309 const SIMDType a2( A2.load(i+1UL,k) );
310 const SIMDType a3( A2.load(i+2UL,k) );
311 const SIMDType a4( A2.load(i+3UL,k) );
313 const SIMDType b1( B2.load(k,j) );
321 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
322 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
323 (~C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
324 (~C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
329 for( ; (i+2UL) <= isize; i+=2UL )
333 for( ; (j+4UL) <= jblock; j+=4UL )
335 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
337 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
339 const SIMDType a1( A2.load(i ,k) );
340 const SIMDType a2( A2.load(i+1UL,k) );
342 const SIMDType b1( B2.load(k,j ) );
343 const SIMDType b2( B2.load(k,j+1UL) );
344 const SIMDType b3( B2.load(k,j+2UL) );
345 const SIMDType b4( B2.load(k,j+3UL) );
357 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
358 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
359 (~C)(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
360 (~C)(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
361 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
362 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
363 (~C)(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
364 (~C)(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
367 for( ; (j+2UL) <= jblock; j+=2UL )
369 SIMDType xmm1, xmm2, xmm3, xmm4;
371 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
373 const SIMDType a1( A2.load(i ,k) );
374 const SIMDType a2( A2.load(i+1UL,k) );
376 const SIMDType b1( B2.load(k,j ) );
377 const SIMDType b2( B2.load(k,j+1UL) );
385 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
386 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
387 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
388 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
395 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
397 const SIMDType a1( A2.load(i ,k) );
398 const SIMDType a2( A2.load(i+1UL,k) );
400 const SIMDType b1( B2.load(k,j) );
406 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
407 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
415 for( ; (j+2UL) <= jblock; j+=2UL )
419 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
421 const SIMDType a1( A2.load(i,k) );
423 xmm1 += a1 * B2.load(k,j );
424 xmm2 += a1 * B2.load(k,j+1UL);
427 (~C)(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
428 (~C)(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
435 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
437 const SIMDType a1( A2.load(i,k) );
439 xmm1 += a1 * B2.load(k,j);
442 (~C)(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
452 if( remainder && kk < K )
454 const size_t ksize( K - kk );
456 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
457 const size_t isize ( M - ibegin );
462 size_t jblock( 0UL );
466 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
468 if( IsUpper<MT3>::value && jj+jblock <= kk ) {
477 if( IsFloatingPoint<ET1>::value )
479 for( ; (i+5UL) <= isize; i+=5UL )
483 for( ; (j+2UL) <= jblock; j+=2UL ) {
484 for(
size_t k=0UL; k<ksize; ++k ) {
485 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
486 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
487 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
488 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
489 (~C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
490 (~C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
491 (~C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
492 (~C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
493 (~C)(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
494 (~C)(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
499 for(
size_t k=0UL; k<ksize; ++k ) {
500 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
501 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
502 (~C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
503 (~C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
504 (~C)(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
511 for( ; (i+4UL) <= isize; i+=4UL )
515 for( ; (j+2UL) <= jblock; j+=2UL ) {
516 for(
size_t k=0UL; k<ksize; ++k ) {
517 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
518 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
519 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
520 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
521 (~C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
522 (~C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
523 (~C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
524 (~C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
529 for(
size_t k=0UL; k<ksize; ++k ) {
530 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
531 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
532 (~C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
533 (~C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
539 for( ; (i+2UL) <= isize; i+=2UL )
543 for( ; (j+2UL) <= jblock; j+=2UL ) {
544 for(
size_t k=0UL; k<ksize; ++k ) {
545 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
546 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
547 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
548 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
553 for(
size_t k=0UL; k<ksize; ++k ) {
554 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
555 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
564 for( ; (j+2UL) <= jblock; j+=2UL ) {
565 for(
size_t k=0UL; k<ksize; ++k ) {
566 (~C)(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
567 (~C)(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
572 for(
size_t k=0UL; k<ksize; ++k ) {
573 (~C)(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
605 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
606 void mmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
608 using ET1 = ElementType_<MT1>;
609 using ET2 = ElementType_<MT2>;
610 using ET3 = ElementType_<MT3>;
611 using SIMDType = SIMDTrait_<ET1>;
629 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
631 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
632 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
637 const size_t M( A.rows() );
638 const size_t N( B.columns() );
639 const size_t K( A.columns() );
643 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
644 DynamicMatrix<ET3,true> B2( KBLOCK, N );
649 else if( !
isOne( beta ) ) {
654 size_t kblock( 0UL );
656 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
659 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
662 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
665 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
666 const size_t jend ( IsLower<MT3>::value ? kk+kblock : N );
667 const size_t jsize ( jend - jbegin );
669 B2 =
serial( submatrix<!remainder>( B, kk, jbegin, kblock, jsize ) );
672 size_t iblock( 0UL );
676 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
678 if( ( IsLower<MT2>::value && ii+iblock <= kk ) ||
679 ( IsUpper<MT2>::value && kk+kblock <= ii ) ) {
684 A2 =
serial( submatrix<!remainder>( A, ii, kk, iblock, kblock ) );
688 if( IsFloatingPoint<ET3>::value )
690 for( ; (j+5UL) <= jsize; j+=5UL )
694 for( ; (i+2UL) <= iblock; i+=2UL )
696 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
698 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
700 const SIMDType a1( A2.load(i ,k) );
701 const SIMDType a2( A2.load(i+1UL,k) );
703 const SIMDType b1( B2.load(k,j ) );
704 const SIMDType b2( B2.load(k,j+1UL) );
705 const SIMDType b3( B2.load(k,j+2UL) );
706 const SIMDType b4( B2.load(k,j+3UL) );
707 const SIMDType b5( B2.load(k,j+4UL) );
721 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
722 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
723 (~C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
724 (~C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
725 (~C)(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
726 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
727 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
728 (~C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
729 (~C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
730 (~C)(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
735 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
737 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
739 const SIMDType a1( A2.load(i,k) );
741 xmm1 += a1 * B2.load(k,j );
742 xmm2 += a1 * B2.load(k,j+1UL);
743 xmm3 += a1 * B2.load(k,j+2UL);
744 xmm4 += a1 * B2.load(k,j+3UL);
745 xmm5 += a1 * B2.load(k,j+4UL);
748 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
749 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
750 (~C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
751 (~C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
752 (~C)(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
758 for( ; (j+4UL) <= jsize; j+=4UL )
762 for( ; (i+2UL) <= iblock; i+=2UL )
764 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
766 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
768 const SIMDType a1( A2.load(i ,k) );
769 const SIMDType a2( A2.load(i+1UL,k) );
771 const SIMDType b1( B2.load(k,j ) );
772 const SIMDType b2( B2.load(k,j+1UL) );
773 const SIMDType b3( B2.load(k,j+2UL) );
774 const SIMDType b4( B2.load(k,j+3UL) );
786 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
787 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
788 (~C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
789 (~C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
790 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
791 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
792 (~C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
793 (~C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
798 SIMDType xmm1, xmm2, xmm3, xmm4;
800 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
802 const SIMDType a1( A2.load(i,k) );
804 xmm1 += a1 * B2.load(k,j );
805 xmm2 += a1 * B2.load(k,j+1UL);
806 xmm3 += a1 * B2.load(k,j+2UL);
807 xmm4 += a1 * B2.load(k,j+3UL);
810 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
811 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
812 (~C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
813 (~C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
818 for( ; (j+2UL) <= jsize; j+=2UL )
822 for( ; (i+4UL) <= iblock; i+=4UL )
824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
826 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
828 const SIMDType a1( A2.load(i ,k) );
829 const SIMDType a2( A2.load(i+1UL,k) );
830 const SIMDType a3( A2.load(i+2UL,k) );
831 const SIMDType a4( A2.load(i+3UL,k) );
833 const SIMDType b1( B2.load(k,j ) );
834 const SIMDType b2( B2.load(k,j+1UL) );
846 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
847 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
848 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
849 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
850 (~C)(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
851 (~C)(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
852 (~C)(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
853 (~C)(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
856 for( ; (i+2UL) <= iblock; i+=2UL )
858 SIMDType xmm1, xmm2, xmm3, xmm4;
860 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
862 const SIMDType a1( A2.load(i ,k) );
863 const SIMDType a2( A2.load(i+1UL,k) );
865 const SIMDType b1( B2.load(k,j ) );
866 const SIMDType b2( B2.load(k,j+1UL) );
874 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
875 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
876 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
877 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
884 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
886 const SIMDType a1( A2.load(i,k) );
888 xmm1 += a1 * B2.load(k,j );
889 xmm2 += a1 * B2.load(k,j+1UL);
892 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
893 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
901 for( ; (i+2UL) <= iblock; i+=2UL )
905 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
907 const SIMDType b1( B2.load(k,j) );
909 xmm1 += A2.load(i ,k) * b1;
910 xmm2 += A2.load(i+1UL,k) * b1;
913 (~C)(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
914 (~C)(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
921 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
923 xmm1 += A2.load(i,k) * B2.load(k,j);
926 (~C)(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
936 if( remainder && kk < K )
938 const size_t ksize( K - kk );
940 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
941 const size_t jsize ( N - jbegin );
946 size_t iblock( 0UL );
950 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
952 if( IsLower<MT2>::value && ii+iblock <= kk ) {
961 if( IsFloatingPoint<ET1>::value )
963 for( ; (j+5UL) <= jsize; j+=5UL )
967 for( ; (i+2UL) <= iblock; i+=2UL ) {
968 for(
size_t k=0UL; k<ksize; ++k ) {
969 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
970 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
971 (~C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
972 (~C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
973 (~C)(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
974 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
975 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
976 (~C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
977 (~C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
978 (~C)(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
983 for(
size_t k=0UL; k<ksize; ++k ) {
984 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
985 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
986 (~C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
987 (~C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
988 (~C)(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
995 for( ; (j+4UL) <= jsize; j+=4UL )
999 for( ; (i+2UL) <= iblock; i+=2UL ) {
1000 for(
size_t k=0UL; k<ksize; ++k ) {
1001 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1002 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1003 (~C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
1004 (~C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
1005 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1006 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1007 (~C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
1008 (~C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
1013 for(
size_t k=0UL; k<ksize; ++k ) {
1014 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1015 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1016 (~C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
1017 (~C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
1023 for( ; (j+2UL) <= jsize; j+=2UL )
1027 for( ; (i+2UL) <= iblock; i+=2UL ) {
1028 for(
size_t k=0UL; k<ksize; ++k ) {
1029 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1030 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1031 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1032 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1037 for(
size_t k=0UL; k<ksize; ++k ) {
1038 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1039 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1048 for( ; (i+2UL) <= iblock; i+=2UL ) {
1049 for(
size_t k=0UL; k<ksize; ++k ) {
1050 (~C)(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
1051 (~C)(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1056 for(
size_t k=0UL; k<ksize; ++k ) {
1057 (~C)(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
1086 template<
typename MT1,
typename MT2,
typename MT3 >
1087 inline void mmm( MT1& C,
const MT2& A,
const MT3& B )
1089 using ET1 = ElementType_<MT1>;
1090 using ET2 = ElementType_<MT2>;
1091 using ET3 = ElementType_<MT3>;
1096 mmm( C, A, B, ET1(1), ET1(0) );
1129 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1130 void lmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1132 using ET1 = ElementType_<MT1>;
1133 using ET2 = ElementType_<MT2>;
1134 using ET3 = ElementType_<MT3>;
1135 using SIMDType = SIMDTrait_<ET1>;
1157 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
1159 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1160 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
1165 const size_t M( A.rows() );
1166 const size_t N( B.columns() );
1167 const size_t K( A.columns() );
1171 DynamicMatrix<ET2,false> A2( M, KBLOCK );
1172 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
1174 DerestrictTrait_<MT1> c( derestrict( ~C ) );
1179 else if( !
isOne( beta ) ) {
1184 size_t kblock( 0UL );
1186 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1189 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
1192 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1195 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
1196 const size_t iend ( IsUpper<MT2>::value ? kk+kblock : M );
1197 const size_t isize ( iend - ibegin );
1199 A2 =
serial( submatrix<!remainder>( A, ibegin, kk, isize, kblock ) );
1202 size_t jblock( 0UL );
1206 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1208 if( ( IsLower<MT3>::value && kk+kblock <= jj ) ||
1209 ( IsUpper<MT3>::value && jj+jblock <= kk ) ) {
1214 B2 =
serial( submatrix<!remainder>( B, kk, jj, kblock, jblock ) );
1218 if( IsFloatingPoint<ET1>::value )
1220 for( ; (i+5UL) <= isize; i+=5UL )
1222 if( jj > ibegin+i+4UL )
continue;
1224 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1227 for( ; (j+2UL) <= jend; j+=2UL )
1229 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1231 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1233 const SIMDType a1( A2.load(i ,k) );
1234 const SIMDType a2( A2.load(i+1UL,k) );
1235 const SIMDType a3( A2.load(i+2UL,k) );
1236 const SIMDType a4( A2.load(i+3UL,k) );
1237 const SIMDType a5( A2.load(i+4UL,k) );
1239 const SIMDType b1( B2.load(k,j ) );
1240 const SIMDType b2( B2.load(k,j+1UL) );
1254 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1255 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1256 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1257 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1258 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1259 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1260 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1261 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1262 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
1263 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
1268 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1270 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1272 const SIMDType a1( A2.load(i ,k) );
1273 const SIMDType a2( A2.load(i+1UL,k) );
1274 const SIMDType a3( A2.load(i+2UL,k) );
1275 const SIMDType a4( A2.load(i+3UL,k) );
1276 const SIMDType a5( A2.load(i+4UL,k) );
1278 const SIMDType b1( B2.load(k,j) );
1287 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1288 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1289 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1290 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1291 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
1297 for( ; (i+4UL) <= isize; i+=4UL )
1299 if( jj > ibegin+i+3UL )
continue;
1301 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1304 for( ; (j+2UL) <= jend; j+=2UL )
1306 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1308 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1310 const SIMDType a1( A2.load(i ,k) );
1311 const SIMDType a2( A2.load(i+1UL,k) );
1312 const SIMDType a3( A2.load(i+2UL,k) );
1313 const SIMDType a4( A2.load(i+3UL,k) );
1315 const SIMDType b1( B2.load(k,j ) );
1316 const SIMDType b2( B2.load(k,j+1UL) );
1328 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1329 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1330 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1331 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1332 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1333 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1334 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1335 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1340 SIMDType xmm1, xmm2, xmm3, xmm4;
1342 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1344 const SIMDType a1( A2.load(i ,k) );
1345 const SIMDType a2( A2.load(i+1UL,k) );
1346 const SIMDType a3( A2.load(i+2UL,k) );
1347 const SIMDType a4( A2.load(i+3UL,k) );
1349 const SIMDType b1( B2.load(k,j) );
1357 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1358 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1359 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1360 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1365 for( ; (i+2UL) <= isize; i+=2UL )
1367 if( jj > ibegin+i+1UL )
continue;
1369 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1372 for( ; (j+4UL) <= jend; j+=4UL )
1374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1376 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1378 const SIMDType a1( A2.load(i ,k) );
1379 const SIMDType a2( A2.load(i+1UL,k) );
1381 const SIMDType b1( B2.load(k,j ) );
1382 const SIMDType b2( B2.load(k,j+1UL) );
1383 const SIMDType b3( B2.load(k,j+2UL) );
1384 const SIMDType b4( B2.load(k,j+3UL) );
1396 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1397 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1398 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
1399 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
1400 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
1401 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1402 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
1403 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
1406 for( ; (j+2UL) <= jend; j+=2UL )
1408 SIMDType xmm1, xmm2, xmm3, xmm4;
1410 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1412 const SIMDType a1( A2.load(i ,k) );
1413 const SIMDType a2( A2.load(i+1UL,k) );
1415 const SIMDType b1( B2.load(k,j ) );
1416 const SIMDType b2( B2.load(k,j+1UL) );
1424 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1425 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1426 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1427 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1432 SIMDType xmm1, xmm2;
1434 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1436 const SIMDType a1( A2.load(i ,k) );
1437 const SIMDType a2( A2.load(i+1UL,k) );
1439 const SIMDType b1( B2.load(k,j) );
1445 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1446 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1450 if( i<isize && jj <= ibegin+i )
1452 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1455 for( ; (j+2UL) <= jend; j+=2UL )
1457 SIMDType xmm1, xmm2;
1459 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1461 const SIMDType a1( A2.load(i,k) );
1463 xmm1 += a1 * B2.load(k,j );
1464 xmm2 += a1 * B2.load(k,j+1UL);
1467 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
1468 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1475 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1477 const SIMDType a1( A2.load(i,k) );
1479 xmm1 += a1 * B2.load(k,j);
1482 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
1492 if( remainder && kk < K )
1494 const size_t ksize( K - kk );
1496 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
1497 const size_t isize ( M - ibegin );
1502 size_t jblock( 0UL );
1506 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1508 if( IsUpper<MT3>::value && jj+jblock <= kk ) {
1517 if( IsFloatingPoint<ET1>::value )
1519 for( ; (i+5UL) <= isize; i+=5UL )
1521 if( jj > ibegin+i+4UL )
continue;
1523 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1526 for( ; (j+2UL) <= jend; j+=2UL ) {
1527 for(
size_t k=0UL; k<ksize; ++k ) {
1528 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1529 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1530 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1531 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1532 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1533 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1534 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1535 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1536 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
1537 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
1542 for(
size_t k=0UL; k<ksize; ++k ) {
1543 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1544 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1545 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1546 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1547 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
1554 for( ; (i+4UL) <= isize; i+=4UL )
1556 if( jj > ibegin+i+3UL )
continue;
1558 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1561 for( ; (j+2UL) <= jend; j+=2UL ) {
1562 for(
size_t k=0UL; k<ksize; ++k ) {
1563 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1564 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1565 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1566 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1567 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1568 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1569 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1570 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1575 for(
size_t k=0UL; k<ksize; ++k ) {
1576 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1577 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1578 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1579 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1585 for( ; (i+2UL) <= isize; i+=2UL )
1587 if( jj > ibegin+i+1UL )
continue;
1589 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1592 for( ; (j+2UL) <= jend; j+=2UL ) {
1593 for(
size_t k=0UL; k<ksize; ++k ) {
1594 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1595 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1596 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1597 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1602 for(
size_t k=0UL; k<ksize; ++k ) {
1603 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1604 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1609 if( i<isize && jj <= ibegin+i )
1611 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1614 for( ; (j+2UL) <= jend; j+=2UL ) {
1615 for(
size_t k=0UL; k<ksize; ++k ) {
1616 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
1617 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1622 for(
size_t k=0UL; k<ksize; ++k ) {
1623 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
1655 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1656 void lmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1658 using ET1 = ElementType_<MT1>;
1659 using ET2 = ElementType_<MT2>;
1660 using ET3 = ElementType_<MT3>;
1661 using SIMDType = SIMDTrait_<ET1>;
1683 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
1685 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1686 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
1691 const size_t M( A.rows() );
1692 const size_t N( B.columns() );
1693 const size_t K( A.columns() );
1697 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
1698 DynamicMatrix<ET3,true> B2( KBLOCK, N );
1700 DerestrictTrait_<MT1> c( derestrict( ~C ) );
1705 else if( !
isOne( beta ) ) {
1710 size_t kblock( 0UL );
1712 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1715 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
1718 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1721 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
1722 const size_t jend ( IsLower<MT3>::value ? kk+kblock : N );
1723 const size_t jsize ( jend - jbegin );
1725 B2 =
serial( submatrix<!remainder>( B, kk, jbegin, kblock, jsize ) );
1728 size_t iblock( 0UL );
1732 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
1734 if( ( IsLower<MT2>::value && ii+iblock <= kk ) ||
1735 ( IsUpper<MT2>::value && kk+kblock <= ii ) ) {
1740 A2 =
serial( submatrix<!remainder>( A, ii, kk, iblock, kblock ) );
1744 if( IsFloatingPoint<ET3>::value )
1746 for( ; (j+5UL) <= jsize; j+=5UL )
1748 if( ii+iblock < jbegin )
continue;
1750 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1752 for( ; (i+2UL) <= iblock; i+=2UL )
1754 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1756 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1758 const SIMDType a1( A2.load(i ,k) );
1759 const SIMDType a2( A2.load(i+1UL,k) );
1761 const SIMDType b1( B2.load(k,j ) );
1762 const SIMDType b2( B2.load(k,j+1UL) );
1763 const SIMDType b3( B2.load(k,j+2UL) );
1764 const SIMDType b4( B2.load(k,j+3UL) );
1765 const SIMDType b5( B2.load(k,j+4UL) );
1779 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1780 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1781 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1782 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1783 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1784 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
1785 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
1786 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
1787 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
1788 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
1793 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1795 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1797 const SIMDType a1( A2.load(i,k) );
1799 xmm1 += a1 * B2.load(k,j );
1800 xmm2 += a1 * B2.load(k,j+1UL);
1801 xmm3 += a1 * B2.load(k,j+2UL);
1802 xmm4 += a1 * B2.load(k,j+3UL);
1803 xmm5 += a1 * B2.load(k,j+4UL);
1806 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1807 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1808 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1809 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1810 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1816 for( ; (j+4UL) <= jsize; j+=4UL )
1818 if( ii+iblock < jbegin )
continue;
1820 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1822 for( ; (i+2UL) <= iblock; i+=2UL )
1824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1826 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1828 const SIMDType a1( A2.load(i ,k) );
1829 const SIMDType a2( A2.load(i+1UL,k) );
1831 const SIMDType b1( B2.load(k,j ) );
1832 const SIMDType b2( B2.load(k,j+1UL) );
1833 const SIMDType b3( B2.load(k,j+2UL) );
1834 const SIMDType b4( B2.load(k,j+3UL) );
1846 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1847 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1848 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1849 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1850 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1851 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1852 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
1853 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
1858 SIMDType xmm1, xmm2, xmm3, xmm4;
1860 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1862 const SIMDType a1( A2.load(i,k) );
1864 xmm1 += a1 * B2.load(k,j );
1865 xmm2 += a1 * B2.load(k,j+1UL);
1866 xmm3 += a1 * B2.load(k,j+2UL);
1867 xmm4 += a1 * B2.load(k,j+3UL);
1870 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1871 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1872 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1873 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1878 for( ; (j+2UL) <= jsize; j+=2UL )
1880 if( ii+iblock < jbegin )
continue;
1882 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1884 for( ; (i+4UL) <= iblock; i+=4UL )
1886 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1888 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1890 const SIMDType a1( A2.load(i ,k) );
1891 const SIMDType a2( A2.load(i+1UL,k) );
1892 const SIMDType a3( A2.load(i+2UL,k) );
1893 const SIMDType a4( A2.load(i+3UL,k) );
1895 const SIMDType b1( B2.load(k,j ) );
1896 const SIMDType b2( B2.load(k,j+1UL) );
1908 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1909 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1910 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1911 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1912 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1913 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1914 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
1915 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
1918 for( ; (i+2UL) <= iblock; i+=2UL )
1920 SIMDType xmm1, xmm2, xmm3, xmm4;
1922 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1924 const SIMDType a1( A2.load(i ,k) );
1925 const SIMDType a2( A2.load(i+1UL,k) );
1927 const SIMDType b1( B2.load(k,j ) );
1928 const SIMDType b2( B2.load(k,j+1UL) );
1936 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1937 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1938 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1939 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1944 SIMDType xmm1, xmm2;
1946 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1948 const SIMDType a1( A2.load(i,k) );
1950 xmm1 += a1 * B2.load(k,j );
1951 xmm2 += a1 * B2.load(k,j+1UL);
1954 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1955 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1959 if( j<jsize && ii+iblock >= jbegin )
1961 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1963 for( ; (i+2UL) <= iblock; i+=2UL )
1965 SIMDType xmm1, xmm2;
1967 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1969 const SIMDType b1( B2.load(k,j) );
1971 xmm1 += A2.load(i ,k) * b1;
1972 xmm2 += A2.load(i+1UL,k) * b1;
1975 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
1976 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
1983 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1985 xmm1 += A2.load(i,k) * B2.load(k,j);
1988 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
1998 if( remainder && kk < K )
2000 const size_t ksize( K - kk );
2002 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
2003 const size_t jsize ( N - jbegin );
2008 size_t iblock( 0UL );
2012 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2014 if( IsLower<MT2>::value && ii+iblock <= kk ) {
2023 if( IsFloatingPoint<ET1>::value )
2025 for( ; (j+5UL) <= jsize; j+=5UL )
2027 if( ii+iblock < jbegin )
continue;
2029 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2031 for( ; (i+2UL) <= iblock; i+=2UL ) {
2032 for(
size_t k=0UL; k<ksize; ++k ) {
2033 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2034 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2035 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2036 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2037 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
2038 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2039 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2040 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2041 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2042 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
2047 for(
size_t k=0UL; k<ksize; ++k ) {
2048 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2049 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2050 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2051 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2052 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
2059 for( ; (j+4UL) <= jsize; j+=4UL )
2061 if( ii+iblock < jbegin )
continue;
2063 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2065 for( ; (i+2UL) <= iblock; i+=2UL ) {
2066 for(
size_t k=0UL; k<ksize; ++k ) {
2067 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2068 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2069 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2070 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2071 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2072 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2073 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2074 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2079 for(
size_t k=0UL; k<ksize; ++k ) {
2080 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2081 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2082 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2083 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2089 for( ; (j+2UL) <= jsize; j+=2UL )
2091 if( ii+iblock < jbegin )
continue;
2093 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2095 for( ; (i+2UL) <= iblock; i+=2UL ) {
2096 for(
size_t k=0UL; k<ksize; ++k ) {
2097 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2098 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2099 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2100 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2105 for(
size_t k=0UL; k<ksize; ++k ) {
2106 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2107 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2114 if( ii+iblock < jbegin )
continue;
2116 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2118 for( ; (i+2UL) <= iblock; i+=2UL ) {
2119 for(
size_t k=0UL; k<ksize; ++k ) {
2120 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
2121 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2126 for(
size_t k=0UL; k<ksize; ++k ) {
2127 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
2156 template<
typename MT1,
typename MT2,
typename MT3 >
2157 inline void lmmm( MT1& C,
const MT2& A,
const MT3& B )
2159 using ET1 = ElementType_<MT1>;
2160 using ET2 = ElementType_<MT2>;
2161 using ET3 = ElementType_<MT3>;
2166 lmmm( C, A, B, ET1(1), ET1(0) );
2199 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2200 void ummm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2202 using ET1 = ElementType_<MT1>;
2203 using ET2 = ElementType_<MT2>;
2204 using ET3 = ElementType_<MT3>;
2205 using SIMDType = SIMDTrait_<ET1>;
2227 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
2229 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2230 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
2235 const size_t M( A.rows() );
2236 const size_t N( B.columns() );
2237 const size_t K( A.columns() );
2241 DynamicMatrix<ET2,false> A2( M, KBLOCK );
2242 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
2244 DerestrictTrait_<MT1> c( derestrict( ~C ) );
2249 else if( !
isOne( beta ) ) {
2254 size_t kblock( 0UL );
2256 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2259 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
2262 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2265 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
2266 const size_t iend ( IsUpper<MT2>::value ? kk+kblock : M );
2267 const size_t isize ( iend - ibegin );
2269 A2 =
serial( submatrix<!remainder>( A, ibegin, kk, isize, kblock ) );
2272 size_t jblock( 0UL );
2276 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2278 if( ( IsLower<MT3>::value && kk+kblock <= jj ) ||
2279 ( IsUpper<MT3>::value && jj+jblock <= kk ) ) {
2284 B2 =
serial( submatrix<!remainder>( B, kk, jj, kblock, jblock ) );
2288 if( IsFloatingPoint<ET1>::value )
2290 for( ; (i+5UL) <= isize; i+=5UL )
2292 if( jj+jblock < ibegin )
continue;
2294 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2296 for( ; (j+2UL) <= jblock; j+=2UL )
2298 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2300 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2302 const SIMDType a1( A2.load(i ,k) );
2303 const SIMDType a2( A2.load(i+1UL,k) );
2304 const SIMDType a3( A2.load(i+2UL,k) );
2305 const SIMDType a4( A2.load(i+3UL,k) );
2306 const SIMDType a5( A2.load(i+4UL,k) );
2308 const SIMDType b1( B2.load(k,j ) );
2309 const SIMDType b2( B2.load(k,j+1UL) );
2323 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2324 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2325 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2326 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2327 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2328 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2329 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2330 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2331 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
2332 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
2337 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2339 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2341 const SIMDType a1( A2.load(i ,k) );
2342 const SIMDType a2( A2.load(i+1UL,k) );
2343 const SIMDType a3( A2.load(i+2UL,k) );
2344 const SIMDType a4( A2.load(i+3UL,k) );
2345 const SIMDType a5( A2.load(i+4UL,k) );
2347 const SIMDType b1( B2.load(k,j) );
2356 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2357 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2358 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2359 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2360 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
2366 for( ; (i+4UL) <= isize; i+=4UL )
2368 if( jj+jblock < ibegin )
continue;
2370 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2372 for( ; (j+2UL) <= jblock; j+=2UL )
2374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2376 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2378 const SIMDType a1( A2.load(i ,k) );
2379 const SIMDType a2( A2.load(i+1UL,k) );
2380 const SIMDType a3( A2.load(i+2UL,k) );
2381 const SIMDType a4( A2.load(i+3UL,k) );
2383 const SIMDType b1( B2.load(k,j ) );
2384 const SIMDType b2( B2.load(k,j+1UL) );
2396 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2397 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2398 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2399 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2400 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2401 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2402 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2403 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2408 SIMDType xmm1, xmm2, xmm3, xmm4;
2410 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2412 const SIMDType a1( A2.load(i ,k) );
2413 const SIMDType a2( A2.load(i+1UL,k) );
2414 const SIMDType a3( A2.load(i+2UL,k) );
2415 const SIMDType a4( A2.load(i+3UL,k) );
2417 const SIMDType b1( B2.load(k,j) );
2425 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2426 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2427 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2428 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2433 for( ; (i+2UL) <= isize; i+=2UL )
2435 if( jj+jblock < ibegin )
continue;
2437 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2439 for( ; (j+4UL) <= jblock; j+=4UL )
2441 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2443 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2445 const SIMDType a1( A2.load(i ,k) );
2446 const SIMDType a2( A2.load(i+1UL,k) );
2448 const SIMDType b1( B2.load(k,j ) );
2449 const SIMDType b2( B2.load(k,j+1UL) );
2450 const SIMDType b3( B2.load(k,j+2UL) );
2451 const SIMDType b4( B2.load(k,j+3UL) );
2463 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2464 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2465 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
2466 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
2467 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
2468 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2469 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
2470 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
2473 for( ; (j+2UL) <= jblock; j+=2UL )
2475 SIMDType xmm1, xmm2, xmm3, xmm4;
2477 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2479 const SIMDType a1( A2.load(i ,k) );
2480 const SIMDType a2( A2.load(i+1UL,k) );
2482 const SIMDType b1( B2.load(k,j ) );
2483 const SIMDType b2( B2.load(k,j+1UL) );
2491 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2492 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2493 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2494 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2499 SIMDType xmm1, xmm2;
2501 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2503 const SIMDType a1( A2.load(i ,k) );
2504 const SIMDType a2( A2.load(i+1UL,k) );
2506 const SIMDType b1( B2.load(k,j) );
2512 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2513 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2517 if( i<isize && jj+jblock >= ibegin )
2519 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2521 for( ; (j+2UL) <= jblock; j+=2UL )
2523 SIMDType xmm1, xmm2;
2525 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2527 const SIMDType a1( A2.load(i,k) );
2529 xmm1 += a1 * B2.load(k,j );
2530 xmm2 += a1 * B2.load(k,j+1UL);
2533 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
2534 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2541 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2543 const SIMDType a1( A2.load(i,k) );
2545 xmm1 += a1 * B2.load(k,j);
2548 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
2558 if( remainder && kk < K )
2560 const size_t ksize( K - kk );
2562 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
2563 const size_t isize ( M - ibegin );
2568 size_t jblock( 0UL );
2572 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2574 if( IsUpper<MT3>::value && jj+jblock <= kk ) {
2583 if( IsFloatingPoint<ET1>::value )
2585 for( ; (i+5UL) <= isize; i+=5UL )
2587 if( jj+jblock < ibegin )
continue;
2589 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2591 for( ; (j+2UL) <= jblock; j+=2UL ) {
2592 for(
size_t k=0UL; k<ksize; ++k ) {
2593 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2594 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2595 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2596 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2597 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2598 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2599 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2600 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2601 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
2602 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
2607 for(
size_t k=0UL; k<ksize; ++k ) {
2608 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2609 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2610 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2611 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2612 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
2619 for( ; (i+4UL) <= isize; i+=4UL )
2621 if( jj+jblock < ibegin )
continue;
2623 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2625 for( ; (j+2UL) <= jblock; j+=2UL ) {
2626 for(
size_t k=0UL; k<ksize; ++k ) {
2627 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2628 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2629 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2630 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2631 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2632 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2633 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2634 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2639 for(
size_t k=0UL; k<ksize; ++k ) {
2640 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2641 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2642 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2643 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2649 for( ; (i+2UL) <= isize; i+=2UL )
2651 if( jj+jblock < ibegin )
continue;
2653 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2655 for( ; (j+2UL) <= jblock; j+=2UL ) {
2656 for(
size_t k=0UL; k<ksize; ++k ) {
2657 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2658 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2659 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2660 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2665 for(
size_t k=0UL; k<ksize; ++k ) {
2666 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2667 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2672 if( i<isize && jj+jblock >= ibegin )
2674 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2676 for( ; (j+2UL) <= jblock; j+=2UL ) {
2677 for(
size_t k=0UL; k<ksize; ++k ) {
2678 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
2679 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2684 for(
size_t k=0UL; k<ksize; ++k ) {
2685 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
2717 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2718 void ummm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2720 using ET1 = ElementType_<MT1>;
2721 using ET2 = ElementType_<MT2>;
2722 using ET3 = ElementType_<MT3>;
2723 using SIMDType = SIMDTrait_<ET1>;
2745 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
2747 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2748 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
2753 const size_t M( A.rows() );
2754 const size_t N( B.columns() );
2755 const size_t K( A.columns() );
2759 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
2760 DynamicMatrix<ET3,true> B2( KBLOCK, N );
2762 DerestrictTrait_<MT1> c( derestrict( ~C ) );
2767 else if( !
isOne( beta ) ) {
2772 size_t kblock( 0UL );
2774 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2777 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
2780 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2783 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
2784 const size_t jend ( IsLower<MT3>::value ? kk+kblock : N );
2785 const size_t jsize ( jend - jbegin );
2787 B2 =
serial( submatrix<!remainder>( B, kk, jbegin, kblock, jsize ) );
2790 size_t iblock( 0UL );
2794 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2796 if( ( IsLower<MT2>::value && ii+iblock <= kk ) ||
2797 ( IsUpper<MT2>::value && kk+kblock <= ii ) ) {
2802 A2 =
serial( submatrix<!remainder>( A, ii, kk, iblock, kblock ) );
2806 if( IsFloatingPoint<ET3>::value )
2808 for( ; (j+5UL) <= jsize; j+=5UL )
2810 if( ii > jbegin+j+4UL )
continue;
2812 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
2815 for( ; (i+2UL) <= iend; i+=2UL )
2817 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2819 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2821 const SIMDType a1( A2.load(i ,k) );
2822 const SIMDType a2( A2.load(i+1UL,k) );
2824 const SIMDType b1( B2.load(k,j ) );
2825 const SIMDType b2( B2.load(k,j+1UL) );
2826 const SIMDType b3( B2.load(k,j+2UL) );
2827 const SIMDType b4( B2.load(k,j+3UL) );
2828 const SIMDType b5( B2.load(k,j+4UL) );
2842 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2843 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2844 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2845 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2846 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2847 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
2848 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
2849 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
2850 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
2851 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
2856 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2858 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2860 const SIMDType a1( A2.load(i,k) );
2862 xmm1 += a1 * B2.load(k,j );
2863 xmm2 += a1 * B2.load(k,j+1UL);
2864 xmm3 += a1 * B2.load(k,j+2UL);
2865 xmm4 += a1 * B2.load(k,j+3UL);
2866 xmm5 += a1 * B2.load(k,j+4UL);
2869 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2870 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2871 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2872 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2873 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2879 for( ; (j+4UL) <= jsize; j+=4UL )
2881 if( ii > jbegin+j+3UL )
continue;
2883 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
2886 for( ; (i+2UL) <= iend; i+=2UL )
2888 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2890 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2892 const SIMDType a1( A2.load(i ,k) );
2893 const SIMDType a2( A2.load(i+1UL,k) );
2895 const SIMDType b1( B2.load(k,j ) );
2896 const SIMDType b2( B2.load(k,j+1UL) );
2897 const SIMDType b3( B2.load(k,j+2UL) );
2898 const SIMDType b4( B2.load(k,j+3UL) );
2910 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2911 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2912 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2913 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2914 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2915 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2916 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
2917 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
2922 SIMDType xmm1, xmm2, xmm3, xmm4;
2924 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2926 const SIMDType a1( A2.load(i,k) );
2928 xmm1 += a1 * B2.load(k,j );
2929 xmm2 += a1 * B2.load(k,j+1UL);
2930 xmm3 += a1 * B2.load(k,j+2UL);
2931 xmm4 += a1 * B2.load(k,j+3UL);
2934 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2935 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2936 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2937 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2942 for( ; (j+2UL) <= jsize; j+=2UL )
2944 if( ii > jbegin+j+1UL )
continue;
2946 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
2949 for( ; (i+4UL) <= iend; i+=4UL )
2951 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2953 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2955 const SIMDType a1( A2.load(i ,k) );
2956 const SIMDType a2( A2.load(i+1UL,k) );
2957 const SIMDType a3( A2.load(i+2UL,k) );
2958 const SIMDType a4( A2.load(i+3UL,k) );
2960 const SIMDType b1( B2.load(k,j ) );
2961 const SIMDType b2( B2.load(k,j+1UL) );
2973 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2974 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2975 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
2976 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
2977 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2978 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2979 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
2980 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
2983 for( ; (i+2UL) <= iend; i+=2UL )
2985 SIMDType xmm1, xmm2, xmm3, xmm4;
2987 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2989 const SIMDType a1( A2.load(i ,k) );
2990 const SIMDType a2( A2.load(i+1UL,k) );
2992 const SIMDType b1( B2.load(k,j ) );
2993 const SIMDType b2( B2.load(k,j+1UL) );
3001 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
3002 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3003 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
3004 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
3009 SIMDType xmm1, xmm2;
3011 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3013 const SIMDType a1( A2.load(i,k) );
3015 xmm1 += a1 * B2.load(k,j );
3016 xmm2 += a1 * B2.load(k,j+1UL);
3019 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
3020 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3024 if( j<jsize && ii <= jbegin+j )
3026 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3029 for( ; (i+2UL) <= iend; i+=2UL )
3031 SIMDType xmm1, xmm2;
3033 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3035 const SIMDType b1( B2.load(k,j) );
3037 xmm1 += A2.load(i ,k) * b1;
3038 xmm2 += A2.load(i+1UL,k) * b1;
3041 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
3042 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
3049 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3051 xmm1 += A2.load(i,k) * B2.load(k,j);
3054 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
3064 if( remainder && kk < K )
3066 const size_t ksize( K - kk );
3068 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
3069 const size_t jsize ( N - jbegin );
3074 size_t iblock( 0UL );
3078 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
3080 if( IsLower<MT2>::value && ii+iblock <= kk ) {
3089 if( IsFloatingPoint<ET1>::value )
3091 for( ; (j+5UL) <= jsize; j+=5UL )
3093 if( ii > jbegin+j+4UL )
continue;
3095 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
3098 for( ; (i+2UL) <= iend; i+=2UL ) {
3099 for(
size_t k=0UL; k<ksize; ++k ) {
3100 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3101 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3102 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3103 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3104 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
3105 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3106 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3107 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3108 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3109 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
3114 for(
size_t k=0UL; k<ksize; ++k ) {
3115 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3116 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3117 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3118 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3119 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
3126 for( ; (j+4UL) <= jsize; j+=4UL )
3128 if( ii > jbegin+j+3UL )
continue;
3130 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
3133 for( ; (i+2UL) <= iend; i+=2UL ) {
3134 for(
size_t k=0UL; k<ksize; ++k ) {
3135 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3136 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3137 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3138 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3139 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3140 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3141 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3142 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3147 for(
size_t k=0UL; k<ksize; ++k ) {
3148 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3149 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3150 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3151 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3157 for( ; (j+2UL) <= jsize; j+=2UL )
3159 if( ii > jbegin+j+1UL )
continue;
3161 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3164 for( ; (i+2UL) <= iend; i+=2UL ) {
3165 for(
size_t k=0UL; k<ksize; ++k ) {
3166 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3167 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3168 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3169 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3174 for(
size_t k=0UL; k<ksize; ++k ) {
3175 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3176 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3181 if( j<jsize && ii <= jbegin+j )
3183 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3186 for( ; (i+2UL) <= iend; i+=2UL ) {
3187 for(
size_t k=0UL; k<ksize; ++k ) {
3188 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
3189 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
3194 for(
size_t k=0UL; k<ksize; ++k ) {
3195 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
3224 template<
typename MT1,
typename MT2,
typename MT3 >
3225 inline void ummm( MT1& C,
const MT2& A,
const MT3& B )
3227 using ET1 = ElementType_<MT1>;
3228 using ET2 = ElementType_<MT2>;
3229 using ET3 = ElementType_<MT3>;
3234 ummm( C, A, B, ET1(1), ET1(0) );
3266 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3267 void smmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3269 using ET1 = ElementType_<MT1>;
3270 using ET2 = ElementType_<MT2>;
3271 using ET3 = ElementType_<MT3>;
3287 const size_t M( A.rows() );
3288 const size_t N( B.columns() );
3292 lmmm( C, A, B, alpha, ST(0) );
3294 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3296 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3298 for(
size_t i=ii; i<iend; ++i ) {
3299 for(
size_t j=i+1UL; j<iend; ++j ) {
3300 (~C)(i,j) = (~C)(j,i);
3304 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3305 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3306 for(
size_t i=ii; i<iend; ++i ) {
3307 for(
size_t j=jj; j<jend; ++j ) {
3308 (~C)(i,j) = (~C)(j,i);
3336 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3337 void smmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3339 using ET1 = ElementType_<MT1>;
3340 using ET2 = ElementType_<MT2>;
3341 using ET3 = ElementType_<MT3>;
3357 const size_t M( A.rows() );
3358 const size_t N( B.columns() );
3362 ummm( C, A, B, alpha, ST(0) );
3364 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3366 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3368 for(
size_t j=jj; j<jend; ++j ) {
3369 for(
size_t i=jj+1UL; i<jend; ++i ) {
3370 (~C)(i,j) = (~C)(j,i);
3374 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3375 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3376 for(
size_t j=jj; j<jend; ++j ) {
3377 for(
size_t i=ii; i<iend; ++i ) {
3378 (~C)(i,j) = (~C)(j,i);
3404 template<
typename MT1,
typename MT2,
typename MT3 >
3405 inline void smmm( MT1& C,
const MT2& A,
const MT3& B )
3407 using ET1 = ElementType_<MT1>;
3408 using ET2 = ElementType_<MT2>;
3409 using ET3 = ElementType_<MT3>;
3414 smmm( C, A, B, ET1(1) );
3446 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3447 void hmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3449 using ET1 = ElementType_<MT1>;
3450 using ET2 = ElementType_<MT2>;
3451 using ET3 = ElementType_<MT3>;
3467 const size_t M( A.rows() );
3468 const size_t N( B.columns() );
3472 lmmm( C, A, B, alpha, ST(0) );
3474 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3476 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3478 for(
size_t i=ii; i<iend; ++i ) {
3479 for(
size_t j=i+1UL; j<iend; ++j ) {
3480 (~C)(i,j) =
conj( (~C)(j,i) );
3484 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3485 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3486 for(
size_t i=ii; i<iend; ++i ) {
3487 for(
size_t j=jj; j<jend; ++j ) {
3488 (~C)(i,j) =
conj( (~C)(j,i) );
3516 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3517 void hmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3519 using ET1 = ElementType_<MT1>;
3520 using ET2 = ElementType_<MT2>;
3521 using ET3 = ElementType_<MT3>;
3537 const size_t M( A.rows() );
3538 const size_t N( B.columns() );
3542 ummm( C, A, B, alpha, ST(0) );
3544 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3546 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3548 for(
size_t j=jj; j<jend; ++j ) {
3549 for(
size_t i=jj+1UL; i<jend; ++i ) {
3550 (~C)(i,j) =
conj( (~C)(j,i) );
3554 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3555 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3556 for(
size_t j=jj; j<jend; ++j ) {
3557 for(
size_t i=ii; i<iend; ++i ) {
3558 (~C)(i,j) =
conj( (~C)(j,i) );
3584 template<
typename MT1,
typename MT2,
typename MT3 >
3585 inline void hmmm( MT1& C,
const MT2& A,
const MT3& B )
3587 using ET1 = ElementType_<MT1>;
3588 using ET2 = ElementType_<MT2>;
3589 using ET3 = ElementType_<MT3>;
3594 hmmm( C, A, B, ET1(1) );
Header file for the implementation of the Submatrix view.
Constraint on the data type.
const DMatForEachExpr< MT, Conj, SO > conj(const DenseMatrix< MT, SO > &dm)
Returns a matrix containing the complex conjugate of each single element of dm.
Definition: DMatForEachExpr.h:1214
Header file for auxiliary alias declarations.
Header file for kernel specific block sizes.
Constraint on the data type.
Header file for mathematical functions.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a strictly upper triangular matrix type...
Definition: StrictlyUpper.h:81
Header file for basic type definitions.
BLAZE_ALWAYS_INLINE const complex< int8_t > sum(const SIMDcint8 &a) noexcept
Returns the sum of all elements in the 8-bit integral complex SIMD vector.
Definition: Reduction.h:63
Header file for the serial shim.
BLAZE_ALWAYS_INLINE size_t size(const Vector< VT, TF > &vector) noexcept
Returns the current size/dimension of the vector.
Definition: Vector.h:261
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i...
Definition: Computation.h:81
#define BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a dense, N-dimensional matrix type...
Definition: DenseMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a upper unitriangular matrix type...
Definition: UniUpper.h:81
void reset(const DiagonalProxy< MT > &proxy)
Resetting the represented element to the default initial values.
Definition: DiagonalProxy.h:533
const ElementType_< MT > min(const DenseMatrix< MT, SO > &dm)
Returns the smallest element of the dense matrix.
Definition: DenseMatrix.h:1755
const DMatSerialExpr< MT, SO > serial(const DenseMatrix< MT, SO > &dm)
Forces the serial evaluation of the given dense matrix expression dm.
Definition: DMatSerialExpr.h:721
Constraint on the data type.
Constraints on the storage order of matrix types.
Constraint on the data type.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE(T)
Constraint on the data type.In case the given data type T is an adaptor type (as for instance LowerMa...
Definition: Adaptor.h:81
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a strictly lower triangular matrix type...
Definition: StrictlyLower.h:81
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES(T1, T2)
Constraint on the data type.In case the given data types T1 and T2 are not SIMD combinable (i...
Definition: SIMDCombinable.h:61
Namespace of the Blaze C++ math library.
Definition: Blaze.h:57
Header file for the IsFloatingPoint type trait.
#define BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a column-major dense or sparse matri...
Definition: ColumnMajorMatrix.h:61
Header file for the DenseMatrix base class.
Header file for all SIMD functionality.
Header file for the IsLower type trait.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a upper triangular matrix type...
Definition: Upper.h:81
Header file for the implementation of a dynamic MxN matrix.
Constraint on the data type.
Header file for the IsPadded type trait.
Header file for the isOne shim.
Header file for the DerestrictTrait class template.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a symmetric matrix type, a compilation error is created.
Definition: Symmetric.h:79
#define BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a row-major dense or sparse matrix t...
Definition: RowMajorMatrix.h:61
Header file for run time assertion macros.
Constraint on the data type.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a lower triangular matrix type...
Definition: Lower.h:81
bool isOne(const DiagonalProxy< MT > &proxy)
Returns whether the represented element is 1.
Definition: DiagonalProxy.h:635
Header file for the isDefault shim.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a lower unitriangular matrix type...
Definition: UniLower.h:81
Constraint on the data type.
Constraints on the storage order of matrix types.
Header file for the IsRowMajorMatrix type trait.
bool isDefault(const DiagonalProxy< MT > &proxy)
Returns whether the represented element is in default state.
Definition: DiagonalProxy.h:573
#define BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is an Hermitian matrix type, a compilation error is created.
Definition: Hermitian.h:79
Header file for the IsUpper type trait.
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.In case of an invalid compile time expression, a compilation error is cr...
Definition: StaticAssert.h:112
SubmatrixExprTrait_< MT, unaligned > submatrix(Matrix< MT, SO > &matrix, size_t row, size_t column, size_t m, size_t n)
Creating a view on a specific submatrix of the given matrix.
Definition: Submatrix.h:168
#define BLAZE_INTERNAL_ASSERT(expr, msg)
Run time assertion macro for internal checks.In case of an invalid run time expression, the program execution is terminated. The BLAZE_INTERNAL_ASSERT macro can be disabled by setting the BLAZE_USER_ASSERTION flag to zero or by defining NDEBUG during the compilation.
Definition: Assert.h:101
Constraint on the data type.