35#ifndef _BLAZE_MATH_DENSE_MMM_H_
36#define _BLAZE_MATH_DENSE_MMM_H_
105template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
106void mmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
108 using ET1 = ElementType_t<MT1>;
109 using ET2 = ElementType_t<MT2>;
110 using ET3 = ElementType_t<MT3>;
111 using SIMDType = SIMDTrait_t<ET1>;
129 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
131 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
132 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
137 const size_t M( A.rows() );
138 const size_t N( B.columns() );
139 const size_t K( A.columns() );
143 DynamicMatrix<ET2,false> A2( M, KBLOCK );
144 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
149 else if( !
isOne( beta ) ) {
154 size_t kblock( 0UL );
156 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
159 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):(
prevMultiple( K - kk, SIMDSIZE ) ) );
162 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
165 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
166 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
167 const size_t isize ( iend - ibegin );
169 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock,
unchecked ) );
172 size_t jblock( 0UL );
176 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
178 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
179 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
184 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock,
unchecked ) );
188 if( IsFloatingPoint_v<ET1> )
190 for( ; (i+5UL) <= isize; i+=5UL )
194 for( ; (j+2UL) <= jblock; j+=2UL )
196 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
198 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
200 const SIMDType a1( A2.load(i ,k) );
201 const SIMDType a2( A2.load(i+1UL,k) );
202 const SIMDType a3( A2.load(i+2UL,k) );
203 const SIMDType a4( A2.load(i+3UL,k) );
204 const SIMDType a5( A2.load(i+4UL,k) );
206 const SIMDType b1( B2.load(k,j ) );
207 const SIMDType b2( B2.load(k,j+1UL) );
221 (*C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
222 (*C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
223 (*C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
224 (*C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
225 (*C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
226 (*C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
227 (*C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
228 (*C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
229 (*C)(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
230 (*C)(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
235 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
237 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
239 const SIMDType a1( A2.load(i ,k) );
240 const SIMDType a2( A2.load(i+1UL,k) );
241 const SIMDType a3( A2.load(i+2UL,k) );
242 const SIMDType a4( A2.load(i+3UL,k) );
243 const SIMDType a5( A2.load(i+4UL,k) );
245 const SIMDType b1( B2.load(k,j) );
254 (*C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
255 (*C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
256 (*C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
257 (*C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
258 (*C)(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
264 for( ; (i+4UL) <= isize; i+=4UL )
268 for( ; (j+2UL) <= jblock; j+=2UL )
270 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
272 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
274 const SIMDType a1( A2.load(i ,k) );
275 const SIMDType a2( A2.load(i+1UL,k) );
276 const SIMDType a3( A2.load(i+2UL,k) );
277 const SIMDType a4( A2.load(i+3UL,k) );
279 const SIMDType b1( B2.load(k,j ) );
280 const SIMDType b2( B2.load(k,j+1UL) );
292 (*C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
293 (*C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
294 (*C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
295 (*C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
296 (*C)(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
297 (*C)(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
298 (*C)(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
299 (*C)(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
304 SIMDType xmm1, xmm2, xmm3, xmm4;
306 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
308 const SIMDType a1( A2.load(i ,k) );
309 const SIMDType a2( A2.load(i+1UL,k) );
310 const SIMDType a3( A2.load(i+2UL,k) );
311 const SIMDType a4( A2.load(i+3UL,k) );
313 const SIMDType b1( B2.load(k,j) );
321 (*C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
322 (*C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
323 (*C)(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
324 (*C)(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
329 for( ; (i+2UL) <= isize; i+=2UL )
333 for( ; (j+4UL) <= jblock; j+=4UL )
335 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
337 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
339 const SIMDType a1( A2.load(i ,k) );
340 const SIMDType a2( A2.load(i+1UL,k) );
342 const SIMDType b1( B2.load(k,j ) );
343 const SIMDType b2( B2.load(k,j+1UL) );
344 const SIMDType b3( B2.load(k,j+2UL) );
345 const SIMDType b4( B2.load(k,j+3UL) );
357 (*C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
358 (*C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
359 (*C)(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
360 (*C)(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
361 (*C)(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
362 (*C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
363 (*C)(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
364 (*C)(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
367 for( ; (j+2UL) <= jblock; j+=2UL )
369 SIMDType xmm1, xmm2, xmm3, xmm4;
371 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
373 const SIMDType a1( A2.load(i ,k) );
374 const SIMDType a2( A2.load(i+1UL,k) );
376 const SIMDType b1( B2.load(k,j ) );
377 const SIMDType b2( B2.load(k,j+1UL) );
385 (*C)(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
386 (*C)(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
387 (*C)(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
388 (*C)(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
395 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
397 const SIMDType a1( A2.load(i ,k) );
398 const SIMDType a2( A2.load(i+1UL,k) );
400 const SIMDType b1( B2.load(k,j) );
406 (*C)(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
407 (*C)(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
415 for( ; (j+2UL) <= jblock; j+=2UL )
419 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
421 const SIMDType a1( A2.load(i,k) );
423 xmm1 += a1 * B2.load(k,j );
424 xmm2 += a1 * B2.load(k,j+1UL);
427 (*C)(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
428 (*C)(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
435 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
437 const SIMDType a1( A2.load(i,k) );
439 xmm1 += a1 * B2.load(k,j);
442 (*C)(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
452 if( remainder && kk < K )
454 const size_t ksize( K - kk );
456 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
457 const size_t isize ( M - ibegin );
462 size_t jblock( 0UL );
466 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
468 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
477 if( IsFloatingPoint_v<ET1> )
479 for( ; (i+5UL) <= isize; i+=5UL )
483 for( ; (j+2UL) <= jblock; j+=2UL ) {
484 for(
size_t k=0UL; k<ksize; ++k ) {
485 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
486 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
487 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
488 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
489 (*C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
490 (*C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
491 (*C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
492 (*C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
493 (*C)(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
494 (*C)(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
499 for(
size_t k=0UL; k<ksize; ++k ) {
500 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
501 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
502 (*C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
503 (*C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
504 (*C)(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
511 for( ; (i+4UL) <= isize; i+=4UL )
515 for( ; (j+2UL) <= jblock; j+=2UL ) {
516 for(
size_t k=0UL; k<ksize; ++k ) {
517 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
518 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
519 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
520 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
521 (*C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
522 (*C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
523 (*C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
524 (*C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
529 for(
size_t k=0UL; k<ksize; ++k ) {
530 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
531 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
532 (*C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
533 (*C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
539 for( ; (i+2UL) <= isize; i+=2UL )
543 for( ; (j+2UL) <= jblock; j+=2UL ) {
544 for(
size_t k=0UL; k<ksize; ++k ) {
545 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
546 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
547 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
548 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
553 for(
size_t k=0UL; k<ksize; ++k ) {
554 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
555 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
564 for( ; (j+2UL) <= jblock; j+=2UL ) {
565 for(
size_t k=0UL; k<ksize; ++k ) {
566 (*C)(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
567 (*C)(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
572 for(
size_t k=0UL; k<ksize; ++k ) {
573 (*C)(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
605template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
606void mmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
608 using ET1 = ElementType_t<MT1>;
609 using ET2 = ElementType_t<MT2>;
610 using ET3 = ElementType_t<MT3>;
611 using SIMDType = SIMDTrait_t<ET1>;
629 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
631 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
632 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
637 const size_t M( A.rows() );
638 const size_t N( B.columns() );
639 const size_t K( A.columns() );
643 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
644 DynamicMatrix<ET3,true> B2( KBLOCK, N );
649 else if( !
isOne( beta ) ) {
654 size_t kblock( 0UL );
656 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
659 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):(
prevMultiple( K - kk, SIMDSIZE ) ) );
662 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
665 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
666 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
667 const size_t jsize ( jend - jbegin );
669 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize,
unchecked ) );
672 size_t iblock( 0UL );
676 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
678 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
679 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
684 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock,
unchecked ) );
688 if( IsFloatingPoint_v<ET3> )
690 for( ; (j+5UL) <= jsize; j+=5UL )
694 for( ; (i+2UL) <= iblock; i+=2UL )
696 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
698 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
700 const SIMDType a1( A2.load(i ,k) );
701 const SIMDType a2( A2.load(i+1UL,k) );
703 const SIMDType b1( B2.load(k,j ) );
704 const SIMDType b2( B2.load(k,j+1UL) );
705 const SIMDType b3( B2.load(k,j+2UL) );
706 const SIMDType b4( B2.load(k,j+3UL) );
707 const SIMDType b5( B2.load(k,j+4UL) );
721 (*C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
722 (*C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
723 (*C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
724 (*C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
725 (*C)(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
726 (*C)(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
727 (*C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
728 (*C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
729 (*C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
730 (*C)(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
735 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
737 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
739 const SIMDType a1( A2.load(i,k) );
741 xmm1 += a1 * B2.load(k,j );
742 xmm2 += a1 * B2.load(k,j+1UL);
743 xmm3 += a1 * B2.load(k,j+2UL);
744 xmm4 += a1 * B2.load(k,j+3UL);
745 xmm5 += a1 * B2.load(k,j+4UL);
748 (*C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
749 (*C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
750 (*C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
751 (*C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
752 (*C)(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
758 for( ; (j+4UL) <= jsize; j+=4UL )
762 for( ; (i+2UL) <= iblock; i+=2UL )
764 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
766 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
768 const SIMDType a1( A2.load(i ,k) );
769 const SIMDType a2( A2.load(i+1UL,k) );
771 const SIMDType b1( B2.load(k,j ) );
772 const SIMDType b2( B2.load(k,j+1UL) );
773 const SIMDType b3( B2.load(k,j+2UL) );
774 const SIMDType b4( B2.load(k,j+3UL) );
786 (*C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
787 (*C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
788 (*C)(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
789 (*C)(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
790 (*C)(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
791 (*C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
792 (*C)(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
793 (*C)(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
798 SIMDType xmm1, xmm2, xmm3, xmm4;
800 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
802 const SIMDType a1( A2.load(i,k) );
804 xmm1 += a1 * B2.load(k,j );
805 xmm2 += a1 * B2.load(k,j+1UL);
806 xmm3 += a1 * B2.load(k,j+2UL);
807 xmm4 += a1 * B2.load(k,j+3UL);
810 (*C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
811 (*C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
812 (*C)(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
813 (*C)(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
818 for( ; (j+2UL) <= jsize; j+=2UL )
822 for( ; (i+4UL) <= iblock; i+=4UL )
824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
826 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
828 const SIMDType a1( A2.load(i ,k) );
829 const SIMDType a2( A2.load(i+1UL,k) );
830 const SIMDType a3( A2.load(i+2UL,k) );
831 const SIMDType a4( A2.load(i+3UL,k) );
833 const SIMDType b1( B2.load(k,j ) );
834 const SIMDType b2( B2.load(k,j+1UL) );
846 (*C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
847 (*C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
848 (*C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
849 (*C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
850 (*C)(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
851 (*C)(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
852 (*C)(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
853 (*C)(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
856 for( ; (i+2UL) <= iblock; i+=2UL )
858 SIMDType xmm1, xmm2, xmm3, xmm4;
860 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
862 const SIMDType a1( A2.load(i ,k) );
863 const SIMDType a2( A2.load(i+1UL,k) );
865 const SIMDType b1( B2.load(k,j ) );
866 const SIMDType b2( B2.load(k,j+1UL) );
874 (*C)(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
875 (*C)(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
876 (*C)(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
877 (*C)(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
884 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
886 const SIMDType a1( A2.load(i,k) );
888 xmm1 += a1 * B2.load(k,j );
889 xmm2 += a1 * B2.load(k,j+1UL);
892 (*C)(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
893 (*C)(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
901 for( ; (i+2UL) <= iblock; i+=2UL )
905 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
907 const SIMDType b1( B2.load(k,j) );
909 xmm1 += A2.load(i ,k) * b1;
910 xmm2 += A2.load(i+1UL,k) * b1;
913 (*C)(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
914 (*C)(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
921 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
923 xmm1 += A2.load(i,k) * B2.load(k,j);
926 (*C)(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
936 if( remainder && kk < K )
938 const size_t ksize( K - kk );
940 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
941 const size_t jsize ( N - jbegin );
946 size_t iblock( 0UL );
950 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
952 if( IsLower_v<MT2> && ii+iblock <= kk ) {
961 if( IsFloatingPoint_v<ET1> )
963 for( ; (j+5UL) <= jsize; j+=5UL )
967 for( ; (i+2UL) <= iblock; i+=2UL ) {
968 for(
size_t k=0UL; k<ksize; ++k ) {
969 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
970 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
971 (*C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
972 (*C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
973 (*C)(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
974 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
975 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
976 (*C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
977 (*C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
978 (*C)(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
983 for(
size_t k=0UL; k<ksize; ++k ) {
984 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
985 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
986 (*C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
987 (*C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
988 (*C)(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
995 for( ; (j+4UL) <= jsize; j+=4UL )
999 for( ; (i+2UL) <= iblock; i+=2UL ) {
1000 for(
size_t k=0UL; k<ksize; ++k ) {
1001 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1002 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1003 (*C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
1004 (*C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
1005 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1006 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1007 (*C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
1008 (*C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
1013 for(
size_t k=0UL; k<ksize; ++k ) {
1014 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1015 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1016 (*C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
1017 (*C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
1023 for( ; (j+2UL) <= jsize; j+=2UL )
1027 for( ; (i+2UL) <= iblock; i+=2UL ) {
1028 for(
size_t k=0UL; k<ksize; ++k ) {
1029 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1030 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1031 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1032 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1037 for(
size_t k=0UL; k<ksize; ++k ) {
1038 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1039 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1048 for( ; (i+2UL) <= iblock; i+=2UL ) {
1049 for(
size_t k=0UL; k<ksize; ++k ) {
1050 (*C)(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
1051 (*C)(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1056 for(
size_t k=0UL; k<ksize; ++k ) {
1057 (*C)(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
1086template<
typename MT1,
typename MT2,
typename MT3 >
1087inline void mmm( MT1& C,
const MT2& A,
const MT3& B )
1089 using ET1 = ElementType_t<MT1>;
1090 using ET2 = ElementType_t<MT2>;
1091 using ET3 = ElementType_t<MT3>;
1096 mmm( C, A, B, ET1(1), ET1(0) );
1129template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1130void lmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1132 using ET1 = ElementType_t<MT1>;
1133 using ET2 = ElementType_t<MT2>;
1134 using ET3 = ElementType_t<MT3>;
1135 using SIMDType = SIMDTrait_t<ET1>;
1157 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1159 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1160 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
1165 const size_t M( A.rows() );
1166 const size_t N( B.columns() );
1167 const size_t K( A.columns() );
1171 DynamicMatrix<ET2,false> A2( M, KBLOCK );
1172 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
1174 decltype(
auto) c( derestrict( *C ) );
1179 else if( !
isOne( beta ) ) {
1184 size_t kblock( 0UL );
1186 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1189 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):(
prevMultiple( K - kk, SIMDSIZE ) ) );
1192 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1195 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1196 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
1197 const size_t isize ( iend - ibegin );
1199 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock,
unchecked ) );
1202 size_t jblock( 0UL );
1206 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1208 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
1209 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
1214 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock,
unchecked ) );
1218 if( IsFloatingPoint_v<ET1> )
1220 for( ; (i+5UL) <= isize; i+=5UL )
1222 if( jj > ibegin+i+4UL )
continue;
1224 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1227 for( ; (j+2UL) <= jend; j+=2UL )
1229 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1231 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1233 const SIMDType a1( A2.load(i ,k) );
1234 const SIMDType a2( A2.load(i+1UL,k) );
1235 const SIMDType a3( A2.load(i+2UL,k) );
1236 const SIMDType a4( A2.load(i+3UL,k) );
1237 const SIMDType a5( A2.load(i+4UL,k) );
1239 const SIMDType b1( B2.load(k,j ) );
1240 const SIMDType b2( B2.load(k,j+1UL) );
1254 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1255 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1256 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1257 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1258 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1259 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1260 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1261 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1262 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
1263 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
1268 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1270 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1272 const SIMDType a1( A2.load(i ,k) );
1273 const SIMDType a2( A2.load(i+1UL,k) );
1274 const SIMDType a3( A2.load(i+2UL,k) );
1275 const SIMDType a4( A2.load(i+3UL,k) );
1276 const SIMDType a5( A2.load(i+4UL,k) );
1278 const SIMDType b1( B2.load(k,j) );
1287 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1288 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1289 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1290 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1291 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
1297 for( ; (i+4UL) <= isize; i+=4UL )
1299 if( jj > ibegin+i+3UL )
continue;
1301 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1304 for( ; (j+2UL) <= jend; j+=2UL )
1306 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1308 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1310 const SIMDType a1( A2.load(i ,k) );
1311 const SIMDType a2( A2.load(i+1UL,k) );
1312 const SIMDType a3( A2.load(i+2UL,k) );
1313 const SIMDType a4( A2.load(i+3UL,k) );
1315 const SIMDType b1( B2.load(k,j ) );
1316 const SIMDType b2( B2.load(k,j+1UL) );
1328 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1329 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1330 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1331 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1332 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
1333 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1334 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
1335 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
1340 SIMDType xmm1, xmm2, xmm3, xmm4;
1342 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1344 const SIMDType a1( A2.load(i ,k) );
1345 const SIMDType a2( A2.load(i+1UL,k) );
1346 const SIMDType a3( A2.load(i+2UL,k) );
1347 const SIMDType a4( A2.load(i+3UL,k) );
1349 const SIMDType b1( B2.load(k,j) );
1357 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1358 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1359 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
1360 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
1365 for( ; (i+2UL) <= isize; i+=2UL )
1367 if( jj > ibegin+i+1UL )
continue;
1369 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1372 for( ; (j+4UL) <= jend; j+=4UL )
1374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1376 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1378 const SIMDType a1( A2.load(i ,k) );
1379 const SIMDType a2( A2.load(i+1UL,k) );
1381 const SIMDType b1( B2.load(k,j ) );
1382 const SIMDType b2( B2.load(k,j+1UL) );
1383 const SIMDType b3( B2.load(k,j+2UL) );
1384 const SIMDType b4( B2.load(k,j+3UL) );
1396 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1397 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1398 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
1399 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
1400 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
1401 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
1402 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
1403 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
1406 for( ; (j+2UL) <= jend; j+=2UL )
1408 SIMDType xmm1, xmm2, xmm3, xmm4;
1410 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1412 const SIMDType a1( A2.load(i ,k) );
1413 const SIMDType a2( A2.load(i+1UL,k) );
1415 const SIMDType b1( B2.load(k,j ) );
1416 const SIMDType b2( B2.load(k,j+1UL) );
1424 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
1425 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1426 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
1427 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
1432 SIMDType xmm1, xmm2;
1434 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1436 const SIMDType a1( A2.load(i ,k) );
1437 const SIMDType a2( A2.load(i+1UL,k) );
1439 const SIMDType b1( B2.load(k,j) );
1445 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
1446 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
1450 if( i<isize && jj <= ibegin+i )
1452 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1455 for( ; (j+2UL) <= jend; j+=2UL )
1457 SIMDType xmm1, xmm2;
1459 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1461 const SIMDType a1( A2.load(i,k) );
1463 xmm1 += a1 * B2.load(k,j );
1464 xmm2 += a1 * B2.load(k,j+1UL);
1467 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
1468 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
1475 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1477 const SIMDType a1( A2.load(i,k) );
1479 xmm1 += a1 * B2.load(k,j);
1482 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
1492 if( remainder && kk < K )
1494 const size_t ksize( K - kk );
1496 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1497 const size_t isize ( M - ibegin );
1502 size_t jblock( 0UL );
1506 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1508 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
1517 if( IsFloatingPoint_v<ET1> )
1519 for( ; (i+5UL) <= isize; i+=5UL )
1521 if( jj > ibegin+i+4UL )
continue;
1523 const size_t jend(
min( ibegin+i-jj+5UL, jblock ) );
1526 for( ; (j+2UL) <= jend; j+=2UL ) {
1527 for(
size_t k=0UL; k<ksize; ++k ) {
1528 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1529 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1530 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1531 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1532 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1533 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1534 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1535 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1536 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
1537 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
1542 for(
size_t k=0UL; k<ksize; ++k ) {
1543 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1544 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1545 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1546 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1547 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
1554 for( ; (i+4UL) <= isize; i+=4UL )
1556 if( jj > ibegin+i+3UL )
continue;
1558 const size_t jend(
min( ibegin+i-jj+4UL, jblock ) );
1561 for( ; (j+2UL) <= jend; j+=2UL ) {
1562 for(
size_t k=0UL; k<ksize; ++k ) {
1563 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1564 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1565 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1566 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1567 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1568 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1569 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1570 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1575 for(
size_t k=0UL; k<ksize; ++k ) {
1576 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1577 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1578 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1579 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1585 for( ; (i+2UL) <= isize; i+=2UL )
1587 if( jj > ibegin+i+1UL )
continue;
1589 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1592 for( ; (j+2UL) <= jend; j+=2UL ) {
1593 for(
size_t k=0UL; k<ksize; ++k ) {
1594 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1595 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1596 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1597 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1602 for(
size_t k=0UL; k<ksize; ++k ) {
1603 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1604 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1609 if( i<isize && jj <= ibegin+i )
1611 const size_t jend(
min( ibegin+i-jj+2UL, jblock ) );
1614 for( ; (j+2UL) <= jend; j+=2UL ) {
1615 for(
size_t k=0UL; k<ksize; ++k ) {
1616 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
1617 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1622 for(
size_t k=0UL; k<ksize; ++k ) {
1623 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
1655template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
1656void lmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
1658 using ET1 = ElementType_t<MT1>;
1659 using ET2 = ElementType_t<MT2>;
1660 using ET3 = ElementType_t<MT3>;
1661 using SIMDType = SIMDTrait_t<ET1>;
1683 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1685 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
1686 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
1691 const size_t M( A.rows() );
1692 const size_t N( B.columns() );
1693 const size_t K( A.columns() );
1697 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
1698 DynamicMatrix<ET3,true> B2( KBLOCK, N );
1700 decltype(
auto) c( derestrict( *C ) );
1705 else if( !
isOne( beta ) ) {
1710 size_t kblock( 0UL );
1712 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1715 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):(
prevMultiple( K - kk, SIMDSIZE ) ) );
1718 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1721 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
1722 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
1723 const size_t jsize ( jend - jbegin );
1725 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize,
unchecked ) );
1728 size_t iblock( 0UL );
1732 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
1734 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
1735 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
1740 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock,
unchecked ) );
1744 if( IsFloatingPoint_v<ET3> )
1746 for( ; (j+5UL) <= jsize; j+=5UL )
1748 if( ii+iblock < jbegin )
continue;
1750 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1752 for( ; (i+2UL) <= iblock; i+=2UL )
1754 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1756 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1758 const SIMDType a1( A2.load(i ,k) );
1759 const SIMDType a2( A2.load(i+1UL,k) );
1761 const SIMDType b1( B2.load(k,j ) );
1762 const SIMDType b2( B2.load(k,j+1UL) );
1763 const SIMDType b3( B2.load(k,j+2UL) );
1764 const SIMDType b4( B2.load(k,j+3UL) );
1765 const SIMDType b5( B2.load(k,j+4UL) );
1779 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1780 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1781 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1782 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1783 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1784 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
1785 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
1786 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
1787 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
1788 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
1793 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1795 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1797 const SIMDType a1( A2.load(i,k) );
1799 xmm1 += a1 * B2.load(k,j );
1800 xmm2 += a1 * B2.load(k,j+1UL);
1801 xmm3 += a1 * B2.load(k,j+2UL);
1802 xmm4 += a1 * B2.load(k,j+3UL);
1803 xmm5 += a1 * B2.load(k,j+4UL);
1806 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1807 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1808 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1809 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1810 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
1816 for( ; (j+4UL) <= jsize; j+=4UL )
1818 if( ii+iblock < jbegin )
continue;
1820 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1822 for( ; (i+2UL) <= iblock; i+=2UL )
1824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1826 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1828 const SIMDType a1( A2.load(i ,k) );
1829 const SIMDType a2( A2.load(i+1UL,k) );
1831 const SIMDType b1( B2.load(k,j ) );
1832 const SIMDType b2( B2.load(k,j+1UL) );
1833 const SIMDType b3( B2.load(k,j+2UL) );
1834 const SIMDType b4( B2.load(k,j+3UL) );
1846 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1847 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1848 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1849 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1850 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1851 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1852 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
1853 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
1858 SIMDType xmm1, xmm2, xmm3, xmm4;
1860 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1862 const SIMDType a1( A2.load(i,k) );
1864 xmm1 += a1 * B2.load(k,j );
1865 xmm2 += a1 * B2.load(k,j+1UL);
1866 xmm3 += a1 * B2.load(k,j+2UL);
1867 xmm4 += a1 * B2.load(k,j+3UL);
1870 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1871 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1872 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
1873 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
1878 for( ; (j+2UL) <= jsize; j+=2UL )
1880 if( ii+iblock < jbegin )
continue;
1882 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1884 for( ; (i+4UL) <= iblock; i+=4UL )
1886 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1888 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1890 const SIMDType a1( A2.load(i ,k) );
1891 const SIMDType a2( A2.load(i+1UL,k) );
1892 const SIMDType a3( A2.load(i+2UL,k) );
1893 const SIMDType a4( A2.load(i+3UL,k) );
1895 const SIMDType b1( B2.load(k,j ) );
1896 const SIMDType b2( B2.load(k,j+1UL) );
1908 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1909 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1910 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1911 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1912 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
1913 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
1914 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
1915 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
1918 for( ; (i+2UL) <= iblock; i+=2UL )
1920 SIMDType xmm1, xmm2, xmm3, xmm4;
1922 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1924 const SIMDType a1( A2.load(i ,k) );
1925 const SIMDType a2( A2.load(i+1UL,k) );
1927 const SIMDType b1( B2.load(k,j ) );
1928 const SIMDType b2( B2.load(k,j+1UL) );
1936 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
1937 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1938 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
1939 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
1944 SIMDType xmm1, xmm2;
1946 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1948 const SIMDType a1( A2.load(i,k) );
1950 xmm1 += a1 * B2.load(k,j );
1951 xmm2 += a1 * B2.load(k,j+1UL);
1954 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
1955 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
1959 if( j<jsize && ii+iblock >= jbegin )
1961 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1963 for( ; (i+2UL) <= iblock; i+=2UL )
1965 SIMDType xmm1, xmm2;
1967 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1969 const SIMDType b1( B2.load(k,j) );
1971 xmm1 += A2.load(i ,k) * b1;
1972 xmm2 += A2.load(i+1UL,k) * b1;
1975 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
1976 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
1983 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
1985 xmm1 += A2.load(i,k) * B2.load(k,j);
1988 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
1998 if( remainder && kk < K )
2000 const size_t ksize( K - kk );
2002 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2003 const size_t jsize ( N - jbegin );
2008 size_t iblock( 0UL );
2012 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2014 if( IsLower_v<MT2> && ii+iblock <= kk ) {
2023 if( IsFloatingPoint_v<ET1> )
2025 for( ; (j+5UL) <= jsize; j+=5UL )
2027 if( ii+iblock < jbegin )
continue;
2029 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2031 for( ; (i+2UL) <= iblock; i+=2UL ) {
2032 for(
size_t k=0UL; k<ksize; ++k ) {
2033 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2034 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2035 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2036 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2037 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
2038 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2039 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2040 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2041 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2042 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
2047 for(
size_t k=0UL; k<ksize; ++k ) {
2048 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2049 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2050 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2051 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2052 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
2059 for( ; (j+4UL) <= jsize; j+=4UL )
2061 if( ii+iblock < jbegin )
continue;
2063 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2065 for( ; (i+2UL) <= iblock; i+=2UL ) {
2066 for(
size_t k=0UL; k<ksize; ++k ) {
2067 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2068 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2069 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2070 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2071 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2072 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2073 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2074 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2079 for(
size_t k=0UL; k<ksize; ++k ) {
2080 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2081 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2082 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2083 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2089 for( ; (j+2UL) <= jsize; j+=2UL )
2091 if( ii+iblock < jbegin )
continue;
2093 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2095 for( ; (i+2UL) <= iblock; i+=2UL ) {
2096 for(
size_t k=0UL; k<ksize; ++k ) {
2097 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2098 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2099 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2100 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2105 for(
size_t k=0UL; k<ksize; ++k ) {
2106 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2107 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2114 if( ii+iblock < jbegin )
continue;
2116 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2118 for( ; (i+2UL) <= iblock; i+=2UL ) {
2119 for(
size_t k=0UL; k<ksize; ++k ) {
2120 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
2121 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2126 for(
size_t k=0UL; k<ksize; ++k ) {
2127 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
2156template<
typename MT1,
typename MT2,
typename MT3 >
2157inline void lmmm( MT1& C,
const MT2& A,
const MT3& B )
2159 using ET1 = ElementType_t<MT1>;
2160 using ET2 = ElementType_t<MT2>;
2161 using ET3 = ElementType_t<MT3>;
2166 lmmm( C, A, B, ET1(1), ET1(0) );
2199template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2200void ummm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2202 using ET1 = ElementType_t<MT1>;
2203 using ET2 = ElementType_t<MT2>;
2204 using ET3 = ElementType_t<MT3>;
2205 using SIMDType = SIMDTrait_t<ET1>;
2227 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2229 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2230 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
2235 const size_t M( A.rows() );
2236 const size_t N( B.columns() );
2237 const size_t K( A.columns() );
2241 DynamicMatrix<ET2,false> A2( M, KBLOCK );
2242 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
2244 decltype(
auto) c( derestrict( *C ) );
2249 else if( !
isOne( beta ) ) {
2254 size_t kblock( 0UL );
2256 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2259 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):(
prevMultiple( K - kk, SIMDSIZE ) ) );
2262 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2265 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2266 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
2267 const size_t isize ( iend - ibegin );
2269 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock,
unchecked ) );
2272 size_t jblock( 0UL );
2276 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2278 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
2279 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
2284 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock,
unchecked ) );
2288 if( IsFloatingPoint_v<ET1> )
2290 for( ; (i+5UL) <= isize; i+=5UL )
2292 if( jj+jblock < ibegin )
continue;
2294 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2296 for( ; (j+2UL) <= jblock; j+=2UL )
2298 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2300 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2302 const SIMDType a1( A2.load(i ,k) );
2303 const SIMDType a2( A2.load(i+1UL,k) );
2304 const SIMDType a3( A2.load(i+2UL,k) );
2305 const SIMDType a4( A2.load(i+3UL,k) );
2306 const SIMDType a5( A2.load(i+4UL,k) );
2308 const SIMDType b1( B2.load(k,j ) );
2309 const SIMDType b2( B2.load(k,j+1UL) );
2323 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2324 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2325 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2326 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2327 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2328 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2329 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2330 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2331 c(ibegin+i+4UL,jj+j ) +=
sum( xmm9 ) * alpha;
2332 c(ibegin+i+4UL,jj+j+1UL) +=
sum( xmm10 ) * alpha;
2337 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2339 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2341 const SIMDType a1( A2.load(i ,k) );
2342 const SIMDType a2( A2.load(i+1UL,k) );
2343 const SIMDType a3( A2.load(i+2UL,k) );
2344 const SIMDType a4( A2.load(i+3UL,k) );
2345 const SIMDType a5( A2.load(i+4UL,k) );
2347 const SIMDType b1( B2.load(k,j) );
2356 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2357 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2358 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2359 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2360 c(ibegin+i+4UL,jj+j) +=
sum( xmm5 ) * alpha;
2366 for( ; (i+4UL) <= isize; i+=4UL )
2368 if( jj+jblock < ibegin )
continue;
2370 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2372 for( ; (j+2UL) <= jblock; j+=2UL )
2374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2376 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2378 const SIMDType a1( A2.load(i ,k) );
2379 const SIMDType a2( A2.load(i+1UL,k) );
2380 const SIMDType a3( A2.load(i+2UL,k) );
2381 const SIMDType a4( A2.load(i+3UL,k) );
2383 const SIMDType b1( B2.load(k,j ) );
2384 const SIMDType b2( B2.load(k,j+1UL) );
2396 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2397 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2398 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2399 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2400 c(ibegin+i+2UL,jj+j ) +=
sum( xmm5 ) * alpha;
2401 c(ibegin+i+2UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2402 c(ibegin+i+3UL,jj+j ) +=
sum( xmm7 ) * alpha;
2403 c(ibegin+i+3UL,jj+j+1UL) +=
sum( xmm8 ) * alpha;
2408 SIMDType xmm1, xmm2, xmm3, xmm4;
2410 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2412 const SIMDType a1( A2.load(i ,k) );
2413 const SIMDType a2( A2.load(i+1UL,k) );
2414 const SIMDType a3( A2.load(i+2UL,k) );
2415 const SIMDType a4( A2.load(i+3UL,k) );
2417 const SIMDType b1( B2.load(k,j) );
2425 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2426 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2427 c(ibegin+i+2UL,jj+j) +=
sum( xmm3 ) * alpha;
2428 c(ibegin+i+3UL,jj+j) +=
sum( xmm4 ) * alpha;
2433 for( ; (i+2UL) <= isize; i+=2UL )
2435 if( jj+jblock < ibegin )
continue;
2437 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2439 for( ; (j+4UL) <= jblock; j+=4UL )
2441 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2443 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2445 const SIMDType a1( A2.load(i ,k) );
2446 const SIMDType a2( A2.load(i+1UL,k) );
2448 const SIMDType b1( B2.load(k,j ) );
2449 const SIMDType b2( B2.load(k,j+1UL) );
2450 const SIMDType b3( B2.load(k,j+2UL) );
2451 const SIMDType b4( B2.load(k,j+3UL) );
2463 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2464 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2465 c(ibegin+i ,jj+j+2UL) +=
sum( xmm3 ) * alpha;
2466 c(ibegin+i ,jj+j+3UL) +=
sum( xmm4 ) * alpha;
2467 c(ibegin+i+1UL,jj+j ) +=
sum( xmm5 ) * alpha;
2468 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm6 ) * alpha;
2469 c(ibegin+i+1UL,jj+j+2UL) +=
sum( xmm7 ) * alpha;
2470 c(ibegin+i+1UL,jj+j+3UL) +=
sum( xmm8 ) * alpha;
2473 for( ; (j+2UL) <= jblock; j+=2UL )
2475 SIMDType xmm1, xmm2, xmm3, xmm4;
2477 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2479 const SIMDType a1( A2.load(i ,k) );
2480 const SIMDType a2( A2.load(i+1UL,k) );
2482 const SIMDType b1( B2.load(k,j ) );
2483 const SIMDType b2( B2.load(k,j+1UL) );
2491 c(ibegin+i ,jj+j ) +=
sum( xmm1 ) * alpha;
2492 c(ibegin+i ,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2493 c(ibegin+i+1UL,jj+j ) +=
sum( xmm3 ) * alpha;
2494 c(ibegin+i+1UL,jj+j+1UL) +=
sum( xmm4 ) * alpha;
2499 SIMDType xmm1, xmm2;
2501 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2503 const SIMDType a1( A2.load(i ,k) );
2504 const SIMDType a2( A2.load(i+1UL,k) );
2506 const SIMDType b1( B2.load(k,j) );
2512 c(ibegin+i ,jj+j) +=
sum( xmm1 ) * alpha;
2513 c(ibegin+i+1UL,jj+j) +=
sum( xmm2 ) * alpha;
2517 if( i<isize && jj+jblock >= ibegin )
2519 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2521 for( ; (j+2UL) <= jblock; j+=2UL )
2523 SIMDType xmm1, xmm2;
2525 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2527 const SIMDType a1( A2.load(i,k) );
2529 xmm1 += a1 * B2.load(k,j );
2530 xmm2 += a1 * B2.load(k,j+1UL);
2533 c(ibegin+i,jj+j ) +=
sum( xmm1 ) * alpha;
2534 c(ibegin+i,jj+j+1UL) +=
sum( xmm2 ) * alpha;
2541 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2543 const SIMDType a1( A2.load(i,k) );
2545 xmm1 += a1 * B2.load(k,j);
2548 c(ibegin+i,jj+j) +=
sum( xmm1 ) * alpha;
2558 if( remainder && kk < K )
2560 const size_t ksize( K - kk );
2562 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2563 const size_t isize ( M - ibegin );
2568 size_t jblock( 0UL );
2572 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2574 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
2583 if( IsFloatingPoint_v<ET1> )
2585 for( ; (i+5UL) <= isize; i+=5UL )
2587 if( jj+jblock < ibegin )
continue;
2589 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2591 for( ; (j+2UL) <= jblock; j+=2UL ) {
2592 for(
size_t k=0UL; k<ksize; ++k ) {
2593 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2594 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2595 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2596 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2597 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2598 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2599 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2600 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2601 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
2602 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
2607 for(
size_t k=0UL; k<ksize; ++k ) {
2608 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2609 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2610 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2611 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2612 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
2619 for( ; (i+4UL) <= isize; i+=4UL )
2621 if( jj+jblock < ibegin )
continue;
2623 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2625 for( ; (j+2UL) <= jblock; j+=2UL ) {
2626 for(
size_t k=0UL; k<ksize; ++k ) {
2627 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2628 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2629 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2630 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2631 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2632 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2633 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2634 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2639 for(
size_t k=0UL; k<ksize; ++k ) {
2640 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2641 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2642 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2643 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2649 for( ; (i+2UL) <= isize; i+=2UL )
2651 if( jj+jblock < ibegin )
continue;
2653 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2655 for( ; (j+2UL) <= jblock; j+=2UL ) {
2656 for(
size_t k=0UL; k<ksize; ++k ) {
2657 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2658 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2659 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2660 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2665 for(
size_t k=0UL; k<ksize; ++k ) {
2666 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2667 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2672 if( i<isize && jj+jblock >= ibegin )
2674 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2676 for( ; (j+2UL) <= jblock; j+=2UL ) {
2677 for(
size_t k=0UL; k<ksize; ++k ) {
2678 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
2679 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2684 for(
size_t k=0UL; k<ksize; ++k ) {
2685 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
2717template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
2718void ummm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha, ST beta )
2720 using ET1 = ElementType_t<MT1>;
2721 using ET2 = ElementType_t<MT2>;
2722 using ET3 = ElementType_t<MT3>;
2723 using SIMDType = SIMDTrait_t<ET1>;
2745 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2747 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/
sizeof(ET1) ) );
2748 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
2753 const size_t M( A.rows() );
2754 const size_t N( B.columns() );
2755 const size_t K( A.columns() );
2759 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
2760 DynamicMatrix<ET3,true> B2( KBLOCK, N );
2762 decltype(
auto) c( derestrict( *C ) );
2767 else if( !
isOne( beta ) ) {
2772 size_t kblock( 0UL );
2774 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2777 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):(
prevMultiple( K - kk, SIMDSIZE ) ) );
2780 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2783 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2784 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
2785 const size_t jsize ( jend - jbegin );
2787 B2 =
serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize,
unchecked ) );
2790 size_t iblock( 0UL );
2794 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2796 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
2797 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
2802 A2 =
serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock,
unchecked ) );
2806 if( IsFloatingPoint_v<ET3> )
2808 for( ; (j+5UL) <= jsize; j+=5UL )
2810 if( ii > jbegin+j+4UL )
continue;
2812 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
2815 for( ; (i+2UL) <= iend; i+=2UL )
2817 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2819 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2821 const SIMDType a1( A2.load(i ,k) );
2822 const SIMDType a2( A2.load(i+1UL,k) );
2824 const SIMDType b1( B2.load(k,j ) );
2825 const SIMDType b2( B2.load(k,j+1UL) );
2826 const SIMDType b3( B2.load(k,j+2UL) );
2827 const SIMDType b4( B2.load(k,j+3UL) );
2828 const SIMDType b5( B2.load(k,j+4UL) );
2842 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2843 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2844 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2845 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2846 c(ii+i ,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2847 c(ii+i+1UL,jbegin+j ) +=
sum( xmm6 ) * alpha;
2848 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm7 ) * alpha;
2849 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm8 ) * alpha;
2850 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm9 ) * alpha;
2851 c(ii+i+1UL,jbegin+j+4UL) +=
sum( xmm10 ) * alpha;
2856 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2858 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2860 const SIMDType a1( A2.load(i,k) );
2862 xmm1 += a1 * B2.load(k,j );
2863 xmm2 += a1 * B2.load(k,j+1UL);
2864 xmm3 += a1 * B2.load(k,j+2UL);
2865 xmm4 += a1 * B2.load(k,j+3UL);
2866 xmm5 += a1 * B2.load(k,j+4UL);
2869 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2870 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2871 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2872 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2873 c(ii+i,jbegin+j+4UL) +=
sum( xmm5 ) * alpha;
2879 for( ; (j+4UL) <= jsize; j+=4UL )
2881 if( ii > jbegin+j+3UL )
continue;
2883 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
2886 for( ; (i+2UL) <= iend; i+=2UL )
2888 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2890 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2892 const SIMDType a1( A2.load(i ,k) );
2893 const SIMDType a2( A2.load(i+1UL,k) );
2895 const SIMDType b1( B2.load(k,j ) );
2896 const SIMDType b2( B2.load(k,j+1UL) );
2897 const SIMDType b3( B2.load(k,j+2UL) );
2898 const SIMDType b4( B2.load(k,j+3UL) );
2910 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2911 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2912 c(ii+i ,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2913 c(ii+i ,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2914 c(ii+i+1UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2915 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2916 c(ii+i+1UL,jbegin+j+2UL) +=
sum( xmm7 ) * alpha;
2917 c(ii+i+1UL,jbegin+j+3UL) +=
sum( xmm8 ) * alpha;
2922 SIMDType xmm1, xmm2, xmm3, xmm4;
2924 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2926 const SIMDType a1( A2.load(i,k) );
2928 xmm1 += a1 * B2.load(k,j );
2929 xmm2 += a1 * B2.load(k,j+1UL);
2930 xmm3 += a1 * B2.load(k,j+2UL);
2931 xmm4 += a1 * B2.load(k,j+3UL);
2934 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
2935 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2936 c(ii+i,jbegin+j+2UL) +=
sum( xmm3 ) * alpha;
2937 c(ii+i,jbegin+j+3UL) +=
sum( xmm4 ) * alpha;
2942 for( ; (j+2UL) <= jsize; j+=2UL )
2944 if( ii > jbegin+j+1UL )
continue;
2946 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
2949 for( ; (i+4UL) <= iend; i+=4UL )
2951 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2953 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2955 const SIMDType a1( A2.load(i ,k) );
2956 const SIMDType a2( A2.load(i+1UL,k) );
2957 const SIMDType a3( A2.load(i+2UL,k) );
2958 const SIMDType a4( A2.load(i+3UL,k) );
2960 const SIMDType b1( B2.load(k,j ) );
2961 const SIMDType b2( B2.load(k,j+1UL) );
2973 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
2974 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
2975 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
2976 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
2977 c(ii+i+2UL,jbegin+j ) +=
sum( xmm5 ) * alpha;
2978 c(ii+i+2UL,jbegin+j+1UL) +=
sum( xmm6 ) * alpha;
2979 c(ii+i+3UL,jbegin+j ) +=
sum( xmm7 ) * alpha;
2980 c(ii+i+3UL,jbegin+j+1UL) +=
sum( xmm8 ) * alpha;
2983 for( ; (i+2UL) <= iend; i+=2UL )
2985 SIMDType xmm1, xmm2, xmm3, xmm4;
2987 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
2989 const SIMDType a1( A2.load(i ,k) );
2990 const SIMDType a2( A2.load(i+1UL,k) );
2992 const SIMDType b1( B2.load(k,j ) );
2993 const SIMDType b2( B2.load(k,j+1UL) );
3001 c(ii+i ,jbegin+j ) +=
sum( xmm1 ) * alpha;
3002 c(ii+i ,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3003 c(ii+i+1UL,jbegin+j ) +=
sum( xmm3 ) * alpha;
3004 c(ii+i+1UL,jbegin+j+1UL) +=
sum( xmm4 ) * alpha;
3009 SIMDType xmm1, xmm2;
3011 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3013 const SIMDType a1( A2.load(i,k) );
3015 xmm1 += a1 * B2.load(k,j );
3016 xmm2 += a1 * B2.load(k,j+1UL);
3019 c(ii+i,jbegin+j ) +=
sum( xmm1 ) * alpha;
3020 c(ii+i,jbegin+j+1UL) +=
sum( xmm2 ) * alpha;
3024 if( j<jsize && ii <= jbegin+j )
3026 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3029 for( ; (i+2UL) <= iend; i+=2UL )
3031 SIMDType xmm1, xmm2;
3033 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3035 const SIMDType b1( B2.load(k,j) );
3037 xmm1 += A2.load(i ,k) * b1;
3038 xmm2 += A2.load(i+1UL,k) * b1;
3041 c(ii+i ,jbegin+j) +=
sum( xmm1 ) * alpha;
3042 c(ii+i+1UL,jbegin+j) +=
sum( xmm2 ) * alpha;
3049 for(
size_t k=0UL; k<kblock; k+=SIMDSIZE )
3051 xmm1 += A2.load(i,k) * B2.load(k,j);
3054 c(ii+i,jbegin+j) +=
sum( xmm1 ) * alpha;
3064 if( remainder && kk < K )
3066 const size_t ksize( K - kk );
3068 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
3069 const size_t jsize ( N - jbegin );
3074 size_t iblock( 0UL );
3078 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
3080 if( IsLower_v<MT2> && ii+iblock <= kk ) {
3089 if( IsFloatingPoint_v<ET1> )
3091 for( ; (j+5UL) <= jsize; j+=5UL )
3093 if( ii > jbegin+j+4UL )
continue;
3095 const size_t iend(
min( iblock, jbegin+j-ii+5UL ) );
3098 for( ; (i+2UL) <= iend; i+=2UL ) {
3099 for(
size_t k=0UL; k<ksize; ++k ) {
3100 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3101 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3102 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3103 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3104 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
3105 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3106 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3107 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3108 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3109 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
3114 for(
size_t k=0UL; k<ksize; ++k ) {
3115 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3116 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3117 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3118 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3119 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
3126 for( ; (j+4UL) <= jsize; j+=4UL )
3128 if( ii > jbegin+j+3UL )
continue;
3130 const size_t iend(
min( iblock, jbegin+j-ii+4UL ) );
3133 for( ; (i+2UL) <= iend; i+=2UL ) {
3134 for(
size_t k=0UL; k<ksize; ++k ) {
3135 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3136 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3137 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3138 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3139 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3140 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3141 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3142 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3147 for(
size_t k=0UL; k<ksize; ++k ) {
3148 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3149 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3150 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3151 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3157 for( ; (j+2UL) <= jsize; j+=2UL )
3159 if( ii > jbegin+j+1UL )
continue;
3161 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3164 for( ; (i+2UL) <= iend; i+=2UL ) {
3165 for(
size_t k=0UL; k<ksize; ++k ) {
3166 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3167 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3168 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3169 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3174 for(
size_t k=0UL; k<ksize; ++k ) {
3175 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3176 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3181 if( j<jsize && ii <= jbegin+j )
3183 const size_t iend(
min( iblock, jbegin+j-ii+2UL ) );
3186 for( ; (i+2UL) <= iend; i+=2UL ) {
3187 for(
size_t k=0UL; k<ksize; ++k ) {
3188 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
3189 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
3194 for(
size_t k=0UL; k<ksize; ++k ) {
3195 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
3224template<
typename MT1,
typename MT2,
typename MT3 >
3225inline void ummm( MT1& C,
const MT2& A,
const MT3& B )
3227 using ET1 = ElementType_t<MT1>;
3228 using ET2 = ElementType_t<MT2>;
3229 using ET3 = ElementType_t<MT3>;
3234 ummm( C, A, B, ET1(1), ET1(0) );
3266template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3267void smmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3269 using ET1 = ElementType_t<MT1>;
3270 using ET2 = ElementType_t<MT2>;
3271 using ET3 = ElementType_t<MT3>;
3287 const size_t M( A.rows() );
3288 const size_t N( B.columns() );
3292 lmmm( C, A, B, alpha, ST(0) );
3294 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3296 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3298 for(
size_t i=ii; i<iend; ++i ) {
3299 for(
size_t j=i+1UL; j<iend; ++j ) {
3300 (*C)(i,j) = (*C)(j,i);
3304 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3305 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3306 for(
size_t i=ii; i<iend; ++i ) {
3307 for(
size_t j=jj; j<jend; ++j ) {
3308 (*C)(i,j) = (*C)(j,i);
3336template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3337void smmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3339 using ET1 = ElementType_t<MT1>;
3340 using ET2 = ElementType_t<MT2>;
3341 using ET3 = ElementType_t<MT3>;
3357 const size_t M( A.rows() );
3358 const size_t N( B.columns() );
3362 ummm( C, A, B, alpha, ST(0) );
3364 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3366 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3368 for(
size_t j=jj; j<jend; ++j ) {
3369 for(
size_t i=jj+1UL; i<jend; ++i ) {
3370 (*C)(i,j) = (*C)(j,i);
3374 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3375 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3376 for(
size_t j=jj; j<jend; ++j ) {
3377 for(
size_t i=ii; i<iend; ++i ) {
3378 (*C)(i,j) = (*C)(j,i);
3404template<
typename MT1,
typename MT2,
typename MT3 >
3405inline void smmm( MT1& C,
const MT2& A,
const MT3& B )
3407 using ET1 = ElementType_t<MT1>;
3408 using ET2 = ElementType_t<MT2>;
3409 using ET3 = ElementType_t<MT3>;
3414 smmm( C, A, B, ET1(1) );
3446template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3447void hmmm( DenseMatrix<MT1,false>& C,
const MT2& A,
const MT3& B, ST alpha )
3449 using ET1 = ElementType_t<MT1>;
3450 using ET2 = ElementType_t<MT2>;
3451 using ET3 = ElementType_t<MT3>;
3467 const size_t M( A.rows() );
3468 const size_t N( B.columns() );
3472 lmmm( C, A, B, alpha, ST(0) );
3474 for(
size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3476 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3478 for(
size_t i=ii; i<iend; ++i ) {
3479 for(
size_t j=i+1UL; j<iend; ++j ) {
3480 (*C)(i,j) =
conj( (*C)(j,i) );
3484 for(
size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3485 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3486 for(
size_t i=ii; i<iend; ++i ) {
3487 for(
size_t j=jj; j<jend; ++j ) {
3488 (*C)(i,j) =
conj( (*C)(j,i) );
3516template<
typename MT1,
typename MT2,
typename MT3,
typename ST >
3517void hmmm( DenseMatrix<MT1,true>& C,
const MT2& A,
const MT3& B, ST alpha )
3519 using ET1 = ElementType_t<MT1>;
3520 using ET2 = ElementType_t<MT2>;
3521 using ET3 = ElementType_t<MT3>;
3537 const size_t M( A.rows() );
3538 const size_t N( B.columns() );
3542 ummm( C, A, B, alpha, ST(0) );
3544 for(
size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3546 const size_t jend(
min( N, jj+BLOCK_SIZE ) );
3548 for(
size_t j=jj; j<jend; ++j ) {
3549 for(
size_t i=jj+1UL; i<jend; ++i ) {
3550 (*C)(i,j) =
conj( (*C)(j,i) );
3554 for(
size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3555 const size_t iend(
min( M, ii+BLOCK_SIZE ) );
3556 for(
size_t j=jj; j<jend; ++j ) {
3557 for(
size_t i=ii; i<iend; ++i ) {
3558 (*C)(i,j) =
conj( (*C)(j,i) );
3584template<
typename MT1,
typename MT2,
typename MT3 >
3585inline void hmmm( MT1& C,
const MT2& A,
const MT3& B )
3587 using ET1 = ElementType_t<MT1>;
3588 using ET2 = ElementType_t<MT2>;
3589 using ET3 = ElementType_t<MT3>;
3594 hmmm( C, A, B, ET1(1) );
Constraint on the data type.
Header file for auxiliary alias declarations.
Header file for run time assertion macros.
Header file for kernel specific block sizes.
Header file for the blaze::checked and blaze::unchecked instances.
Constraints on the storage order of matrix types.
Constraint on the data type.
Header file for the isDefault shim.
Header file for the IsFloatingPoint type trait.
Header file for the IsLower type trait.
Header file for the isOne shim.
Header file for the IsPadded type trait.
Header file for the IsUpper type trait.
Constraint on the data type.
Header file for the prevMultiple shim.
Constraints on the storage order of matrix types.
Constraint on the data type.
Header file for all SIMD functionality.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Header file for the implementation of a dynamic MxN matrix.
Header file for the DenseMatrix base class.
decltype(auto) min(const DenseMatrix< MT1, SO1 > &lhs, const DenseMatrix< MT2, SO2 > &rhs)
Computes the componentwise minimum of the dense matrices lhs and rhs.
Definition: DMatDMatMapExpr.h:1339
decltype(auto) conj(const DenseMatrix< MT, SO > &dm)
Returns a matrix containing the complex conjugate of each single element of dm.
Definition: DMatMapExpr.h:1464
decltype(auto) serial(const DenseMatrix< MT, SO > &dm)
Forces the serial evaluation of the given dense matrix expression dm.
Definition: DMatSerialExpr.h:812
decltype(auto) sum(const DenseMatrix< MT, SO > &dm)
Reduces the given dense matrix by means of addition.
Definition: DMatReduceExpr.h:2156
bool isDefault(const DiagonalMatrix< MT, SO, DF > &m)
Returns whether the given diagonal matrix is in default state.
Definition: DiagonalMatrix.h:169
#define BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Symmetric.h:79
#define BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.
Definition: RowMajorMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Hermitian.h:79
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Upper.h:81
#define BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE(T)
Constraint on the data type.
Definition: DenseMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.
Definition: Computation.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: UniUpper.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE(T)
Constraint on the data type.
Definition: Adaptor.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Lower.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: UniLower.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: StrictlyLower.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: StrictlyUpper.h:81
#define BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.
Definition: ColumnMajorMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES(T1, T2)
Constraint on the data type.
Definition: SIMDCombinable.h:61
BLAZE_ALWAYS_INLINE constexpr auto prevMultiple(T1 value, T2 factor) noexcept
Rounds down an integral value to the previous multiple of a given factor.
Definition: PrevMultiple.h:68
bool isOne(const Proxy< PT, RT > &proxy)
Returns whether the represented element is 1.
Definition: Proxy.h:2337
constexpr void reset(Matrix< MT, SO > &matrix)
Resetting the given matrix.
Definition: Matrix.h:806
constexpr size_t size(const Matrix< MT, SO > &matrix) noexcept
Returns the total number of elements of the matrix.
Definition: Matrix.h:676
#define BLAZE_INTERNAL_ASSERT(expr, msg)
Run time assertion macro for internal checks.
Definition: Assert.h:101
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.
Definition: StaticAssert.h:112
decltype(auto) submatrix(Matrix< MT, SO > &, RSAs...)
Creating a view on a specific submatrix of the given matrix.
Definition: Submatrix.h:181
constexpr Unchecked unchecked
Global Unchecked instance.
Definition: Check.h:146
Header file for the serial shim.
Header file for basic type definitions.
Header file for the generic min algorithm.
Header file for the implementation of the Submatrix view.