35 #ifndef _BLAZE_MATH_DENSE_MMM_H_ 36 #define _BLAZE_MATH_DENSE_MMM_H_ 103 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
104 void mmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
106 using ET1 = ElementType_<MT1>;
107 using ET2 = ElementType_<MT2>;
108 using ET3 = ElementType_<MT3>;
109 using SIMDType = SIMDTrait_<ET1>;
127 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
129 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
130 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
135 const size_t M( A.rows() );
136 const size_t N( B.columns() );
137 const size_t K( A.columns() );
141 DynamicMatrix<ET2,false> A2( M, KBLOCK );
142 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
147 else if( !
isOne( beta ) ) {
152 size_t kblock( 0UL );
154 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
157 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
160 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
163 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
164 const size_t iend ( IsUpper<MT2>::value ? kk+kblock : M );
165 const size_t isize ( iend - ibegin );
167 A2 =
serial( submatrix<!remainder>( A, ibegin, kk, isize, kblock ) );
170 size_t jblock( 0UL );
174 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
176 if( ( IsLower<MT3>::value && kk+kblock <= jj ) ||
177 ( IsUpper<MT3>::value && jj+jblock <= kk ) ) {
182 B2 =
serial( submatrix<!remainder>( B, kk, jj, kblock, jblock ) );
186 if( IsFloatingPoint<ET1>::value )
188 for( ; (i+5UL) <= isize; i+=5UL )
192 for( ; (j+2UL) <= jblock; j+=2UL )
194 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
196 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
198 const SIMDType a1( A2.load(i ,k) );
199 const SIMDType a2( A2.load(i+1UL,k) );
200 const SIMDType a3( A2.load(i+2UL,k) );
201 const SIMDType a4( A2.load(i+3UL,k) );
202 const SIMDType a5( A2.load(i+4UL,k) );
204 const SIMDType b1( B2.load(k,j ) );
205 const SIMDType b2( B2.load(k,j+1UL) );
219 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
220 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
221 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
222 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
223 (~C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
224 (~C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
225 (~C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
226 (~C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
227 (~C)(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
228 (~C)(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
233 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
235 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
237 const SIMDType a1( A2.load(i ,k) );
238 const SIMDType a2( A2.load(i+1UL,k) );
239 const SIMDType a3( A2.load(i+2UL,k) );
240 const SIMDType a4( A2.load(i+3UL,k) );
241 const SIMDType a5( A2.load(i+4UL,k) );
243 const SIMDType b1( B2.load(k,j) );
252 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
253 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
254 (~C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
255 (~C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
256 (~C)(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
262 for( ; (i+4UL) <= isize; i+=4UL )
266 for( ; (j+2UL) <= jblock; j+=2UL )
268 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
270 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
272 const SIMDType a1( A2.load(i ,k) );
273 const SIMDType a2( A2.load(i+1UL,k) );
274 const SIMDType a3( A2.load(i+2UL,k) );
275 const SIMDType a4( A2.load(i+3UL,k) );
277 const SIMDType b1( B2.load(k,j ) );
278 const SIMDType b2( B2.load(k,j+1UL) );
290 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
291 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
292 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
293 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
294 (~C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
295 (~C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
296 (~C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
297 (~C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
302 SIMDType xmm1, xmm2, xmm3, xmm4;
304 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
306 const SIMDType a1( A2.load(i ,k) );
307 const SIMDType a2( A2.load(i+1UL,k) );
308 const SIMDType a3( A2.load(i+2UL,k) );
309 const SIMDType a4( A2.load(i+3UL,k) );
311 const SIMDType b1( B2.load(k,j) );
319 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
320 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
321 (~C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
322 (~C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
327 for( ; (i+2UL) <= isize; i+=2UL )
331 for( ; (j+4UL) <= jblock; j+=4UL )
333 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
335 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
337 const SIMDType a1( A2.load(i ,k) );
338 const SIMDType a2( A2.load(i+1UL,k) );
340 const SIMDType b1( B2.load(k,j ) );
341 const SIMDType b2( B2.load(k,j+1UL) );
342 const SIMDType b3( B2.load(k,j+2UL) );
343 const SIMDType b4( B2.load(k,j+3UL) );
355 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
356 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
357 (~C)(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
358 (~C)(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
359 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
360 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
361 (~C)(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
362 (~C)(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
365 for( ; (j+2UL) <= jblock; j+=2UL )
367 SIMDType xmm1, xmm2, xmm3, xmm4;
369 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
371 const SIMDType a1( A2.load(i ,k) );
372 const SIMDType a2( A2.load(i+1UL,k) );
374 const SIMDType b1( B2.load(k,j ) );
375 const SIMDType b2( B2.load(k,j+1UL) );
383 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
384 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
385 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
386 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
393 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
395 const SIMDType a1( A2.load(i ,k) );
396 const SIMDType a2( A2.load(i+1UL,k) );
398 const SIMDType b1( B2.load(k,j) );
404 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
405 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
413 for( ; (j+2UL) <= jblock; j+=2UL )
417 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
419 const SIMDType a1( A2.load(i,k) );
421 xmm1 += a1 * B2.load(k,j );
422 xmm2 += a1 * B2.load(k,j+1UL);
425 (~C)(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
426 (~C)(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
433 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
435 const SIMDType a1( A2.load(i,k) );
437 xmm1 += a1 * B2.load(k,j);
440 (~C)(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
450 if( remainder && kk < K )
452 const size_t ksize( K - kk );
454 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
455 const size_t isize ( M - ibegin );
460 size_t jblock( 0UL );
464 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
466 if( IsUpper<MT3>::value && jj+jblock <= kk ) {
475 if( IsFloatingPoint<ET1>::value )
477 for( ; (i+5UL) <= isize; i+=5UL )
481 for( ; (j+2UL) <= jblock; j+=2UL ) {
482 for(
size_t k=0UL; k<ksize; ++k ) {
483 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
484 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
485 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
486 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
487 (~C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
488 (~C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
489 (~C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
490 (~C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
491 (~C)(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
492 (~C)(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
497 for(
size_t k=0UL; k<ksize; ++k ) {
498 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
499 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
500 (~C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
501 (~C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
502 (~C)(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
509 for( ; (i+4UL) <= isize; i+=4UL )
513 for( ; (j+2UL) <= jblock; j+=2UL ) {
514 for(
size_t k=0UL; k<ksize; ++k ) {
515 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
516 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
517 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
518 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
519 (~C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
520 (~C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
521 (~C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
522 (~C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
527 for(
size_t k=0UL; k<ksize; ++k ) {
528 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
529 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
530 (~C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
531 (~C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
537 for( ; (i+2UL) <= isize; i+=2UL )
541 for( ; (j+2UL) <= jblock; j+=2UL ) {
542 for(
size_t k=0UL; k<ksize; ++k ) {
543 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
544 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
545 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
546 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
551 for(
size_t k=0UL; k<ksize; ++k ) {
552 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
553 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
562 for( ; (j+2UL) <= jblock; j+=2UL ) {
563 for(
size_t k=0UL; k<ksize; ++k ) {
564 (~C)(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
565 (~C)(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
570 for(
size_t k=0UL; k<ksize; ++k ) {
571 (~C)(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
603 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
604 void mmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
606 using ET1 = ElementType_<MT1>;
607 using ET2 = ElementType_<MT2>;
608 using ET3 = ElementType_<MT3>;
609 using SIMDType = SIMDTrait_<ET1>;
627 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
629 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
630 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
635 const size_t M( A.rows() );
636 const size_t N( B.columns() );
637 const size_t K( A.columns() );
641 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
642 DynamicMatrix<ET3,true> B2( KBLOCK, N );
647 else if( !
isOne( beta ) ) {
652 size_t kblock( 0UL );
654 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
657 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
660 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
663 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
664 const size_t jend ( IsLower<MT3>::value ? kk+kblock : N );
665 const size_t jsize ( jend - jbegin );
667 B2 =
serial( submatrix<!remainder>( B, kk, jbegin, kblock, jsize ) );
670 size_t iblock( 0UL );
674 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
676 if( ( IsLower<MT2>::value && ii+iblock <= kk ) ||
677 ( IsUpper<MT2>::value && kk+kblock <= ii ) ) {
682 A2 =
serial( submatrix<!remainder>( A, ii, kk, iblock, kblock ) );
686 if( IsFloatingPoint<ET3>::value )
688 for( ; (j+5UL) <= jsize; j+=5UL )
692 for( ; (i+2UL) <= iblock; i+=2UL )
694 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
696 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
698 const SIMDType a1( A2.load(i ,k) );
699 const SIMDType a2( A2.load(i+1UL,k) );
701 const SIMDType b1( B2.load(k,j ) );
702 const SIMDType b2( B2.load(k,j+1UL) );
703 const SIMDType b3( B2.load(k,j+2UL) );
704 const SIMDType b4( B2.load(k,j+3UL) );
705 const SIMDType b5( B2.load(k,j+4UL) );
719 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
720 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
721 (~C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
722 (~C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
723 (~C)(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
724 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
725 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
726 (~C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
727 (~C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
728 (~C)(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
733 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
735 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
737 const SIMDType a1( A2.load(i,k) );
739 xmm1 += a1 * B2.load(k,j );
740 xmm2 += a1 * B2.load(k,j+1UL);
741 xmm3 += a1 * B2.load(k,j+2UL);
742 xmm4 += a1 * B2.load(k,j+3UL);
743 xmm5 += a1 * B2.load(k,j+4UL);
746 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
747 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
748 (~C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
749 (~C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
750 (~C)(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
756 for( ; (j+4UL) <= jsize; j+=4UL )
760 for( ; (i+2UL) <= iblock; i+=2UL )
762 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
764 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
766 const SIMDType a1( A2.load(i ,k) );
767 const SIMDType a2( A2.load(i+1UL,k) );
769 const SIMDType b1( B2.load(k,j ) );
770 const SIMDType b2( B2.load(k,j+1UL) );
771 const SIMDType b3( B2.load(k,j+2UL) );
772 const SIMDType b4( B2.load(k,j+3UL) );
784 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
785 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
786 (~C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
787 (~C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
788 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
789 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
790 (~C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
791 (~C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
796 SIMDType xmm1, xmm2, xmm3, xmm4;
798 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
800 const SIMDType a1( A2.load(i,k) );
802 xmm1 += a1 * B2.load(k,j );
803 xmm2 += a1 * B2.load(k,j+1UL);
804 xmm3 += a1 * B2.load(k,j+2UL);
805 xmm4 += a1 * B2.load(k,j+3UL);
808 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
809 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
810 (~C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
811 (~C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
816 for( ; (j+2UL) <= jsize; j+=2UL )
820 for( ; (i+4UL) <= iblock; i+=4UL )
822 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
824 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
826 const SIMDType a1( A2.load(i ,k) );
827 const SIMDType a2( A2.load(i+1UL,k) );
828 const SIMDType a3( A2.load(i+2UL,k) );
829 const SIMDType a4( A2.load(i+3UL,k) );
831 const SIMDType b1( B2.load(k,j ) );
832 const SIMDType b2( B2.load(k,j+1UL) );
844 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
845 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
846 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
847 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
848 (~C)(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
849 (~C)(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
850 (~C)(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
851 (~C)(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
854 for( ; (i+2UL) <= iblock; i+=2UL )
856 SIMDType xmm1, xmm2, xmm3, xmm4;
858 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
860 const SIMDType a1( A2.load(i ,k) );
861 const SIMDType a2( A2.load(i+1UL,k) );
863 const SIMDType b1( B2.load(k,j ) );
864 const SIMDType b2( B2.load(k,j+1UL) );
872 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
873 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
874 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
875 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
882 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
884 const SIMDType a1( A2.load(i,k) );
886 xmm1 += a1 * B2.load(k,j );
887 xmm2 += a1 * B2.load(k,j+1UL);
890 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
891 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
899 for( ; (i+2UL) <= iblock; i+=2UL )
903 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
905 const SIMDType b1( B2.load(k,j) );
907 xmm1 += A2.load(i ,k) * b1;
908 xmm2 += A2.load(i+1UL,k) * b1;
911 (~C)(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
912 (~C)(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
919 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
921 xmm1 += A2.load(i,k) * B2.load(k,j);
924 (~C)(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
934 if( remainder && kk < K )
936 const size_t ksize( K - kk );
938 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
939 const size_t jsize ( N - jbegin );
944 size_t iblock( 0UL );
948 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
950 if( IsLower<MT2>::value && ii+iblock <= kk ) {
959 if( IsFloatingPoint<ET1>::value )
961 for( ; (j+5UL) <= jsize; j+=5UL )
965 for( ; (i+2UL) <= iblock; i+=2UL ) {
966 for(
size_t k=0UL; k<ksize; ++k ) {
967 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
968 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
969 (~C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
970 (~C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
971 (~C)(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
972 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
973 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
974 (~C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
975 (~C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
976 (~C)(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
981 for(
size_t k=0UL; k<ksize; ++k ) {
982 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
983 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
984 (~C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
985 (~C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
986 (~C)(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
993 for( ; (j+4UL) <= jsize; j+=4UL )
997 for( ; (i+2UL) <= iblock; i+=2UL ) {
998 for(
size_t k=0UL; k<ksize; ++k ) {
999 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1000 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1001 (~C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
1002 (~C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
1003 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1004 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1005 (~C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
1006 (~C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
1011 for(
size_t k=0UL; k<ksize; ++k ) {
1012 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1013 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1014 (~C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
1015 (~C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
1021 for( ; (j+2UL) <= jsize; j+=2UL )
1025 for( ; (i+2UL) <= iblock; i+=2UL ) {
1026 for(
size_t k=0UL; k<ksize; ++k ) {
1027 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1028 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1029 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1030 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1035 for(
size_t k=0UL; k<ksize; ++k ) {
1036 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1037 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1046 for( ; (i+2UL) <= iblock; i+=2UL ) {
1047 for(
size_t k=0UL; k<ksize; ++k ) {
1048 (~C)(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
1049 (~C)(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1054 for(
size_t k=0UL; k<ksize; ++k ) {
1055 (~C)(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
1084 template<
typename MT1,
typename MT2,
typename MT3 >
1085 inline void mmm( MT1& C,
const MT2& A,
const MT3& B )
1087 using ET1 = ElementType_<MT1>;
1088 using ET2 = ElementType_<MT2>;
1089 using ET3 = ElementType_<MT3>;
1094 mmm( C, A, B, ET1(1), ET1(0) );
1127 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1128 void lmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1130 using ET1 = ElementType_<MT1>;
1131 using ET2 = ElementType_<MT2>;
1132 using ET3 = ElementType_<MT3>;
1133 using SIMDType = SIMDTrait_<ET1>;
1155 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
1157 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1158 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
1163 const size_t M( A.rows() );
1164 const size_t N( B.columns() );
1165 const size_t K( A.columns() );
1169 DynamicMatrix<ET2,false> A2( M, KBLOCK );
1170 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
1172 decltype(
auto) c( derestrict( ~C ) );
1177 else if( !
isOne( beta ) ) {
1182 size_t kblock( 0UL );
1184 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1187 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
1190 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1193 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
1194 const size_t iend ( IsUpper<MT2>::value ? kk+kblock : M );
1195 const size_t isize ( iend - ibegin );
1197 A2 =
serial( submatrix<!remainder>( A, ibegin, kk, isize, kblock ) );
1200 size_t jblock( 0UL );
1204 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1206 if( ( IsLower<MT3>::value && kk+kblock <= jj ) ||
1207 ( IsUpper<MT3>::value && jj+jblock <= kk ) ) {
1212 B2 =
serial( submatrix<!remainder>( B, kk, jj, kblock, jblock ) );
1216 if( IsFloatingPoint<ET1>::value )
1218 for( ; (i+5UL) <= isize; i+=5UL )
1220 if( jj > ibegin+i+4UL )
continue;
1222 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1225 for( ; (j+2UL) <= jend; j+=2UL )
1227 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1229 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1231 const SIMDType a1( A2.load(i ,k) );
1232 const SIMDType a2( A2.load(i+1UL,k) );
1233 const SIMDType a3( A2.load(i+2UL,k) );
1234 const SIMDType a4( A2.load(i+3UL,k) );
1235 const SIMDType a5( A2.load(i+4UL,k) );
1237 const SIMDType b1( B2.load(k,j ) );
1238 const SIMDType b2( B2.load(k,j+1UL) );
1252 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1253 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1254 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1255 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1256 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1257 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1258 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1259 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1260 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
1261 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
1266 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1268 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1270 const SIMDType a1( A2.load(i ,k) );
1271 const SIMDType a2( A2.load(i+1UL,k) );
1272 const SIMDType a3( A2.load(i+2UL,k) );
1273 const SIMDType a4( A2.load(i+3UL,k) );
1274 const SIMDType a5( A2.load(i+4UL,k) );
1276 const SIMDType b1( B2.load(k,j) );
1285 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1286 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1287 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1288 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1289 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
1295 for( ; (i+4UL) <= isize; i+=4UL )
1297 if( jj > ibegin+i+3UL )
continue;
1299 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1302 for( ; (j+2UL) <= jend; j+=2UL )
1304 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1306 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1308 const SIMDType a1( A2.load(i ,k) );
1309 const SIMDType a2( A2.load(i+1UL,k) );
1310 const SIMDType a3( A2.load(i+2UL,k) );
1311 const SIMDType a4( A2.load(i+3UL,k) );
1313 const SIMDType b1( B2.load(k,j ) );
1314 const SIMDType b2( B2.load(k,j+1UL) );
1326 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1327 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1328 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1329 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1330 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1331 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1332 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1333 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1338 SIMDType xmm1, xmm2, xmm3, xmm4;
1340 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1342 const SIMDType a1( A2.load(i ,k) );
1343 const SIMDType a2( A2.load(i+1UL,k) );
1344 const SIMDType a3( A2.load(i+2UL,k) );
1345 const SIMDType a4( A2.load(i+3UL,k) );
1347 const SIMDType b1( B2.load(k,j) );
1355 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1356 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1357 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1358 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1363 for( ; (i+2UL) <= isize; i+=2UL )
1365 if( jj > ibegin+i+1UL )
continue;
1367 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1370 for( ; (j+4UL) <= jend; j+=4UL )
1372 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1374 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1376 const SIMDType a1( A2.load(i ,k) );
1377 const SIMDType a2( A2.load(i+1UL,k) );
1379 const SIMDType b1( B2.load(k,j ) );
1380 const SIMDType b2( B2.load(k,j+1UL) );
1381 const SIMDType b3( B2.load(k,j+2UL) );
1382 const SIMDType b4( B2.load(k,j+3UL) );
1394 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1395 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1396 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
1397 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
1398 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
1399 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1400 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
1401 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
1404 for( ; (j+2UL) <= jend; j+=2UL )
1406 SIMDType xmm1, xmm2, xmm3, xmm4;
1408 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1410 const SIMDType a1( A2.load(i ,k) );
1411 const SIMDType a2( A2.load(i+1UL,k) );
1413 const SIMDType b1( B2.load(k,j ) );
1414 const SIMDType b2( B2.load(k,j+1UL) );
1422 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1423 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1424 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1425 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1430 SIMDType xmm1, xmm2;
1432 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1434 const SIMDType a1( A2.load(i ,k) );
1435 const SIMDType a2( A2.load(i+1UL,k) );
1437 const SIMDType b1( B2.load(k,j) );
1443 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1444 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1448 if( i<isize && jj <= ibegin+i )
1450 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1453 for( ; (j+2UL) <= jend; j+=2UL )
1455 SIMDType xmm1, xmm2;
1457 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1459 const SIMDType a1( A2.load(i,k) );
1461 xmm1 += a1 * B2.load(k,j );
1462 xmm2 += a1 * B2.load(k,j+1UL);
1465 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
1466 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1473 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1475 const SIMDType a1( A2.load(i,k) );
1477 xmm1 += a1 * B2.load(k,j);
1480 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
1490 if( remainder && kk < K )
1492 const size_t ksize( K - kk );
1494 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
1495 const size_t isize ( M - ibegin );
1500 size_t jblock( 0UL );
1504 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1506 if( IsUpper<MT3>::value && jj+jblock <= kk ) {
1515 if( IsFloatingPoint<ET1>::value )
1517 for( ; (i+5UL) <= isize; i+=5UL )
1519 if( jj > ibegin+i+4UL )
continue;
1521 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1524 for( ; (j+2UL) <= jend; j+=2UL ) {
1525 for(
size_t k=0UL; k<ksize; ++k ) {
1526 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1527 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1528 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1529 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1530 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1531 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1532 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1533 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1534 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
1535 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
1540 for(
size_t k=0UL; k<ksize; ++k ) {
1541 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1542 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1543 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1544 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1545 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
1552 for( ; (i+4UL) <= isize; i+=4UL )
1554 if( jj > ibegin+i+3UL )
continue;
1556 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1559 for( ; (j+2UL) <= jend; j+=2UL ) {
1560 for(
size_t k=0UL; k<ksize; ++k ) {
1561 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1562 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1563 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1564 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1565 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1566 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1567 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1568 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1573 for(
size_t k=0UL; k<ksize; ++k ) {
1574 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1575 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1576 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1577 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1583 for( ; (i+2UL) <= isize; i+=2UL )
1585 if( jj > ibegin+i+1UL )
continue;
1587 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1590 for( ; (j+2UL) <= jend; j+=2UL ) {
1591 for(
size_t k=0UL; k<ksize; ++k ) {
1592 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1593 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1594 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1595 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1600 for(
size_t k=0UL; k<ksize; ++k ) {
1601 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1602 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1607 if( i<isize && jj <= ibegin+i )
1609 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1612 for( ; (j+2UL) <= jend; j+=2UL ) {
1613 for(
size_t k=0UL; k<ksize; ++k ) {
1614 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
1615 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1620 for(
size_t k=0UL; k<ksize; ++k ) {
1621 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
1653 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1654 void lmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1656 using ET1 = ElementType_<MT1>;
1657 using ET2 = ElementType_<MT2>;
1658 using ET3 = ElementType_<MT3>;
1659 using SIMDType = SIMDTrait_<ET1>;
1681 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
1683 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1684 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
1689 const size_t M( A.rows() );
1690 const size_t N( B.columns() );
1691 const size_t K( A.columns() );
1695 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
1696 DynamicMatrix<ET3,true> B2( KBLOCK, N );
1698 decltype(
auto) c( derestrict( ~C ) );
1703 else if( !
isOne( beta ) ) {
1708 size_t kblock( 0UL );
1710 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1713 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
1716 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1719 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
1720 const size_t jend ( IsLower<MT3>::value ? kk+kblock : N );
1721 const size_t jsize ( jend - jbegin );
1723 B2 =
serial( submatrix<!remainder>( B, kk, jbegin, kblock, jsize ) );
1726 size_t iblock( 0UL );
1730 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
1732 if( ( IsLower<MT2>::value && ii+iblock <= kk ) ||
1733 ( IsUpper<MT2>::value && kk+kblock <= ii ) ) {
1738 A2 =
serial( submatrix<!remainder>( A, ii, kk, iblock, kblock ) );
1742 if( IsFloatingPoint<ET3>::value )
1744 for( ; (j+5UL) <= jsize; j+=5UL )
1746 if( ii+iblock < jbegin )
continue;
1748 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1750 for( ; (i+2UL) <= iblock; i+=2UL )
1752 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1754 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1756 const SIMDType a1( A2.load(i ,k) );
1757 const SIMDType a2( A2.load(i+1UL,k) );
1759 const SIMDType b1( B2.load(k,j ) );
1760 const SIMDType b2( B2.load(k,j+1UL) );
1761 const SIMDType b3( B2.load(k,j+2UL) );
1762 const SIMDType b4( B2.load(k,j+3UL) );
1763 const SIMDType b5( B2.load(k,j+4UL) );
1777 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1778 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1779 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1780 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1781 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1782 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
1783 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
1784 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
1785 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
1786 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
1791 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1793 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1795 const SIMDType a1( A2.load(i,k) );
1797 xmm1 += a1 * B2.load(k,j );
1798 xmm2 += a1 * B2.load(k,j+1UL);
1799 xmm3 += a1 * B2.load(k,j+2UL);
1800 xmm4 += a1 * B2.load(k,j+3UL);
1801 xmm5 += a1 * B2.load(k,j+4UL);
1804 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1805 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1806 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1807 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1808 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1814 for( ; (j+4UL) <= jsize; j+=4UL )
1816 if( ii+iblock < jbegin )
continue;
1818 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1820 for( ; (i+2UL) <= iblock; i+=2UL )
1822 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1824 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1826 const SIMDType a1( A2.load(i ,k) );
1827 const SIMDType a2( A2.load(i+1UL,k) );
1829 const SIMDType b1( B2.load(k,j ) );
1830 const SIMDType b2( B2.load(k,j+1UL) );
1831 const SIMDType b3( B2.load(k,j+2UL) );
1832 const SIMDType b4( B2.load(k,j+3UL) );
1844 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1845 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1846 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1847 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1848 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1849 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1850 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
1851 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
1856 SIMDType xmm1, xmm2, xmm3, xmm4;
1858 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1860 const SIMDType a1( A2.load(i,k) );
1862 xmm1 += a1 * B2.load(k,j );
1863 xmm2 += a1 * B2.load(k,j+1UL);
1864 xmm3 += a1 * B2.load(k,j+2UL);
1865 xmm4 += a1 * B2.load(k,j+3UL);
1868 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1869 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1870 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1871 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1876 for( ; (j+2UL) <= jsize; j+=2UL )
1878 if( ii+iblock < jbegin )
continue;
1880 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1882 for( ; (i+4UL) <= iblock; i+=4UL )
1884 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1886 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1888 const SIMDType a1( A2.load(i ,k) );
1889 const SIMDType a2( A2.load(i+1UL,k) );
1890 const SIMDType a3( A2.load(i+2UL,k) );
1891 const SIMDType a4( A2.load(i+3UL,k) );
1893 const SIMDType b1( B2.load(k,j ) );
1894 const SIMDType b2( B2.load(k,j+1UL) );
1906 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1907 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1908 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1909 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1910 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1911 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1912 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
1913 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
1916 for( ; (i+2UL) <= iblock; i+=2UL )
1918 SIMDType xmm1, xmm2, xmm3, xmm4;
1920 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1922 const SIMDType a1( A2.load(i ,k) );
1923 const SIMDType a2( A2.load(i+1UL,k) );
1925 const SIMDType b1( B2.load(k,j ) );
1926 const SIMDType b2( B2.load(k,j+1UL) );
1934 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1935 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1936 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1937 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1942 SIMDType xmm1, xmm2;
1944 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1946 const SIMDType a1( A2.load(i,k) );
1948 xmm1 += a1 * B2.load(k,j );
1949 xmm2 += a1 * B2.load(k,j+1UL);
1952 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1953 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1957 if( j<jsize && ii+iblock >= jbegin )
1959 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1961 for( ; (i+2UL) <= iblock; i+=2UL )
1963 SIMDType xmm1, xmm2;
1965 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1967 const SIMDType b1( B2.load(k,j) );
1969 xmm1 += A2.load(i ,k) * b1;
1970 xmm2 += A2.load(i+1UL,k) * b1;
1973 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
1974 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
1981 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1983 xmm1 += A2.load(i,k) * B2.load(k,j);
1986 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
1996 if( remainder && kk < K )
1998 const size_t ksize( K - kk );
2000 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
2001 const size_t jsize ( N - jbegin );
2006 size_t iblock( 0UL );
2010 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2012 if( IsLower<MT2>::value && ii+iblock <= kk ) {
2021 if( IsFloatingPoint<ET1>::value )
2023 for( ; (j+5UL) <= jsize; j+=5UL )
2025 if( ii+iblock < jbegin )
continue;
2027 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2029 for( ; (i+2UL) <= iblock; i+=2UL ) {
2030 for(
size_t k=0UL; k<ksize; ++k ) {
2031 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2032 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2033 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2034 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2035 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
2036 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2037 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2038 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2039 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2040 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
2045 for(
size_t k=0UL; k<ksize; ++k ) {
2046 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2047 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2048 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2049 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2050 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
2057 for( ; (j+4UL) <= jsize; j+=4UL )
2059 if( ii+iblock < jbegin )
continue;
2061 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2063 for( ; (i+2UL) <= iblock; i+=2UL ) {
2064 for(
size_t k=0UL; k<ksize; ++k ) {
2065 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2066 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2067 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2068 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2069 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2070 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2071 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2072 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2077 for(
size_t k=0UL; k<ksize; ++k ) {
2078 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2079 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2080 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2081 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2087 for( ; (j+2UL) <= jsize; j+=2UL )
2089 if( ii+iblock < jbegin )
continue;
2091 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2093 for( ; (i+2UL) <= iblock; i+=2UL ) {
2094 for(
size_t k=0UL; k<ksize; ++k ) {
2095 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2096 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2097 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2098 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2103 for(
size_t k=0UL; k<ksize; ++k ) {
2104 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2105 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2112 if( ii+iblock < jbegin )
continue;
2114 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2116 for( ; (i+2UL) <= iblock; i+=2UL ) {
2117 for(
size_t k=0UL; k<ksize; ++k ) {
2118 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
2119 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2124 for(
size_t k=0UL; k<ksize; ++k ) {
2125 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
2154 template<
typename MT1,
typename MT2,
typename MT3 >
2155 inline void lmmm( MT1& C,
const MT2& A,
const MT3& B )
2157 using ET1 = ElementType_<MT1>;
2158 using ET2 = ElementType_<MT2>;
2159 using ET3 = ElementType_<MT3>;
2164 lmmm( C, A, B, ET1(1), ET1(0) );
2197 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2198 void ummm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2200 using ET1 = ElementType_<MT1>;
2201 using ET2 = ElementType_<MT2>;
2202 using ET3 = ElementType_<MT3>;
2203 using SIMDType = SIMDTrait_<ET1>;
2225 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
2227 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2228 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
2233 const size_t M( A.rows() );
2234 const size_t N( B.columns() );
2235 const size_t K( A.columns() );
2239 DynamicMatrix<ET2,false> A2( M, KBLOCK );
2240 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
2242 decltype(
auto) c( derestrict( ~C ) );
2247 else if( !
isOne( beta ) ) {
2252 size_t kblock( 0UL );
2254 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2257 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
2260 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2263 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
2264 const size_t iend ( IsUpper<MT2>::value ? kk+kblock : M );
2265 const size_t isize ( iend - ibegin );
2267 A2 =
serial( submatrix<!remainder>( A, ibegin, kk, isize, kblock ) );
2270 size_t jblock( 0UL );
2274 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2276 if( ( IsLower<MT3>::value && kk+kblock <= jj ) ||
2277 ( IsUpper<MT3>::value && jj+jblock <= kk ) ) {
2282 B2 =
serial( submatrix<!remainder>( B, kk, jj, kblock, jblock ) );
2286 if( IsFloatingPoint<ET1>::value )
2288 for( ; (i+5UL) <= isize; i+=5UL )
2290 if( jj+jblock < ibegin )
continue;
2292 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2294 for( ; (j+2UL) <= jblock; j+=2UL )
2296 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2298 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2300 const SIMDType a1( A2.load(i ,k) );
2301 const SIMDType a2( A2.load(i+1UL,k) );
2302 const SIMDType a3( A2.load(i+2UL,k) );
2303 const SIMDType a4( A2.load(i+3UL,k) );
2304 const SIMDType a5( A2.load(i+4UL,k) );
2306 const SIMDType b1( B2.load(k,j ) );
2307 const SIMDType b2( B2.load(k,j+1UL) );
2321 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2322 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2323 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2324 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2325 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2326 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2327 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2328 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2329 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
2330 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
2335 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2337 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2339 const SIMDType a1( A2.load(i ,k) );
2340 const SIMDType a2( A2.load(i+1UL,k) );
2341 const SIMDType a3( A2.load(i+2UL,k) );
2342 const SIMDType a4( A2.load(i+3UL,k) );
2343 const SIMDType a5( A2.load(i+4UL,k) );
2345 const SIMDType b1( B2.load(k,j) );
2354 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2355 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2356 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2357 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2358 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
2364 for( ; (i+4UL) <= isize; i+=4UL )
2366 if( jj+jblock < ibegin )
continue;
2368 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2370 for( ; (j+2UL) <= jblock; j+=2UL )
2372 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2374 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2376 const SIMDType a1( A2.load(i ,k) );
2377 const SIMDType a2( A2.load(i+1UL,k) );
2378 const SIMDType a3( A2.load(i+2UL,k) );
2379 const SIMDType a4( A2.load(i+3UL,k) );
2381 const SIMDType b1( B2.load(k,j ) );
2382 const SIMDType b2( B2.load(k,j+1UL) );
2394 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2395 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2396 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2397 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2398 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2399 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2400 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2401 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2406 SIMDType xmm1, xmm2, xmm3, xmm4;
2408 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2410 const SIMDType a1( A2.load(i ,k) );
2411 const SIMDType a2( A2.load(i+1UL,k) );
2412 const SIMDType a3( A2.load(i+2UL,k) );
2413 const SIMDType a4( A2.load(i+3UL,k) );
2415 const SIMDType b1( B2.load(k,j) );
2423 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2424 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2425 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2426 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2431 for( ; (i+2UL) <= isize; i+=2UL )
2433 if( jj+jblock < ibegin )
continue;
2435 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2437 for( ; (j+4UL) <= jblock; j+=4UL )
2439 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2441 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2443 const SIMDType a1( A2.load(i ,k) );
2444 const SIMDType a2( A2.load(i+1UL,k) );
2446 const SIMDType b1( B2.load(k,j ) );
2447 const SIMDType b2( B2.load(k,j+1UL) );
2448 const SIMDType b3( B2.load(k,j+2UL) );
2449 const SIMDType b4( B2.load(k,j+3UL) );
2461 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2462 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2463 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
2464 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
2465 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
2466 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2467 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
2468 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
2471 for( ; (j+2UL) <= jblock; j+=2UL )
2473 SIMDType xmm1, xmm2, xmm3, xmm4;
2475 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2477 const SIMDType a1( A2.load(i ,k) );
2478 const SIMDType a2( A2.load(i+1UL,k) );
2480 const SIMDType b1( B2.load(k,j ) );
2481 const SIMDType b2( B2.load(k,j+1UL) );
2489 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2490 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2491 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2492 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2497 SIMDType xmm1, xmm2;
2499 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2501 const SIMDType a1( A2.load(i ,k) );
2502 const SIMDType a2( A2.load(i+1UL,k) );
2504 const SIMDType b1( B2.load(k,j) );
2510 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2511 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2515 if( i<isize && jj+jblock >= ibegin )
2517 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2519 for( ; (j+2UL) <= jblock; j+=2UL )
2521 SIMDType xmm1, xmm2;
2523 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2525 const SIMDType a1( A2.load(i,k) );
2527 xmm1 += a1 * B2.load(k,j );
2528 xmm2 += a1 * B2.load(k,j+1UL);
2531 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
2532 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2539 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2541 const SIMDType a1( A2.load(i,k) );
2543 xmm1 += a1 * B2.load(k,j);
2546 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
2556 if( remainder && kk < K )
2558 const size_t ksize( K - kk );
2560 const size_t ibegin( IsLower<MT2>::value ? kk : 0UL );
2561 const size_t isize ( M - ibegin );
2566 size_t jblock( 0UL );
2570 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2572 if( IsUpper<MT3>::value && jj+jblock <= kk ) {
2581 if( IsFloatingPoint<ET1>::value )
2583 for( ; (i+5UL) <= isize; i+=5UL )
2585 if( jj+jblock < ibegin )
continue;
2587 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2589 for( ; (j+2UL) <= jblock; j+=2UL ) {
2590 for(
size_t k=0UL; k<ksize; ++k ) {
2591 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2592 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2593 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2594 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2595 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2596 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2597 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2598 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2599 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
2600 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
2605 for(
size_t k=0UL; k<ksize; ++k ) {
2606 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2607 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2608 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2609 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2610 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
2617 for( ; (i+4UL) <= isize; i+=4UL )
2619 if( jj+jblock < ibegin )
continue;
2621 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2623 for( ; (j+2UL) <= jblock; j+=2UL ) {
2624 for(
size_t k=0UL; k<ksize; ++k ) {
2625 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2626 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2627 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2628 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2629 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2630 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2631 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2632 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2637 for(
size_t k=0UL; k<ksize; ++k ) {
2638 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2639 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2640 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2641 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2647 for( ; (i+2UL) <= isize; i+=2UL )
2649 if( jj+jblock < ibegin )
continue;
2651 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2653 for( ; (j+2UL) <= jblock; j+=2UL ) {
2654 for(
size_t k=0UL; k<ksize; ++k ) {
2655 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2656 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2657 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2658 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2663 for(
size_t k=0UL; k<ksize; ++k ) {
2664 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2665 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2670 if( i<isize && jj+jblock >= ibegin )
2672 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2674 for( ; (j+2UL) <= jblock; j+=2UL ) {
2675 for(
size_t k=0UL; k<ksize; ++k ) {
2676 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
2677 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2682 for(
size_t k=0UL; k<ksize; ++k ) {
2683 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
2715 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2716 void ummm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2718 using ET1 = ElementType_<MT1>;
2719 using ET2 = ElementType_<MT2>;
2720 using ET3 = ElementType_<MT3>;
2721 using SIMDType = SIMDTrait_<ET1>;
2743 constexpr
bool remainder( !IsPadded<MT2>::value || !IsPadded<MT3>::value );
2745 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2746 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
2751 const size_t M( A.rows() );
2752 const size_t N( B.columns() );
2753 const size_t K( A.columns() );
2757 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
2758 DynamicMatrix<ET3,true> B2( KBLOCK, N );
2760 decltype(
auto) c( derestrict( ~C ) );
2765 else if( !
isOne( beta ) ) {
2770 size_t kblock( 0UL );
2772 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2775 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
2778 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2781 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
2782 const size_t jend ( IsLower<MT3>::value ? kk+kblock : N );
2783 const size_t jsize ( jend - jbegin );
2785 B2 =
serial( submatrix<!remainder>( B, kk, jbegin, kblock, jsize ) );
2788 size_t iblock( 0UL );
2792 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2794 if( ( IsLower<MT2>::value && ii+iblock <= kk ) ||
2795 ( IsUpper<MT2>::value && kk+kblock <= ii ) ) {
2800 A2 =
serial( submatrix<!remainder>( A, ii, kk, iblock, kblock ) );
2804 if( IsFloatingPoint<ET3>::value )
2806 for( ; (j+5UL) <= jsize; j+=5UL )
2808 if( ii > jbegin+j+4UL )
continue;
2810 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
2813 for( ; (i+2UL) <= iend; i+=2UL )
2815 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2817 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2819 const SIMDType a1( A2.load(i ,k) );
2820 const SIMDType a2( A2.load(i+1UL,k) );
2822 const SIMDType b1( B2.load(k,j ) );
2823 const SIMDType b2( B2.load(k,j+1UL) );
2824 const SIMDType b3( B2.load(k,j+2UL) );
2825 const SIMDType b4( B2.load(k,j+3UL) );
2826 const SIMDType b5( B2.load(k,j+4UL) );
2840 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2841 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2842 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2843 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2844 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2845 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
2846 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
2847 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
2848 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
2849 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
2854 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2856 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2858 const SIMDType a1( A2.load(i,k) );
2860 xmm1 += a1 * B2.load(k,j );
2861 xmm2 += a1 * B2.load(k,j+1UL);
2862 xmm3 += a1 * B2.load(k,j+2UL);
2863 xmm4 += a1 * B2.load(k,j+3UL);
2864 xmm5 += a1 * B2.load(k,j+4UL);
2867 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2868 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2869 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2870 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2871 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2877 for( ; (j+4UL) <= jsize; j+=4UL )
2879 if( ii > jbegin+j+3UL )
continue;
2881 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
2884 for( ; (i+2UL) <= iend; i+=2UL )
2886 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2888 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2890 const SIMDType a1( A2.load(i ,k) );
2891 const SIMDType a2( A2.load(i+1UL,k) );
2893 const SIMDType b1( B2.load(k,j ) );
2894 const SIMDType b2( B2.load(k,j+1UL) );
2895 const SIMDType b3( B2.load(k,j+2UL) );
2896 const SIMDType b4( B2.load(k,j+3UL) );
2908 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2909 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2910 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2911 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2912 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2913 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2914 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
2915 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
2920 SIMDType xmm1, xmm2, xmm3, xmm4;
2922 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2924 const SIMDType a1( A2.load(i,k) );
2926 xmm1 += a1 * B2.load(k,j );
2927 xmm2 += a1 * B2.load(k,j+1UL);
2928 xmm3 += a1 * B2.load(k,j+2UL);
2929 xmm4 += a1 * B2.load(k,j+3UL);
2932 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2933 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2934 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2935 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2940 for( ; (j+2UL) <= jsize; j+=2UL )
2942 if( ii > jbegin+j+1UL )
continue;
2944 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
2947 for( ; (i+4UL) <= iend; i+=4UL )
2949 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2951 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2953 const SIMDType a1( A2.load(i ,k) );
2954 const SIMDType a2( A2.load(i+1UL,k) );
2955 const SIMDType a3( A2.load(i+2UL,k) );
2956 const SIMDType a4( A2.load(i+3UL,k) );
2958 const SIMDType b1( B2.load(k,j ) );
2959 const SIMDType b2( B2.load(k,j+1UL) );
2971 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2972 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2973 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
2974 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
2975 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2976 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2977 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
2978 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
2981 for( ; (i+2UL) <= iend; i+=2UL )
2983 SIMDType xmm1, xmm2, xmm3, xmm4;
2985 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2987 const SIMDType a1( A2.load(i ,k) );
2988 const SIMDType a2( A2.load(i+1UL,k) );
2990 const SIMDType b1( B2.load(k,j ) );
2991 const SIMDType b2( B2.load(k,j+1UL) );
2999 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
3000 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3001 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
3002 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
3007 SIMDType xmm1, xmm2;
3009 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3011 const SIMDType a1( A2.load(i,k) );
3013 xmm1 += a1 * B2.load(k,j );
3014 xmm2 += a1 * B2.load(k,j+1UL);
3017 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
3018 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3022 if( j<jsize && ii <= jbegin+j )
3024 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3027 for( ; (i+2UL) <= iend; i+=2UL )
3029 SIMDType xmm1, xmm2;
3031 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3033 const SIMDType b1( B2.load(k,j) );
3035 xmm1 += A2.load(i ,k) * b1;
3036 xmm2 += A2.load(i+1UL,k) * b1;
3039 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
3040 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
3047 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3049 xmm1 += A2.load(i,k) * B2.load(k,j);
3052 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
3062 if( remainder && kk < K )
3064 const size_t ksize( K - kk );
3066 const size_t jbegin( IsUpper<MT3>::value ? kk : 0UL );
3067 const size_t jsize ( N - jbegin );
3072 size_t iblock( 0UL );
3076 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
3078 if( IsLower<MT2>::value && ii+iblock <= kk ) {
3087 if( IsFloatingPoint<ET1>::value )
3089 for( ; (j+5UL) <= jsize; j+=5UL )
3091 if( ii > jbegin+j+4UL )
continue;
3093 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
3096 for( ; (i+2UL) <= iend; i+=2UL ) {
3097 for(
size_t k=0UL; k<ksize; ++k ) {
3098 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3099 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3100 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3101 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3102 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
3103 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3104 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3105 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3106 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3107 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
3112 for(
size_t k=0UL; k<ksize; ++k ) {
3113 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3114 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3115 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3116 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3117 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
3124 for( ; (j+4UL) <= jsize; j+=4UL )
3126 if( ii > jbegin+j+3UL )
continue;
3128 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
3131 for( ; (i+2UL) <= iend; i+=2UL ) {
3132 for(
size_t k=0UL; k<ksize; ++k ) {
3133 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3134 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3135 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3136 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3137 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3138 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3139 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3140 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3145 for(
size_t k=0UL; k<ksize; ++k ) {
3146 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3147 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3148 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3149 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3155 for( ; (j+2UL) <= jsize; j+=2UL )
3157 if( ii > jbegin+j+1UL )
continue;
3159 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3162 for( ; (i+2UL) <= iend; i+=2UL ) {
3163 for(
size_t k=0UL; k<ksize; ++k ) {
3164 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3165 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3166 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3167 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3172 for(
size_t k=0UL; k<ksize; ++k ) {
3173 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3174 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3179 if( j<jsize && ii <= jbegin+j )
3181 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3184 for( ; (i+2UL) <= iend; i+=2UL ) {
3185 for(
size_t k=0UL; k<ksize; ++k ) {
3186 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
3187 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
3192 for(
size_t k=0UL; k<ksize; ++k ) {
3193 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
3222 template<
typename MT1,
typename MT2,
typename MT3 >
3223 inline void ummm( MT1& C,
const MT2& A,
const MT3& B )
3225 using ET1 = ElementType_<MT1>;
3226 using ET2 = ElementType_<MT2>;
3227 using ET3 = ElementType_<MT3>;
3232 ummm( C, A, B, ET1(1), ET1(0) );
3264 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3265 void smmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3267 using ET1 = ElementType_<MT1>;
3268 using ET2 = ElementType_<MT2>;
3269 using ET3 = ElementType_<MT3>;
3285 const size_t M( A.rows() );
3286 const size_t N( B.columns() );
3290 lmmm( C, A, B, alpha, ST(0) );
3292 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3294 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3296 for(
size_t i=ii; i<iend; ++i ) {
3297 for(
size_t j=i+1UL; j<iend; ++j ) {
3298 (~C)(i,j) = (~C)(j,i);
3302 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3303 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3304 for(
size_t i=ii; i<iend; ++i ) {
3305 for(
size_t j=jj; j<jend; ++j ) {
3306 (~C)(i,j) = (~C)(j,i);
3334 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3335 void smmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3337 using ET1 = ElementType_<MT1>;
3338 using ET2 = ElementType_<MT2>;
3339 using ET3 = ElementType_<MT3>;
3355 const size_t M( A.rows() );
3356 const size_t N( B.columns() );
3360 ummm( C, A, B, alpha, ST(0) );
3362 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3364 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3366 for(
size_t j=jj; j<jend; ++j ) {
3367 for(
size_t i=jj+1UL; i<jend; ++i ) {
3368 (~C)(i,j) = (~C)(j,i);
3372 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3373 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3374 for(
size_t j=jj; j<jend; ++j ) {
3375 for(
size_t i=ii; i<iend; ++i ) {
3376 (~C)(i,j) = (~C)(j,i);
3402 template<
typename MT1,
typename MT2,
typename MT3 >
3403 inline void smmm( MT1& C,
const MT2& A,
const MT3& B )
3405 using ET1 = ElementType_<MT1>;
3406 using ET2 = ElementType_<MT2>;
3407 using ET3 = ElementType_<MT3>;
3412 smmm( C, A, B, ET1(1) );
3444 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3445 void hmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3447 using ET1 = ElementType_<MT1>;
3448 using ET2 = ElementType_<MT2>;
3449 using ET3 = ElementType_<MT3>;
3465 const size_t M( A.rows() );
3466 const size_t N( B.columns() );
3470 lmmm( C, A, B, alpha, ST(0) );
3472 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3474 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3476 for(
size_t i=ii; i<iend; ++i ) {
3477 for(
size_t j=i+1UL; j<iend; ++j ) {
3478 (~C)(i,j) =
conj( (~C)(j,i) );
3482 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3483 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3484 for(
size_t i=ii; i<iend; ++i ) {
3485 for(
size_t j=jj; j<jend; ++j ) {
3486 (~C)(i,j) =
conj( (~C)(j,i) );
3514 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3515 void hmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3517 using ET1 = ElementType_<MT1>;
3518 using ET2 = ElementType_<MT2>;
3519 using ET3 = ElementType_<MT3>;
3535 const size_t M( A.rows() );
3536 const size_t N( B.columns() );
3540 ummm( C, A, B, alpha, ST(0) );
3542 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3544 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3546 for(
size_t j=jj; j<jend; ++j ) {
3547 for(
size_t i=jj+1UL; i<jend; ++i ) {
3548 (~C)(i,j) =
conj( (~C)(j,i) );
3552 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3553 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3554 for(
size_t j=jj; j<jend; ++j ) {
3555 for(
size_t i=ii; i<iend; ++i ) {
3556 (~C)(i,j) =
conj( (~C)(j,i) );
3582 template<
typename MT1,
typename MT2,
typename MT3 >
3583 inline void hmmm( MT1& C,
const MT2& A,
const MT3& B )
3585 using ET1 = ElementType_<MT1>;
3586 using ET2 = ElementType_<MT2>;
3587 using ET3 = ElementType_<MT3>;
3592 hmmm( C, A, B, ET1(1) );
Header file for the implementation of the Submatrix view.
Constraint on the data type.
Header file for auxiliary alias declarations.
Headerfile for the generic min algorithm.
Header file for kernel specific block sizes.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a strictly upper triangular matrix type...
Definition: StrictlyUpper.h:81
Header file for basic type definitions.
Header file for the serial shim.
BLAZE_ALWAYS_INLINE size_t size(const Vector< VT, TF > &vector) noexcept
Returns the current size/dimension of the vector.
Definition: Vector.h:265
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i...
Definition: Computation.h:81
#define BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a dense, N-dimensional matrix type...
Definition: DenseMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a upper unitriangular matrix type...
Definition: UniUpper.h:81
void reset(const DiagonalProxy< MT > &proxy)
Resetting the represented element to the default initial values.
Definition: DiagonalProxy.h:560
Submatrix< MT, AF > submatrix(Matrix< MT, SO > &matrix, size_t row, size_t column, size_t m, size_t n)
Creating a view on a specific submatrix of the given matrix.
Definition: Submatrix.h:352
const ElementType_< MT > min(const DenseMatrix< MT, SO > &dm)
Returns the smallest element of the dense matrix.
Definition: DenseMatrix.h:1762
Constraint on the data type.
Constraints on the storage order of matrix types.
Constraint on the data type.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE(T)
Constraint on the data type.In case the given data type T is an adaptor type (as for instance LowerMa...
Definition: Adaptor.h:81
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a strictly lower triangular matrix type...
Definition: StrictlyLower.h:81
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES(T1, T2)
Constraint on the data type.In case the given data types T1 and T2 are not SIMD combinable (i...
Definition: SIMDCombinable.h:61
Namespace of the Blaze C++ math library.
Definition: Blaze.h:57
Header file for the IsFloatingPoint type trait.
#define BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a column-major dense or sparse matri...
Definition: ColumnMajorMatrix.h:61
Header file for the DenseMatrix base class.
Header file for all SIMD functionality.
Header file for the IsLower type trait.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a upper triangular matrix type...
Definition: Upper.h:81
Header file for the implementation of a dynamic MxN matrix.
Constraint on the data type.
Header file for the IsPadded type trait.
Header file for the isOne shim.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a symmetric matrix type, a compilation error is created.
Definition: Symmetric.h:79
#define BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a row-major dense or sparse matrix t...
Definition: RowMajorMatrix.h:61
BLAZE_ALWAYS_INLINE ValueType_< T > sum(const SIMDi8< T > &a) noexcept
Returns the sum of all elements in the 8-bit integral SIMD vector.
Definition: Reduction.h:65
Header file for run time assertion macros.
Constraint on the data type.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a lower triangular matrix type...
Definition: Lower.h:81
bool isOne(const DiagonalProxy< MT > &proxy)
Returns whether the represented element is 1.
Definition: DiagonalProxy.h:662
Header file for the isDefault shim.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a lower unitriangular matrix type...
Definition: UniLower.h:81
Constraint on the data type.
Constraints on the storage order of matrix types.
decltype(auto) serial(const DenseMatrix< MT, SO > &dm)
Forces the serial evaluation of the given dense matrix expression dm.
Definition: DMatSerialExpr.h:819
bool isDefault(const DiagonalProxy< MT > &proxy)
Returns whether the represented element is in default state.
Definition: DiagonalProxy.h:600
#define BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is an Hermitian matrix type, a compilation error is created.
Definition: Hermitian.h:79
Header file for the IsUpper type trait.
decltype(auto) conj(const DenseMatrix< MT, SO > &dm)
Returns a matrix containing the complex conjugate of each single element of dm.
Definition: DMatMapExpr.h:1321
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.In case of an invalid compile time expression, a compilation error is cr...
Definition: StaticAssert.h:112
#define BLAZE_INTERNAL_ASSERT(expr, msg)
Run time assertion macro for internal checks.In case of an invalid run time expression, the program execution is terminated. The BLAZE_INTERNAL_ASSERT macro can be disabled by setting the BLAZE_USER_ASSERTION flag to zero or by defining NDEBUG during the compilation.
Definition: Assert.h:101
Constraint on the data type.