35 #ifndef _BLAZE_MATH_DENSE_MMM_H_ 36 #define _BLAZE_MATH_DENSE_MMM_H_ 104 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
105 void mmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
107 using ET1 = ElementType_t<MT1>;
108 using ET2 = ElementType_t<MT2>;
109 using ET3 = ElementType_t<MT3>;
110 using SIMDType = SIMDTrait_t<ET1>;
128 constexpr
bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
130 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
131 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
136 const size_t M( A.rows() );
137 const size_t N( B.columns() );
138 const size_t K( A.columns() );
142 DynamicMatrix<ET2,false> A2( M, KBLOCK );
143 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
148 else if( !
isOne( beta ) ) {
153 size_t kblock( 0UL );
155 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
158 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
161 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
164 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
165 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
166 const size_t isize ( iend - ibegin );
168 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock,
unchecked ) );
171 size_t jblock( 0UL );
175 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
177 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
178 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
183 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock,
unchecked ) );
187 if( IsFloatingPoint_v<ET1> )
189 for( ; (i+5UL) <= isize; i+=5UL )
193 for( ; (j+2UL) <= jblock; j+=2UL )
195 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
197 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
199 const SIMDType a1( A2.load(i ,k) );
200 const SIMDType a2( A2.load(i+1UL,k) );
201 const SIMDType a3( A2.load(i+2UL,k) );
202 const SIMDType a4( A2.load(i+3UL,k) );
203 const SIMDType a5( A2.load(i+4UL,k) );
205 const SIMDType b1( B2.load(k,j ) );
206 const SIMDType b2( B2.load(k,j+1UL) );
220 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
221 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
222 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
223 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
224 (~C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
225 (~C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
226 (~C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
227 (~C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
228 (~C)(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
229 (~C)(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
234 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
236 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
238 const SIMDType a1( A2.load(i ,k) );
239 const SIMDType a2( A2.load(i+1UL,k) );
240 const SIMDType a3( A2.load(i+2UL,k) );
241 const SIMDType a4( A2.load(i+3UL,k) );
242 const SIMDType a5( A2.load(i+4UL,k) );
244 const SIMDType b1( B2.load(k,j) );
253 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
254 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
255 (~C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
256 (~C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
257 (~C)(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
263 for( ; (i+4UL) <= isize; i+=4UL )
267 for( ; (j+2UL) <= jblock; j+=2UL )
269 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
271 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
273 const SIMDType a1( A2.load(i ,k) );
274 const SIMDType a2( A2.load(i+1UL,k) );
275 const SIMDType a3( A2.load(i+2UL,k) );
276 const SIMDType a4( A2.load(i+3UL,k) );
278 const SIMDType b1( B2.load(k,j ) );
279 const SIMDType b2( B2.load(k,j+1UL) );
291 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
292 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
293 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
294 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
295 (~C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
296 (~C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
297 (~C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
298 (~C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
303 SIMDType xmm1, xmm2, xmm3, xmm4;
305 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
307 const SIMDType a1( A2.load(i ,k) );
308 const SIMDType a2( A2.load(i+1UL,k) );
309 const SIMDType a3( A2.load(i+2UL,k) );
310 const SIMDType a4( A2.load(i+3UL,k) );
312 const SIMDType b1( B2.load(k,j) );
320 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
321 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
322 (~C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
323 (~C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
328 for( ; (i+2UL) <= isize; i+=2UL )
332 for( ; (j+4UL) <= jblock; j+=4UL )
334 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
336 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
338 const SIMDType a1( A2.load(i ,k) );
339 const SIMDType a2( A2.load(i+1UL,k) );
341 const SIMDType b1( B2.load(k,j ) );
342 const SIMDType b2( B2.load(k,j+1UL) );
343 const SIMDType b3( B2.load(k,j+2UL) );
344 const SIMDType b4( B2.load(k,j+3UL) );
356 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
357 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
358 (~C)(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
359 (~C)(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
360 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
361 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
362 (~C)(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
363 (~C)(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
366 for( ; (j+2UL) <= jblock; j+=2UL )
368 SIMDType xmm1, xmm2, xmm3, xmm4;
370 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
372 const SIMDType a1( A2.load(i ,k) );
373 const SIMDType a2( A2.load(i+1UL,k) );
375 const SIMDType b1( B2.load(k,j ) );
376 const SIMDType b2( B2.load(k,j+1UL) );
384 (~C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
385 (~C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
386 (~C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
387 (~C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
394 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
396 const SIMDType a1( A2.load(i ,k) );
397 const SIMDType a2( A2.load(i+1UL,k) );
399 const SIMDType b1( B2.load(k,j) );
405 (~C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
406 (~C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
414 for( ; (j+2UL) <= jblock; j+=2UL )
418 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
420 const SIMDType a1( A2.load(i,k) );
422 xmm1 += a1 * B2.load(k,j );
423 xmm2 += a1 * B2.load(k,j+1UL);
426 (~C)(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
427 (~C)(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
434 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
436 const SIMDType a1( A2.load(i,k) );
438 xmm1 += a1 * B2.load(k,j);
441 (~C)(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
451 if( remainder && kk < K )
453 const size_t ksize( K - kk );
455 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
456 const size_t isize ( M - ibegin );
461 size_t jblock( 0UL );
465 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
467 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
476 if( IsFloatingPoint_v<ET1> )
478 for( ; (i+5UL) <= isize; i+=5UL )
482 for( ; (j+2UL) <= jblock; j+=2UL ) {
483 for(
size_t k=0UL; k<ksize; ++k ) {
484 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
485 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
486 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
487 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
488 (~C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
489 (~C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
490 (~C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
491 (~C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
492 (~C)(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
493 (~C)(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
498 for(
size_t k=0UL; k<ksize; ++k ) {
499 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
500 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
501 (~C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
502 (~C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
503 (~C)(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
510 for( ; (i+4UL) <= isize; i+=4UL )
514 for( ; (j+2UL) <= jblock; j+=2UL ) {
515 for(
size_t k=0UL; k<ksize; ++k ) {
516 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
517 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
518 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
519 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
520 (~C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
521 (~C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
522 (~C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
523 (~C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
528 for(
size_t k=0UL; k<ksize; ++k ) {
529 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
530 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
531 (~C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
532 (~C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
538 for( ; (i+2UL) <= isize; i+=2UL )
542 for( ; (j+2UL) <= jblock; j+=2UL ) {
543 for(
size_t k=0UL; k<ksize; ++k ) {
544 (~C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
545 (~C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
546 (~C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
547 (~C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
552 for(
size_t k=0UL; k<ksize; ++k ) {
553 (~C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
554 (~C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
563 for( ; (j+2UL) <= jblock; j+=2UL ) {
564 for(
size_t k=0UL; k<ksize; ++k ) {
565 (~C)(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
566 (~C)(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
571 for(
size_t k=0UL; k<ksize; ++k ) {
572 (~C)(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
604 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
605 void mmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
607 using ET1 = ElementType_t<MT1>;
608 using ET2 = ElementType_t<MT2>;
609 using ET3 = ElementType_t<MT3>;
610 using SIMDType = SIMDTrait_t<ET1>;
628 constexpr
bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
630 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
631 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
636 const size_t M( A.rows() );
637 const size_t N( B.columns() );
638 const size_t K( A.columns() );
642 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
643 DynamicMatrix<ET3,true> B2( KBLOCK, N );
648 else if( !
isOne( beta ) ) {
653 size_t kblock( 0UL );
655 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
658 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
661 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
664 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
665 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
666 const size_t jsize ( jend - jbegin );
668 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize,
unchecked ) );
671 size_t iblock( 0UL );
675 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
677 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
678 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
683 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock,
unchecked ) );
687 if( IsFloatingPoint_v<ET3> )
689 for( ; (j+5UL) <= jsize; j+=5UL )
693 for( ; (i+2UL) <= iblock; i+=2UL )
695 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
697 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
699 const SIMDType a1( A2.load(i ,k) );
700 const SIMDType a2( A2.load(i+1UL,k) );
702 const SIMDType b1( B2.load(k,j ) );
703 const SIMDType b2( B2.load(k,j+1UL) );
704 const SIMDType b3( B2.load(k,j+2UL) );
705 const SIMDType b4( B2.load(k,j+3UL) );
706 const SIMDType b5( B2.load(k,j+4UL) );
720 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
721 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
722 (~C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
723 (~C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
724 (~C)(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
725 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
726 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
727 (~C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
728 (~C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
729 (~C)(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
734 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
736 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
738 const SIMDType a1( A2.load(i,k) );
740 xmm1 += a1 * B2.load(k,j );
741 xmm2 += a1 * B2.load(k,j+1UL);
742 xmm3 += a1 * B2.load(k,j+2UL);
743 xmm4 += a1 * B2.load(k,j+3UL);
744 xmm5 += a1 * B2.load(k,j+4UL);
747 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
748 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
749 (~C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
750 (~C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
751 (~C)(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
757 for( ; (j+4UL) <= jsize; j+=4UL )
761 for( ; (i+2UL) <= iblock; i+=2UL )
763 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
765 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
767 const SIMDType a1( A2.load(i ,k) );
768 const SIMDType a2( A2.load(i+1UL,k) );
770 const SIMDType b1( B2.load(k,j ) );
771 const SIMDType b2( B2.load(k,j+1UL) );
772 const SIMDType b3( B2.load(k,j+2UL) );
773 const SIMDType b4( B2.load(k,j+3UL) );
785 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
786 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
787 (~C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
788 (~C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
789 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
790 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
791 (~C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
792 (~C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
797 SIMDType xmm1, xmm2, xmm3, xmm4;
799 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
801 const SIMDType a1( A2.load(i,k) );
803 xmm1 += a1 * B2.load(k,j );
804 xmm2 += a1 * B2.load(k,j+1UL);
805 xmm3 += a1 * B2.load(k,j+2UL);
806 xmm4 += a1 * B2.load(k,j+3UL);
809 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
810 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
811 (~C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
812 (~C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
817 for( ; (j+2UL) <= jsize; j+=2UL )
821 for( ; (i+4UL) <= iblock; i+=4UL )
823 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
825 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
827 const SIMDType a1( A2.load(i ,k) );
828 const SIMDType a2( A2.load(i+1UL,k) );
829 const SIMDType a3( A2.load(i+2UL,k) );
830 const SIMDType a4( A2.load(i+3UL,k) );
832 const SIMDType b1( B2.load(k,j ) );
833 const SIMDType b2( B2.load(k,j+1UL) );
845 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
846 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
847 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
848 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
849 (~C)(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
850 (~C)(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
851 (~C)(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
852 (~C)(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
855 for( ; (i+2UL) <= iblock; i+=2UL )
857 SIMDType xmm1, xmm2, xmm3, xmm4;
859 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
861 const SIMDType a1( A2.load(i ,k) );
862 const SIMDType a2( A2.load(i+1UL,k) );
864 const SIMDType b1( B2.load(k,j ) );
865 const SIMDType b2( B2.load(k,j+1UL) );
873 (~C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
874 (~C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
875 (~C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
876 (~C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
883 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
885 const SIMDType a1( A2.load(i,k) );
887 xmm1 += a1 * B2.load(k,j );
888 xmm2 += a1 * B2.load(k,j+1UL);
891 (~C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
892 (~C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
900 for( ; (i+2UL) <= iblock; i+=2UL )
904 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
906 const SIMDType b1( B2.load(k,j) );
908 xmm1 += A2.load(i ,k) * b1;
909 xmm2 += A2.load(i+1UL,k) * b1;
912 (~C)(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
913 (~C)(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
920 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
922 xmm1 += A2.load(i,k) * B2.load(k,j);
925 (~C)(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
935 if( remainder && kk < K )
937 const size_t ksize( K - kk );
939 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
940 const size_t jsize ( N - jbegin );
945 size_t iblock( 0UL );
949 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
951 if( IsLower_v<MT2> && ii+iblock <= kk ) {
960 if( IsFloatingPoint_v<ET1> )
962 for( ; (j+5UL) <= jsize; j+=5UL )
966 for( ; (i+2UL) <= iblock; i+=2UL ) {
967 for(
size_t k=0UL; k<ksize; ++k ) {
968 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
969 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
970 (~C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
971 (~C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
972 (~C)(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
973 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
974 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
975 (~C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
976 (~C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
977 (~C)(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
982 for(
size_t k=0UL; k<ksize; ++k ) {
983 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
984 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
985 (~C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
986 (~C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
987 (~C)(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
994 for( ; (j+4UL) <= jsize; j+=4UL )
998 for( ; (i+2UL) <= iblock; i+=2UL ) {
999 for(
size_t k=0UL; k<ksize; ++k ) {
1000 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1001 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1002 (~C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
1003 (~C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
1004 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1005 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1006 (~C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
1007 (~C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
1012 for(
size_t k=0UL; k<ksize; ++k ) {
1013 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1014 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1015 (~C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
1016 (~C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
1022 for( ; (j+2UL) <= jsize; j+=2UL )
1026 for( ; (i+2UL) <= iblock; i+=2UL ) {
1027 for(
size_t k=0UL; k<ksize; ++k ) {
1028 (~C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1029 (~C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1030 (~C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1031 (~C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1036 for(
size_t k=0UL; k<ksize; ++k ) {
1037 (~C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1038 (~C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1047 for( ; (i+2UL) <= iblock; i+=2UL ) {
1048 for(
size_t k=0UL; k<ksize; ++k ) {
1049 (~C)(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
1050 (~C)(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1055 for(
size_t k=0UL; k<ksize; ++k ) {
1056 (~C)(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
1085 template<
typename MT1,
typename MT2,
typename MT3 >
1086 inline void mmm( MT1& C,
const MT2& A,
const MT3& B )
1088 using ET1 = ElementType_t<MT1>;
1089 using ET2 = ElementType_t<MT2>;
1090 using ET3 = ElementType_t<MT3>;
1095 mmm( C, A, B, ET1(1), ET1(0) );
1128 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1129 void lmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1131 using ET1 = ElementType_t<MT1>;
1132 using ET2 = ElementType_t<MT2>;
1133 using ET3 = ElementType_t<MT3>;
1134 using SIMDType = SIMDTrait_t<ET1>;
1156 constexpr
bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1158 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1159 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
1164 const size_t M( A.rows() );
1165 const size_t N( B.columns() );
1166 const size_t K( A.columns() );
1170 DynamicMatrix<ET2,false> A2( M, KBLOCK );
1171 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
1173 decltype(
auto) c( derestrict( ~C ) );
1178 else if( !
isOne( beta ) ) {
1183 size_t kblock( 0UL );
1185 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1188 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
1191 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1194 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1195 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
1196 const size_t isize ( iend - ibegin );
1198 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock,
unchecked ) );
1201 size_t jblock( 0UL );
1205 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1207 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
1208 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
1213 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock,
unchecked ) );
1217 if( IsFloatingPoint_v<ET1> )
1219 for( ; (i+5UL) <= isize; i+=5UL )
1221 if( jj > ibegin+i+4UL )
continue;
1223 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1226 for( ; (j+2UL) <= jend; j+=2UL )
1228 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1230 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1232 const SIMDType a1( A2.load(i ,k) );
1233 const SIMDType a2( A2.load(i+1UL,k) );
1234 const SIMDType a3( A2.load(i+2UL,k) );
1235 const SIMDType a4( A2.load(i+3UL,k) );
1236 const SIMDType a5( A2.load(i+4UL,k) );
1238 const SIMDType b1( B2.load(k,j ) );
1239 const SIMDType b2( B2.load(k,j+1UL) );
1253 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1254 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1255 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1256 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1257 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1258 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1259 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1260 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1261 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
1262 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
1267 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1269 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1271 const SIMDType a1( A2.load(i ,k) );
1272 const SIMDType a2( A2.load(i+1UL,k) );
1273 const SIMDType a3( A2.load(i+2UL,k) );
1274 const SIMDType a4( A2.load(i+3UL,k) );
1275 const SIMDType a5( A2.load(i+4UL,k) );
1277 const SIMDType b1( B2.load(k,j) );
1286 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1287 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1288 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1289 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1290 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
1296 for( ; (i+4UL) <= isize; i+=4UL )
1298 if( jj > ibegin+i+3UL )
continue;
1300 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1303 for( ; (j+2UL) <= jend; j+=2UL )
1305 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1307 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1309 const SIMDType a1( A2.load(i ,k) );
1310 const SIMDType a2( A2.load(i+1UL,k) );
1311 const SIMDType a3( A2.load(i+2UL,k) );
1312 const SIMDType a4( A2.load(i+3UL,k) );
1314 const SIMDType b1( B2.load(k,j ) );
1315 const SIMDType b2( B2.load(k,j+1UL) );
1327 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1328 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1329 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1330 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1331 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1332 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1333 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1334 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1339 SIMDType xmm1, xmm2, xmm3, xmm4;
1341 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1343 const SIMDType a1( A2.load(i ,k) );
1344 const SIMDType a2( A2.load(i+1UL,k) );
1345 const SIMDType a3( A2.load(i+2UL,k) );
1346 const SIMDType a4( A2.load(i+3UL,k) );
1348 const SIMDType b1( B2.load(k,j) );
1356 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1357 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1358 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1359 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1364 for( ; (i+2UL) <= isize; i+=2UL )
1366 if( jj > ibegin+i+1UL )
continue;
1368 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1371 for( ; (j+4UL) <= jend; j+=4UL )
1373 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1375 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1377 const SIMDType a1( A2.load(i ,k) );
1378 const SIMDType a2( A2.load(i+1UL,k) );
1380 const SIMDType b1( B2.load(k,j ) );
1381 const SIMDType b2( B2.load(k,j+1UL) );
1382 const SIMDType b3( B2.load(k,j+2UL) );
1383 const SIMDType b4( B2.load(k,j+3UL) );
1395 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1396 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1397 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
1398 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
1399 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
1400 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1401 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
1402 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
1405 for( ; (j+2UL) <= jend; j+=2UL )
1407 SIMDType xmm1, xmm2, xmm3, xmm4;
1409 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1411 const SIMDType a1( A2.load(i ,k) );
1412 const SIMDType a2( A2.load(i+1UL,k) );
1414 const SIMDType b1( B2.load(k,j ) );
1415 const SIMDType b2( B2.load(k,j+1UL) );
1423 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1424 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1425 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1426 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1431 SIMDType xmm1, xmm2;
1433 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1435 const SIMDType a1( A2.load(i ,k) );
1436 const SIMDType a2( A2.load(i+1UL,k) );
1438 const SIMDType b1( B2.load(k,j) );
1444 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1445 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1449 if( i<isize && jj <= ibegin+i )
1451 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1454 for( ; (j+2UL) <= jend; j+=2UL )
1456 SIMDType xmm1, xmm2;
1458 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1460 const SIMDType a1( A2.load(i,k) );
1462 xmm1 += a1 * B2.load(k,j );
1463 xmm2 += a1 * B2.load(k,j+1UL);
1466 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
1467 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1474 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1476 const SIMDType a1( A2.load(i,k) );
1478 xmm1 += a1 * B2.load(k,j);
1481 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
1491 if( remainder && kk < K )
1493 const size_t ksize( K - kk );
1495 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1496 const size_t isize ( M - ibegin );
1501 size_t jblock( 0UL );
1505 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1507 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
1516 if( IsFloatingPoint_v<ET1> )
1518 for( ; (i+5UL) <= isize; i+=5UL )
1520 if( jj > ibegin+i+4UL )
continue;
1522 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1525 for( ; (j+2UL) <= jend; j+=2UL ) {
1526 for(
size_t k=0UL; k<ksize; ++k ) {
1527 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1528 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1529 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1530 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1531 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1532 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1533 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1534 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1535 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
1536 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
1541 for(
size_t k=0UL; k<ksize; ++k ) {
1542 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1543 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1544 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1545 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1546 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
1553 for( ; (i+4UL) <= isize; i+=4UL )
1555 if( jj > ibegin+i+3UL )
continue;
1557 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1560 for( ; (j+2UL) <= jend; j+=2UL ) {
1561 for(
size_t k=0UL; k<ksize; ++k ) {
1562 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1563 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1564 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1565 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1566 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1567 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1568 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1569 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1574 for(
size_t k=0UL; k<ksize; ++k ) {
1575 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1576 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1577 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1578 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1584 for( ; (i+2UL) <= isize; i+=2UL )
1586 if( jj > ibegin+i+1UL )
continue;
1588 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1591 for( ; (j+2UL) <= jend; j+=2UL ) {
1592 for(
size_t k=0UL; k<ksize; ++k ) {
1593 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1594 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1595 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1596 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1601 for(
size_t k=0UL; k<ksize; ++k ) {
1602 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1603 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1608 if( i<isize && jj <= ibegin+i )
1610 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1613 for( ; (j+2UL) <= jend; j+=2UL ) {
1614 for(
size_t k=0UL; k<ksize; ++k ) {
1615 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
1616 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1621 for(
size_t k=0UL; k<ksize; ++k ) {
1622 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
1654 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1655 void lmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1657 using ET1 = ElementType_t<MT1>;
1658 using ET2 = ElementType_t<MT2>;
1659 using ET3 = ElementType_t<MT3>;
1660 using SIMDType = SIMDTrait_t<ET1>;
1682 constexpr
bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1684 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1685 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
1690 const size_t M( A.rows() );
1691 const size_t N( B.columns() );
1692 const size_t K( A.columns() );
1696 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
1697 DynamicMatrix<ET3,true> B2( KBLOCK, N );
1699 decltype(
auto) c( derestrict( ~C ) );
1704 else if( !
isOne( beta ) ) {
1709 size_t kblock( 0UL );
1711 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1714 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
1717 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1720 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
1721 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
1722 const size_t jsize ( jend - jbegin );
1724 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize,
unchecked ) );
1727 size_t iblock( 0UL );
1731 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
1733 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
1734 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
1739 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock,
unchecked ) );
1743 if( IsFloatingPoint_v<ET3> )
1745 for( ; (j+5UL) <= jsize; j+=5UL )
1747 if( ii+iblock < jbegin )
continue;
1749 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1751 for( ; (i+2UL) <= iblock; i+=2UL )
1753 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1755 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1757 const SIMDType a1( A2.load(i ,k) );
1758 const SIMDType a2( A2.load(i+1UL,k) );
1760 const SIMDType b1( B2.load(k,j ) );
1761 const SIMDType b2( B2.load(k,j+1UL) );
1762 const SIMDType b3( B2.load(k,j+2UL) );
1763 const SIMDType b4( B2.load(k,j+3UL) );
1764 const SIMDType b5( B2.load(k,j+4UL) );
1778 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1779 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1780 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1781 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1782 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1783 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
1784 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
1785 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
1786 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
1787 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
1792 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1794 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1796 const SIMDType a1( A2.load(i,k) );
1798 xmm1 += a1 * B2.load(k,j );
1799 xmm2 += a1 * B2.load(k,j+1UL);
1800 xmm3 += a1 * B2.load(k,j+2UL);
1801 xmm4 += a1 * B2.load(k,j+3UL);
1802 xmm5 += a1 * B2.load(k,j+4UL);
1805 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1806 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1807 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1808 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1809 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1815 for( ; (j+4UL) <= jsize; j+=4UL )
1817 if( ii+iblock < jbegin )
continue;
1819 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1821 for( ; (i+2UL) <= iblock; i+=2UL )
1823 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1825 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1827 const SIMDType a1( A2.load(i ,k) );
1828 const SIMDType a2( A2.load(i+1UL,k) );
1830 const SIMDType b1( B2.load(k,j ) );
1831 const SIMDType b2( B2.load(k,j+1UL) );
1832 const SIMDType b3( B2.load(k,j+2UL) );
1833 const SIMDType b4( B2.load(k,j+3UL) );
1845 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1846 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1847 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1848 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1849 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1850 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1851 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
1852 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
1857 SIMDType xmm1, xmm2, xmm3, xmm4;
1859 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1861 const SIMDType a1( A2.load(i,k) );
1863 xmm1 += a1 * B2.load(k,j );
1864 xmm2 += a1 * B2.load(k,j+1UL);
1865 xmm3 += a1 * B2.load(k,j+2UL);
1866 xmm4 += a1 * B2.load(k,j+3UL);
1869 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1870 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1871 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1872 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1877 for( ; (j+2UL) <= jsize; j+=2UL )
1879 if( ii+iblock < jbegin )
continue;
1881 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1883 for( ; (i+4UL) <= iblock; i+=4UL )
1885 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1887 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1889 const SIMDType a1( A2.load(i ,k) );
1890 const SIMDType a2( A2.load(i+1UL,k) );
1891 const SIMDType a3( A2.load(i+2UL,k) );
1892 const SIMDType a4( A2.load(i+3UL,k) );
1894 const SIMDType b1( B2.load(k,j ) );
1895 const SIMDType b2( B2.load(k,j+1UL) );
1907 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1908 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1909 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1910 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1911 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1912 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1913 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
1914 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
1917 for( ; (i+2UL) <= iblock; i+=2UL )
1919 SIMDType xmm1, xmm2, xmm3, xmm4;
1921 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1923 const SIMDType a1( A2.load(i ,k) );
1924 const SIMDType a2( A2.load(i+1UL,k) );
1926 const SIMDType b1( B2.load(k,j ) );
1927 const SIMDType b2( B2.load(k,j+1UL) );
1935 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1936 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1937 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1938 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1943 SIMDType xmm1, xmm2;
1945 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1947 const SIMDType a1( A2.load(i,k) );
1949 xmm1 += a1 * B2.load(k,j );
1950 xmm2 += a1 * B2.load(k,j+1UL);
1953 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1954 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1958 if( j<jsize && ii+iblock >= jbegin )
1960 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1962 for( ; (i+2UL) <= iblock; i+=2UL )
1964 SIMDType xmm1, xmm2;
1966 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1968 const SIMDType b1( B2.load(k,j) );
1970 xmm1 += A2.load(i ,k) * b1;
1971 xmm2 += A2.load(i+1UL,k) * b1;
1974 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
1975 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
1982 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1984 xmm1 += A2.load(i,k) * B2.load(k,j);
1987 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
1997 if( remainder && kk < K )
1999 const size_t ksize( K - kk );
2001 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2002 const size_t jsize ( N - jbegin );
2007 size_t iblock( 0UL );
2011 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2013 if( IsLower_v<MT2> && ii+iblock <= kk ) {
2022 if( IsFloatingPoint_v<ET1> )
2024 for( ; (j+5UL) <= jsize; j+=5UL )
2026 if( ii+iblock < jbegin )
continue;
2028 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2030 for( ; (i+2UL) <= iblock; i+=2UL ) {
2031 for(
size_t k=0UL; k<ksize; ++k ) {
2032 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2033 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2034 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2035 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2036 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
2037 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2038 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2039 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2040 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2041 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
2046 for(
size_t k=0UL; k<ksize; ++k ) {
2047 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2048 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2049 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2050 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2051 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
2058 for( ; (j+4UL) <= jsize; j+=4UL )
2060 if( ii+iblock < jbegin )
continue;
2062 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2064 for( ; (i+2UL) <= iblock; i+=2UL ) {
2065 for(
size_t k=0UL; k<ksize; ++k ) {
2066 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2067 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2068 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2069 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2070 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2071 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2072 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2073 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2078 for(
size_t k=0UL; k<ksize; ++k ) {
2079 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2080 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2081 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2082 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2088 for( ; (j+2UL) <= jsize; j+=2UL )
2090 if( ii+iblock < jbegin )
continue;
2092 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2094 for( ; (i+2UL) <= iblock; i+=2UL ) {
2095 for(
size_t k=0UL; k<ksize; ++k ) {
2096 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2097 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2098 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2099 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2104 for(
size_t k=0UL; k<ksize; ++k ) {
2105 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2106 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2113 if( ii+iblock < jbegin )
continue;
2115 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2117 for( ; (i+2UL) <= iblock; i+=2UL ) {
2118 for(
size_t k=0UL; k<ksize; ++k ) {
2119 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
2120 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2125 for(
size_t k=0UL; k<ksize; ++k ) {
2126 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
2155 template<
typename MT1,
typename MT2,
typename MT3 >
2156 inline void lmmm( MT1& C,
const MT2& A,
const MT3& B )
2158 using ET1 = ElementType_t<MT1>;
2159 using ET2 = ElementType_t<MT2>;
2160 using ET3 = ElementType_t<MT3>;
2165 lmmm( C, A, B, ET1(1), ET1(0) );
2198 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2199 void ummm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2201 using ET1 = ElementType_t<MT1>;
2202 using ET2 = ElementType_t<MT2>;
2203 using ET3 = ElementType_t<MT3>;
2204 using SIMDType = SIMDTrait_t<ET1>;
2226 constexpr
bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2228 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2229 constexpr
size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
2234 const size_t M( A.rows() );
2235 const size_t N( B.columns() );
2236 const size_t K( A.columns() );
2240 DynamicMatrix<ET2,false> A2( M, KBLOCK );
2241 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
2243 decltype(
auto) c( derestrict( ~C ) );
2248 else if( !
isOne( beta ) ) {
2253 size_t kblock( 0UL );
2255 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2258 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
2261 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2264 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2265 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
2266 const size_t isize ( iend - ibegin );
2268 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock,
unchecked ) );
2271 size_t jblock( 0UL );
2275 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2277 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
2278 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
2283 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock,
unchecked ) );
2287 if( IsFloatingPoint_v<ET1> )
2289 for( ; (i+5UL) <= isize; i+=5UL )
2291 if( jj+jblock < ibegin )
continue;
2293 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2295 for( ; (j+2UL) <= jblock; j+=2UL )
2297 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2299 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2301 const SIMDType a1( A2.load(i ,k) );
2302 const SIMDType a2( A2.load(i+1UL,k) );
2303 const SIMDType a3( A2.load(i+2UL,k) );
2304 const SIMDType a4( A2.load(i+3UL,k) );
2305 const SIMDType a5( A2.load(i+4UL,k) );
2307 const SIMDType b1( B2.load(k,j ) );
2308 const SIMDType b2( B2.load(k,j+1UL) );
2322 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2323 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2324 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2325 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2326 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2327 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2328 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2329 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2330 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
2331 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
2336 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2338 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2340 const SIMDType a1( A2.load(i ,k) );
2341 const SIMDType a2( A2.load(i+1UL,k) );
2342 const SIMDType a3( A2.load(i+2UL,k) );
2343 const SIMDType a4( A2.load(i+3UL,k) );
2344 const SIMDType a5( A2.load(i+4UL,k) );
2346 const SIMDType b1( B2.load(k,j) );
2355 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2356 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2357 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2358 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2359 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
2365 for( ; (i+4UL) <= isize; i+=4UL )
2367 if( jj+jblock < ibegin )
continue;
2369 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2371 for( ; (j+2UL) <= jblock; j+=2UL )
2373 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2375 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2377 const SIMDType a1( A2.load(i ,k) );
2378 const SIMDType a2( A2.load(i+1UL,k) );
2379 const SIMDType a3( A2.load(i+2UL,k) );
2380 const SIMDType a4( A2.load(i+3UL,k) );
2382 const SIMDType b1( B2.load(k,j ) );
2383 const SIMDType b2( B2.load(k,j+1UL) );
2395 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2396 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2397 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2398 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2399 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2400 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2401 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2402 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2407 SIMDType xmm1, xmm2, xmm3, xmm4;
2409 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2411 const SIMDType a1( A2.load(i ,k) );
2412 const SIMDType a2( A2.load(i+1UL,k) );
2413 const SIMDType a3( A2.load(i+2UL,k) );
2414 const SIMDType a4( A2.load(i+3UL,k) );
2416 const SIMDType b1( B2.load(k,j) );
2424 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2425 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2426 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2427 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2432 for( ; (i+2UL) <= isize; i+=2UL )
2434 if( jj+jblock < ibegin )
continue;
2436 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2438 for( ; (j+4UL) <= jblock; j+=4UL )
2440 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2442 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2444 const SIMDType a1( A2.load(i ,k) );
2445 const SIMDType a2( A2.load(i+1UL,k) );
2447 const SIMDType b1( B2.load(k,j ) );
2448 const SIMDType b2( B2.load(k,j+1UL) );
2449 const SIMDType b3( B2.load(k,j+2UL) );
2450 const SIMDType b4( B2.load(k,j+3UL) );
2462 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2463 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2464 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
2465 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
2466 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
2467 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2468 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
2469 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
2472 for( ; (j+2UL) <= jblock; j+=2UL )
2474 SIMDType xmm1, xmm2, xmm3, xmm4;
2476 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2478 const SIMDType a1( A2.load(i ,k) );
2479 const SIMDType a2( A2.load(i+1UL,k) );
2481 const SIMDType b1( B2.load(k,j ) );
2482 const SIMDType b2( B2.load(k,j+1UL) );
2490 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2491 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2492 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2493 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2498 SIMDType xmm1, xmm2;
2500 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2502 const SIMDType a1( A2.load(i ,k) );
2503 const SIMDType a2( A2.load(i+1UL,k) );
2505 const SIMDType b1( B2.load(k,j) );
2511 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2512 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2516 if( i<isize && jj+jblock >= ibegin )
2518 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2520 for( ; (j+2UL) <= jblock; j+=2UL )
2522 SIMDType xmm1, xmm2;
2524 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2526 const SIMDType a1( A2.load(i,k) );
2528 xmm1 += a1 * B2.load(k,j );
2529 xmm2 += a1 * B2.load(k,j+1UL);
2532 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
2533 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2540 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2542 const SIMDType a1( A2.load(i,k) );
2544 xmm1 += a1 * B2.load(k,j);
2547 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
2557 if( remainder && kk < K )
2559 const size_t ksize( K - kk );
2561 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2562 const size_t isize ( M - ibegin );
2567 size_t jblock( 0UL );
2571 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2573 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
2582 if( IsFloatingPoint_v<ET1> )
2584 for( ; (i+5UL) <= isize; i+=5UL )
2586 if( jj+jblock < ibegin )
continue;
2588 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2590 for( ; (j+2UL) <= jblock; j+=2UL ) {
2591 for(
size_t k=0UL; k<ksize; ++k ) {
2592 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2593 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2594 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2595 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2596 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2597 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2598 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2599 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2600 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
2601 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
2606 for(
size_t k=0UL; k<ksize; ++k ) {
2607 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2608 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2609 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2610 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2611 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
2618 for( ; (i+4UL) <= isize; i+=4UL )
2620 if( jj+jblock < ibegin )
continue;
2622 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2624 for( ; (j+2UL) <= jblock; j+=2UL ) {
2625 for(
size_t k=0UL; k<ksize; ++k ) {
2626 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2627 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2628 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2629 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2630 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2631 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2632 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2633 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2638 for(
size_t k=0UL; k<ksize; ++k ) {
2639 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2640 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2641 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2642 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2648 for( ; (i+2UL) <= isize; i+=2UL )
2650 if( jj+jblock < ibegin )
continue;
2652 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2654 for( ; (j+2UL) <= jblock; j+=2UL ) {
2655 for(
size_t k=0UL; k<ksize; ++k ) {
2656 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2657 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2658 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2659 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2664 for(
size_t k=0UL; k<ksize; ++k ) {
2665 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2666 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2671 if( i<isize && jj+jblock >= ibegin )
2673 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2675 for( ; (j+2UL) <= jblock; j+=2UL ) {
2676 for(
size_t k=0UL; k<ksize; ++k ) {
2677 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
2678 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2683 for(
size_t k=0UL; k<ksize; ++k ) {
2684 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
2716 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2717 void ummm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2719 using ET1 = ElementType_t<MT1>;
2720 using ET2 = ElementType_t<MT2>;
2721 using ET3 = ElementType_t<MT3>;
2722 using SIMDType = SIMDTrait_t<ET1>;
2744 constexpr
bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2746 constexpr
size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2747 constexpr
size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
2752 const size_t M( A.rows() );
2753 const size_t N( B.columns() );
2754 const size_t K( A.columns() );
2758 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
2759 DynamicMatrix<ET3,true> B2( KBLOCK, N );
2761 decltype(
auto) c( derestrict( ~C ) );
2766 else if( !
isOne( beta ) ) {
2771 size_t kblock( 0UL );
2773 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2776 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( ( K - kk ) &
size_t(-SIMDSIZE) ) );
2779 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2782 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2783 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
2784 const size_t jsize ( jend - jbegin );
2786 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize,
unchecked ) );
2789 size_t iblock( 0UL );
2793 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2795 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
2796 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
2801 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock,
unchecked ) );
2805 if( IsFloatingPoint_v<ET3> )
2807 for( ; (j+5UL) <= jsize; j+=5UL )
2809 if( ii > jbegin+j+4UL )
continue;
2811 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
2814 for( ; (i+2UL) <= iend; i+=2UL )
2816 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2818 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2820 const SIMDType a1( A2.load(i ,k) );
2821 const SIMDType a2( A2.load(i+1UL,k) );
2823 const SIMDType b1( B2.load(k,j ) );
2824 const SIMDType b2( B2.load(k,j+1UL) );
2825 const SIMDType b3( B2.load(k,j+2UL) );
2826 const SIMDType b4( B2.load(k,j+3UL) );
2827 const SIMDType b5( B2.load(k,j+4UL) );
2841 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2842 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2843 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2844 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2845 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2846 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
2847 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
2848 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
2849 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
2850 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
2855 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2857 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2859 const SIMDType a1( A2.load(i,k) );
2861 xmm1 += a1 * B2.load(k,j );
2862 xmm2 += a1 * B2.load(k,j+1UL);
2863 xmm3 += a1 * B2.load(k,j+2UL);
2864 xmm4 += a1 * B2.load(k,j+3UL);
2865 xmm5 += a1 * B2.load(k,j+4UL);
2868 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2869 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2870 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2871 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2872 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2878 for( ; (j+4UL) <= jsize; j+=4UL )
2880 if( ii > jbegin+j+3UL )
continue;
2882 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
2885 for( ; (i+2UL) <= iend; i+=2UL )
2887 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2889 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2891 const SIMDType a1( A2.load(i ,k) );
2892 const SIMDType a2( A2.load(i+1UL,k) );
2894 const SIMDType b1( B2.load(k,j ) );
2895 const SIMDType b2( B2.load(k,j+1UL) );
2896 const SIMDType b3( B2.load(k,j+2UL) );
2897 const SIMDType b4( B2.load(k,j+3UL) );
2909 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2910 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2911 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2912 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2913 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2914 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2915 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
2916 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
2921 SIMDType xmm1, xmm2, xmm3, xmm4;
2923 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2925 const SIMDType a1( A2.load(i,k) );
2927 xmm1 += a1 * B2.load(k,j );
2928 xmm2 += a1 * B2.load(k,j+1UL);
2929 xmm3 += a1 * B2.load(k,j+2UL);
2930 xmm4 += a1 * B2.load(k,j+3UL);
2933 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2934 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2935 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2936 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2941 for( ; (j+2UL) <= jsize; j+=2UL )
2943 if( ii > jbegin+j+1UL )
continue;
2945 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
2948 for( ; (i+4UL) <= iend; i+=4UL )
2950 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2952 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2954 const SIMDType a1( A2.load(i ,k) );
2955 const SIMDType a2( A2.load(i+1UL,k) );
2956 const SIMDType a3( A2.load(i+2UL,k) );
2957 const SIMDType a4( A2.load(i+3UL,k) );
2959 const SIMDType b1( B2.load(k,j ) );
2960 const SIMDType b2( B2.load(k,j+1UL) );
2972 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2973 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2974 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
2975 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
2976 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2977 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2978 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
2979 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
2982 for( ; (i+2UL) <= iend; i+=2UL )
2984 SIMDType xmm1, xmm2, xmm3, xmm4;
2986 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2988 const SIMDType a1( A2.load(i ,k) );
2989 const SIMDType a2( A2.load(i+1UL,k) );
2991 const SIMDType b1( B2.load(k,j ) );
2992 const SIMDType b2( B2.load(k,j+1UL) );
3000 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
3001 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3002 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
3003 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
3008 SIMDType xmm1, xmm2;
3010 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3012 const SIMDType a1( A2.load(i,k) );
3014 xmm1 += a1 * B2.load(k,j );
3015 xmm2 += a1 * B2.load(k,j+1UL);
3018 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
3019 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3023 if( j<jsize && ii <= jbegin+j )
3025 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3028 for( ; (i+2UL) <= iend; i+=2UL )
3030 SIMDType xmm1, xmm2;
3032 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3034 const SIMDType b1( B2.load(k,j) );
3036 xmm1 += A2.load(i ,k) * b1;
3037 xmm2 += A2.load(i+1UL,k) * b1;
3040 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
3041 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
3048 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3050 xmm1 += A2.load(i,k) * B2.load(k,j);
3053 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
3063 if( remainder && kk < K )
3065 const size_t ksize( K - kk );
3067 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
3068 const size_t jsize ( N - jbegin );
3073 size_t iblock( 0UL );
3077 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
3079 if( IsLower_v<MT2> && ii+iblock <= kk ) {
3088 if( IsFloatingPoint_v<ET1> )
3090 for( ; (j+5UL) <= jsize; j+=5UL )
3092 if( ii > jbegin+j+4UL )
continue;
3094 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
3097 for( ; (i+2UL) <= iend; i+=2UL ) {
3098 for(
size_t k=0UL; k<ksize; ++k ) {
3099 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3100 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3101 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3102 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3103 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
3104 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3105 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3106 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3107 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3108 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
3113 for(
size_t k=0UL; k<ksize; ++k ) {
3114 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3115 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3116 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3117 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3118 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
3125 for( ; (j+4UL) <= jsize; j+=4UL )
3127 if( ii > jbegin+j+3UL )
continue;
3129 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
3132 for( ; (i+2UL) <= iend; i+=2UL ) {
3133 for(
size_t k=0UL; k<ksize; ++k ) {
3134 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3135 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3136 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3137 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3138 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3139 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3140 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3141 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3146 for(
size_t k=0UL; k<ksize; ++k ) {
3147 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3148 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3149 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3150 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3156 for( ; (j+2UL) <= jsize; j+=2UL )
3158 if( ii > jbegin+j+1UL )
continue;
3160 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3163 for( ; (i+2UL) <= iend; i+=2UL ) {
3164 for(
size_t k=0UL; k<ksize; ++k ) {
3165 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3166 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3167 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3168 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3173 for(
size_t k=0UL; k<ksize; ++k ) {
3174 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3175 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3180 if( j<jsize && ii <= jbegin+j )
3182 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3185 for( ; (i+2UL) <= iend; i+=2UL ) {
3186 for(
size_t k=0UL; k<ksize; ++k ) {
3187 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
3188 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
3193 for(
size_t k=0UL; k<ksize; ++k ) {
3194 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
3223 template<
typename MT1,
typename MT2,
typename MT3 >
3224 inline void ummm( MT1& C,
const MT2& A,
const MT3& B )
3226 using ET1 = ElementType_t<MT1>;
3227 using ET2 = ElementType_t<MT2>;
3228 using ET3 = ElementType_t<MT3>;
3233 ummm( C, A, B, ET1(1), ET1(0) );
3265 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3266 void smmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3268 using ET1 = ElementType_t<MT1>;
3269 using ET2 = ElementType_t<MT2>;
3270 using ET3 = ElementType_t<MT3>;
3286 const size_t M( A.rows() );
3287 const size_t N( B.columns() );
3291 lmmm( C, A, B, alpha, ST(0) );
3293 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3295 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3297 for(
size_t i=ii; i<iend; ++i ) {
3298 for(
size_t j=i+1UL; j<iend; ++j ) {
3299 (~C)(i,j) = (~C)(j,i);
3303 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3304 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3305 for(
size_t i=ii; i<iend; ++i ) {
3306 for(
size_t j=jj; j<jend; ++j ) {
3307 (~C)(i,j) = (~C)(j,i);
3335 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3336 void smmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3338 using ET1 = ElementType_t<MT1>;
3339 using ET2 = ElementType_t<MT2>;
3340 using ET3 = ElementType_t<MT3>;
3356 const size_t M( A.rows() );
3357 const size_t N( B.columns() );
3361 ummm( C, A, B, alpha, ST(0) );
3363 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3365 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3367 for(
size_t j=jj; j<jend; ++j ) {
3368 for(
size_t i=jj+1UL; i<jend; ++i ) {
3369 (~C)(i,j) = (~C)(j,i);
3373 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3374 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3375 for(
size_t j=jj; j<jend; ++j ) {
3376 for(
size_t i=ii; i<iend; ++i ) {
3377 (~C)(i,j) = (~C)(j,i);
3403 template<
typename MT1,
typename MT2,
typename MT3 >
3404 inline void smmm( MT1& C,
const MT2& A,
const MT3& B )
3406 using ET1 = ElementType_t<MT1>;
3407 using ET2 = ElementType_t<MT2>;
3408 using ET3 = ElementType_t<MT3>;
3413 smmm( C, A, B, ET1(1) );
3445 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3446 void hmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3448 using ET1 = ElementType_t<MT1>;
3449 using ET2 = ElementType_t<MT2>;
3450 using ET3 = ElementType_t<MT3>;
3466 const size_t M( A.rows() );
3467 const size_t N( B.columns() );
3471 lmmm( C, A, B, alpha, ST(0) );
3473 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3475 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3477 for(
size_t i=ii; i<iend; ++i ) {
3478 for(
size_t j=i+1UL; j<iend; ++j ) {
3479 (~C)(i,j) =
conj( (~C)(j,i) );
3483 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3484 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3485 for(
size_t i=ii; i<iend; ++i ) {
3486 for(
size_t j=jj; j<jend; ++j ) {
3487 (~C)(i,j) =
conj( (~C)(j,i) );
3515 template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3516 void hmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3518 using ET1 = ElementType_t<MT1>;
3519 using ET2 = ElementType_t<MT2>;
3520 using ET3 = ElementType_t<MT3>;
3536 const size_t M( A.rows() );
3537 const size_t N( B.columns() );
3541 ummm( C, A, B, alpha, ST(0) );
3543 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3545 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3547 for(
size_t j=jj; j<jend; ++j ) {
3548 for(
size_t i=jj+1UL; i<jend; ++i ) {
3549 (~C)(i,j) =
conj( (~C)(j,i) );
3553 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3554 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3555 for(
size_t j=jj; j<jend; ++j ) {
3556 for(
size_t i=ii; i<iend; ++i ) {
3557 (~C)(i,j) =
conj( (~C)(j,i) );
3583 template<
typename MT1,
typename MT2,
typename MT3 >
3584 inline void hmmm( MT1& C,
const MT2& A,
const MT3& B )
3586 using ET1 = ElementType_t<MT1>;
3587 using ET2 = ElementType_t<MT2>;
3588 using ET3 = ElementType_t<MT3>;
3593 hmmm( C, A, B, ET1(1) );
Header file for the implementation of the Submatrix view.
Constraint on the data type.
Header file for auxiliary alias declarations.
Headerfile for the generic min algorithm.
Header file for the blaze::checked and blaze::unchecked instances.
Header file for kernel specific block sizes.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a strictly upper triangular matrix type,...
Definition: StrictlyUpper.h:81
Header file for basic type definitions.
decltype(auto) submatrix(Matrix< MT, SO > &, RSAs...)
Creating a view on a specific submatrix of the given matrix.
Definition: Submatrix.h:178
Header file for the serial shim.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.In case the given data type T is a computational expression (i....
Definition: Computation.h:81
#define BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a dense, N-dimensional matrix type,...
Definition: DenseMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a upper unitriangular matrix type,...
Definition: UniUpper.h:81
void reset(const DiagonalProxy< MT > &proxy)
Resetting the represented element to the default initial values.
Definition: DiagonalProxy.h:595
constexpr Unchecked unchecked
Global Unchecked instance.The blaze::unchecked instance is an optional token for the creation of view...
Definition: Check.h:138
Constraint on the data type.
Constraints on the storage order of matrix types.
Constraint on the data type.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE(T)
Constraint on the data type.In case the given data type T is an adaptor type (as for instance LowerMa...
Definition: Adaptor.h:81
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a strictly lower triangular matrix type,...
Definition: StrictlyLower.h:81
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES(T1, T2)
Constraint on the data type.In case the given data types T1 and T2 are not SIMD combinable (i....
Definition: SIMDCombinable.h:61
Namespace of the Blaze C++ math library.
Definition: Blaze.h:58
Header file for the IsFloatingPoint type trait.
#define BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a column-major dense or sparse matri...
Definition: ColumnMajorMatrix.h:61
decltype(auto) min(const DenseMatrix< MT1, SO1 > &lhs, const DenseMatrix< MT2, SO2 > &rhs)
Computes the componentwise minimum of the dense matrices lhs and rhs.
Definition: DMatDMatMapExpr.h:1162
decltype(auto) sum(const DenseMatrix< MT, SO > &dm)
Reduces the given dense matrix by means of addition.
Definition: DMatReduceExpr.h:2147
Header file for the DenseMatrix base class.
Header file for all SIMD functionality.
Header file for the IsLower type trait.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a upper triangular matrix type,...
Definition: Upper.h:81
Header file for the implementation of a dynamic MxN matrix.
Constraint on the data type.
Header file for the IsPadded type trait.
Header file for the isOne shim.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a symmetric matrix type,...
Definition: Symmetric.h:79
#define BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is not a row-major dense or sparse matrix t...
Definition: RowMajorMatrix.h:61
Header file for run time assertion macros.
Constraint on the data type.
Constraint on the data type.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a lower triangular matrix type,...
Definition: Lower.h:81
bool isOne(const DiagonalProxy< MT > &proxy)
Returns whether the represented element is 1.
Definition: DiagonalProxy.h:697
Header file for the isDefault shim.
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is a lower unitriangular matrix type,...
Definition: UniLower.h:81
constexpr size_t size(const Matrix< MT, SO > &matrix) noexcept
Returns the total number of elements of the matrix.
Definition: Matrix.h:530
Constraint on the data type.
Constraints on the storage order of matrix types.
decltype(auto) serial(const DenseMatrix< MT, SO > &dm)
Forces the serial evaluation of the given dense matrix expression dm.
Definition: DMatSerialExpr.h:808
bool isDefault(const DiagonalProxy< MT > &proxy)
Returns whether the represented element is in default state.
Definition: DiagonalProxy.h:635
#define BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE(T)
Constraint on the data type.In case the given data type T is an Hermitian matrix type,...
Definition: Hermitian.h:79
Header file for the IsUpper type trait.
decltype(auto) conj(const DenseMatrix< MT, SO > &dm)
Returns a matrix containing the complex conjugate of each single element of dm.
Definition: DMatMapExpr.h:1324
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.In case of an invalid compile time expression, a compilation error is cr...
Definition: StaticAssert.h:112
#define BLAZE_INTERNAL_ASSERT(expr, msg)
Run time assertion macro for internal checks.In case of an invalid run time expression,...
Definition: Assert.h:101
Constraint on the data type.