Blaze 3.9
MMM.h
Go to the documentation of this file.
1//=================================================================================================
33//=================================================================================================
34
35#ifndef _BLAZE_MATH_DENSE_MMM_H_
36#define _BLAZE_MATH_DENSE_MMM_H_
37
38
39//*************************************************************************************************
40// Includes
41//*************************************************************************************************
42
43#include <blaze/math/Aliases.h>
64#include <blaze/math/SIMD.h>
72#include <blaze/util/Assert.h>
74#include <blaze/util/Types.h>
76
77
78namespace blaze {
79
80//=================================================================================================
81//
82// GENERAL DENSE MATRIX MULTIPLICATION KERNELS
83//
84//=================================================================================================
85
86//*************************************************************************************************
105template< typename MT1, typename MT2, typename MT3, typename ST >
106void mmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
107{
108 using ET1 = ElementType_t<MT1>;
109 using ET2 = ElementType_t<MT2>;
110 using ET3 = ElementType_t<MT3>;
111 using SIMDType = SIMDTrait_t<ET1>;
112
117
120
123
126
127 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
128
129 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
130
131 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
132 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
133
134 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
135 BLAZE_STATIC_ASSERT( JBLOCK >= SIMDSIZE && JBLOCK % SIMDSIZE == 0UL );
136
137 const size_t M( A.rows() );
138 const size_t N( B.columns() );
139 const size_t K( A.columns() );
140
141 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
142
143 DynamicMatrix<ET2,false> A2( M, KBLOCK );
144 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
145
146 if( isDefault( beta ) ) {
147 reset( *C );
148 }
149 else if( !isOne( beta ) ) {
150 (*C) *= beta;
151 }
152
153 size_t kk( 0UL );
154 size_t kblock( 0UL );
155
156 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
157 {
158 if( remainder ) {
159 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
160 }
161 else {
162 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
163 }
164
165 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
166 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
167 const size_t isize ( iend - ibegin );
168
169 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock, unchecked ) );
170
171 size_t jj( 0UL );
172 size_t jblock( 0UL );
173
174 while( jj < N )
175 {
176 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
177
178 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
179 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
180 jj += jblock;
181 continue;
182 }
183
184 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock, unchecked ) );
185
186 size_t i( 0UL );
187
188 if( IsFloatingPoint_v<ET1> )
189 {
190 for( ; (i+5UL) <= isize; i+=5UL )
191 {
192 size_t j( 0UL );
193
194 for( ; (j+2UL) <= jblock; j+=2UL )
195 {
196 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
197
198 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
199 {
200 const SIMDType a1( A2.load(i ,k) );
201 const SIMDType a2( A2.load(i+1UL,k) );
202 const SIMDType a3( A2.load(i+2UL,k) );
203 const SIMDType a4( A2.load(i+3UL,k) );
204 const SIMDType a5( A2.load(i+4UL,k) );
205
206 const SIMDType b1( B2.load(k,j ) );
207 const SIMDType b2( B2.load(k,j+1UL) );
208
209 xmm1 += a1 * b1;
210 xmm2 += a1 * b2;
211 xmm3 += a2 * b1;
212 xmm4 += a2 * b2;
213 xmm5 += a3 * b1;
214 xmm6 += a3 * b2;
215 xmm7 += a4 * b1;
216 xmm8 += a4 * b2;
217 xmm9 += a5 * b1;
218 xmm10 += a5 * b2;
219 }
220
221 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
222 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
223 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
224 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
225 (*C)(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
226 (*C)(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
227 (*C)(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
228 (*C)(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
229 (*C)(ibegin+i+4UL,jj+j ) += sum( xmm9 ) * alpha;
230 (*C)(ibegin+i+4UL,jj+j+1UL) += sum( xmm10 ) * alpha;
231 }
232
233 if( j<jblock )
234 {
235 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
236
237 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
238 {
239 const SIMDType a1( A2.load(i ,k) );
240 const SIMDType a2( A2.load(i+1UL,k) );
241 const SIMDType a3( A2.load(i+2UL,k) );
242 const SIMDType a4( A2.load(i+3UL,k) );
243 const SIMDType a5( A2.load(i+4UL,k) );
244
245 const SIMDType b1( B2.load(k,j) );
246
247 xmm1 += a1 * b1;
248 xmm2 += a2 * b1;
249 xmm3 += a3 * b1;
250 xmm4 += a4 * b1;
251 xmm5 += a5 * b1;
252 }
253
254 (*C)(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
255 (*C)(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
256 (*C)(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
257 (*C)(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
258 (*C)(ibegin+i+4UL,jj+j) += sum( xmm5 ) * alpha;
259 }
260 }
261 }
262 else
263 {
264 for( ; (i+4UL) <= isize; i+=4UL )
265 {
266 size_t j( 0UL );
267
268 for( ; (j+2UL) <= jblock; j+=2UL )
269 {
270 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
271
272 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
273 {
274 const SIMDType a1( A2.load(i ,k) );
275 const SIMDType a2( A2.load(i+1UL,k) );
276 const SIMDType a3( A2.load(i+2UL,k) );
277 const SIMDType a4( A2.load(i+3UL,k) );
278
279 const SIMDType b1( B2.load(k,j ) );
280 const SIMDType b2( B2.load(k,j+1UL) );
281
282 xmm1 += a1 * b1;
283 xmm2 += a1 * b2;
284 xmm3 += a2 * b1;
285 xmm4 += a2 * b2;
286 xmm5 += a3 * b1;
287 xmm6 += a3 * b2;
288 xmm7 += a4 * b1;
289 xmm8 += a4 * b2;
290 }
291
292 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
293 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
294 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
295 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
296 (*C)(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
297 (*C)(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
298 (*C)(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
299 (*C)(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
300 }
301
302 if( j<jblock )
303 {
304 SIMDType xmm1, xmm2, xmm3, xmm4;
305
306 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
307 {
308 const SIMDType a1( A2.load(i ,k) );
309 const SIMDType a2( A2.load(i+1UL,k) );
310 const SIMDType a3( A2.load(i+2UL,k) );
311 const SIMDType a4( A2.load(i+3UL,k) );
312
313 const SIMDType b1( B2.load(k,j) );
314
315 xmm1 += a1 * b1;
316 xmm2 += a2 * b1;
317 xmm3 += a3 * b1;
318 xmm4 += a4 * b1;
319 }
320
321 (*C)(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
322 (*C)(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
323 (*C)(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
324 (*C)(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
325 }
326 }
327 }
328
329 for( ; (i+2UL) <= isize; i+=2UL )
330 {
331 size_t j( 0UL );
332
333 for( ; (j+4UL) <= jblock; j+=4UL )
334 {
335 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
336
337 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
338 {
339 const SIMDType a1( A2.load(i ,k) );
340 const SIMDType a2( A2.load(i+1UL,k) );
341
342 const SIMDType b1( B2.load(k,j ) );
343 const SIMDType b2( B2.load(k,j+1UL) );
344 const SIMDType b3( B2.load(k,j+2UL) );
345 const SIMDType b4( B2.load(k,j+3UL) );
346
347 xmm1 += a1 * b1;
348 xmm2 += a1 * b2;
349 xmm3 += a1 * b3;
350 xmm4 += a1 * b4;
351 xmm5 += a2 * b1;
352 xmm6 += a2 * b2;
353 xmm7 += a2 * b3;
354 xmm8 += a2 * b4;
355 }
356
357 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
358 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
359 (*C)(ibegin+i ,jj+j+2UL) += sum( xmm3 ) * alpha;
360 (*C)(ibegin+i ,jj+j+3UL) += sum( xmm4 ) * alpha;
361 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm5 ) * alpha;
362 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm6 ) * alpha;
363 (*C)(ibegin+i+1UL,jj+j+2UL) += sum( xmm7 ) * alpha;
364 (*C)(ibegin+i+1UL,jj+j+3UL) += sum( xmm8 ) * alpha;
365 }
366
367 for( ; (j+2UL) <= jblock; j+=2UL )
368 {
369 SIMDType xmm1, xmm2, xmm3, xmm4;
370
371 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
372 {
373 const SIMDType a1( A2.load(i ,k) );
374 const SIMDType a2( A2.load(i+1UL,k) );
375
376 const SIMDType b1( B2.load(k,j ) );
377 const SIMDType b2( B2.load(k,j+1UL) );
378
379 xmm1 += a1 * b1;
380 xmm2 += a1 * b2;
381 xmm3 += a2 * b1;
382 xmm4 += a2 * b2;
383 }
384
385 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
386 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
387 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
388 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
389 }
390
391 if( j<jblock )
392 {
393 SIMDType xmm1, xmm2;
394
395 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
396 {
397 const SIMDType a1( A2.load(i ,k) );
398 const SIMDType a2( A2.load(i+1UL,k) );
399
400 const SIMDType b1( B2.load(k,j) );
401
402 xmm1 += a1 * b1;
403 xmm2 += a2 * b1;
404 }
405
406 (*C)(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
407 (*C)(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
408 }
409 }
410
411 if( i<isize )
412 {
413 size_t j( 0UL );
414
415 for( ; (j+2UL) <= jblock; j+=2UL )
416 {
417 SIMDType xmm1, xmm2;
418
419 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
420 {
421 const SIMDType a1( A2.load(i,k) );
422
423 xmm1 += a1 * B2.load(k,j );
424 xmm2 += a1 * B2.load(k,j+1UL);
425 }
426
427 (*C)(ibegin+i,jj+j ) += sum( xmm1 ) * alpha;
428 (*C)(ibegin+i,jj+j+1UL) += sum( xmm2 ) * alpha;
429 }
430
431 if( j<jblock )
432 {
433 SIMDType xmm1;
434
435 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
436 {
437 const SIMDType a1( A2.load(i,k) );
438
439 xmm1 += a1 * B2.load(k,j);
440 }
441
442 (*C)(ibegin+i,jj+j) += sum( xmm1 ) * alpha;
443 }
444 }
445
446 jj += jblock;
447 }
448
449 kk += kblock;
450 }
451
452 if( remainder && kk < K )
453 {
454 const size_t ksize( K - kk );
455
456 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
457 const size_t isize ( M - ibegin );
458
459 A2 = serial( submatrix( A, ibegin, kk, isize, ksize, unchecked ) );
460
461 size_t jj( 0UL );
462 size_t jblock( 0UL );
463
464 while( jj < N )
465 {
466 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
467
468 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
469 jj += jblock;
470 continue;
471 }
472
473 B2 = serial( submatrix( B, kk, jj, ksize, jblock, unchecked ) );
474
475 size_t i( 0UL );
476
477 if( IsFloatingPoint_v<ET1> )
478 {
479 for( ; (i+5UL) <= isize; i+=5UL )
480 {
481 size_t j( 0UL );
482
483 for( ; (j+2UL) <= jblock; j+=2UL ) {
484 for( size_t k=0UL; k<ksize; ++k ) {
485 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
486 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
487 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
488 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
489 (*C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
490 (*C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
491 (*C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
492 (*C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
493 (*C)(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
494 (*C)(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
495 }
496 }
497
498 if( j<jblock ) {
499 for( size_t k=0UL; k<ksize; ++k ) {
500 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
501 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
502 (*C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
503 (*C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
504 (*C)(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
505 }
506 }
507 }
508 }
509 else
510 {
511 for( ; (i+4UL) <= isize; i+=4UL )
512 {
513 size_t j( 0UL );
514
515 for( ; (j+2UL) <= jblock; j+=2UL ) {
516 for( size_t k=0UL; k<ksize; ++k ) {
517 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
518 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
519 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
520 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
521 (*C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
522 (*C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
523 (*C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
524 (*C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
525 }
526 }
527
528 if( j<jblock ) {
529 for( size_t k=0UL; k<ksize; ++k ) {
530 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
531 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
532 (*C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
533 (*C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
534 }
535 }
536 }
537 }
538
539 for( ; (i+2UL) <= isize; i+=2UL )
540 {
541 size_t j( 0UL );
542
543 for( ; (j+2UL) <= jblock; j+=2UL ) {
544 for( size_t k=0UL; k<ksize; ++k ) {
545 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
546 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
547 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
548 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
549 }
550 }
551
552 if( j<jblock ) {
553 for( size_t k=0UL; k<ksize; ++k ) {
554 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
555 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
556 }
557 }
558 }
559
560 if( i<isize )
561 {
562 size_t j( 0UL );
563
564 for( ; (j+2UL) <= jblock; j+=2UL ) {
565 for( size_t k=0UL; k<ksize; ++k ) {
566 (*C)(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
567 (*C)(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
568 }
569 }
570
571 if( j<jblock ) {
572 for( size_t k=0UL; k<ksize; ++k ) {
573 (*C)(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
574 }
575 }
576 }
577
578 jj += jblock;
579 }
580 }
581}
583//*************************************************************************************************
584
585
586//*************************************************************************************************
605template< typename MT1, typename MT2, typename MT3, typename ST >
606void mmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
607{
608 using ET1 = ElementType_t<MT1>;
609 using ET2 = ElementType_t<MT2>;
610 using ET3 = ElementType_t<MT3>;
611 using SIMDType = SIMDTrait_t<ET1>;
612
617
620
623
626
627 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
628
629 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
630
631 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
632 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
633
634 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
635 BLAZE_STATIC_ASSERT( IBLOCK >= SIMDSIZE && IBLOCK % SIMDSIZE == 0UL );
636
637 const size_t M( A.rows() );
638 const size_t N( B.columns() );
639 const size_t K( A.columns() );
640
641 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
642
643 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
644 DynamicMatrix<ET3,true> B2( KBLOCK, N );
645
646 if( isDefault( beta ) ) {
647 reset( *C );
648 }
649 else if( !isOne( beta ) ) {
650 (*C) *= beta;
651 }
652
653 size_t kk( 0UL );
654 size_t kblock( 0UL );
655
656 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
657 {
658 if( remainder ) {
659 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
660 }
661 else {
662 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
663 }
664
665 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
666 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
667 const size_t jsize ( jend - jbegin );
668
669 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize, unchecked ) );
670
671 size_t ii( 0UL );
672 size_t iblock( 0UL );
673
674 while( ii < M )
675 {
676 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
677
678 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
679 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
680 ii += iblock;
681 continue;
682 }
683
684 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock, unchecked ) );
685
686 size_t j( 0UL );
687
688 if( IsFloatingPoint_v<ET3> )
689 {
690 for( ; (j+5UL) <= jsize; j+=5UL )
691 {
692 size_t i( 0UL );
693
694 for( ; (i+2UL) <= iblock; i+=2UL )
695 {
696 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
697
698 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
699 {
700 const SIMDType a1( A2.load(i ,k) );
701 const SIMDType a2( A2.load(i+1UL,k) );
702
703 const SIMDType b1( B2.load(k,j ) );
704 const SIMDType b2( B2.load(k,j+1UL) );
705 const SIMDType b3( B2.load(k,j+2UL) );
706 const SIMDType b4( B2.load(k,j+3UL) );
707 const SIMDType b5( B2.load(k,j+4UL) );
708
709 xmm1 += a1 * b1;
710 xmm2 += a1 * b2;
711 xmm3 += a1 * b3;
712 xmm4 += a1 * b4;
713 xmm5 += a1 * b5;
714 xmm6 += a2 * b1;
715 xmm7 += a2 * b2;
716 xmm8 += a2 * b3;
717 xmm9 += a2 * b4;
718 xmm10 += a2 * b5;
719 }
720
721 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
722 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
723 (*C)(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
724 (*C)(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
725 (*C)(ii+i ,jbegin+j+4UL) += sum( xmm5 ) * alpha;
726 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm6 ) * alpha;
727 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm7 ) * alpha;
728 (*C)(ii+i+1UL,jbegin+j+2UL) += sum( xmm8 ) * alpha;
729 (*C)(ii+i+1UL,jbegin+j+3UL) += sum( xmm9 ) * alpha;
730 (*C)(ii+i+1UL,jbegin+j+4UL) += sum( xmm10 ) * alpha;
731 }
732
733 if( i<iblock )
734 {
735 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
736
737 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
738 {
739 const SIMDType a1( A2.load(i,k) );
740
741 xmm1 += a1 * B2.load(k,j );
742 xmm2 += a1 * B2.load(k,j+1UL);
743 xmm3 += a1 * B2.load(k,j+2UL);
744 xmm4 += a1 * B2.load(k,j+3UL);
745 xmm5 += a1 * B2.load(k,j+4UL);
746 }
747
748 (*C)(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
749 (*C)(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
750 (*C)(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
751 (*C)(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
752 (*C)(ii+i,jbegin+j+4UL) += sum( xmm5 ) * alpha;
753 }
754 }
755 }
756 else
757 {
758 for( ; (j+4UL) <= jsize; j+=4UL )
759 {
760 size_t i( 0UL );
761
762 for( ; (i+2UL) <= iblock; i+=2UL )
763 {
764 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
765
766 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
767 {
768 const SIMDType a1( A2.load(i ,k) );
769 const SIMDType a2( A2.load(i+1UL,k) );
770
771 const SIMDType b1( B2.load(k,j ) );
772 const SIMDType b2( B2.load(k,j+1UL) );
773 const SIMDType b3( B2.load(k,j+2UL) );
774 const SIMDType b4( B2.load(k,j+3UL) );
775
776 xmm1 += a1 * b1;
777 xmm2 += a1 * b2;
778 xmm3 += a1 * b3;
779 xmm4 += a1 * b4;
780 xmm5 += a2 * b1;
781 xmm6 += a2 * b2;
782 xmm7 += a2 * b3;
783 xmm8 += a2 * b4;
784 }
785
786 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
787 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
788 (*C)(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
789 (*C)(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
790 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm5 ) * alpha;
791 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
792 (*C)(ii+i+1UL,jbegin+j+2UL) += sum( xmm7 ) * alpha;
793 (*C)(ii+i+1UL,jbegin+j+3UL) += sum( xmm8 ) * alpha;
794 }
795
796 if( i<iblock )
797 {
798 SIMDType xmm1, xmm2, xmm3, xmm4;
799
800 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
801 {
802 const SIMDType a1( A2.load(i,k) );
803
804 xmm1 += a1 * B2.load(k,j );
805 xmm2 += a1 * B2.load(k,j+1UL);
806 xmm3 += a1 * B2.load(k,j+2UL);
807 xmm4 += a1 * B2.load(k,j+3UL);
808 }
809
810 (*C)(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
811 (*C)(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
812 (*C)(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
813 (*C)(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
814 }
815 }
816 }
817
818 for( ; (j+2UL) <= jsize; j+=2UL )
819 {
820 size_t i( 0UL );
821
822 for( ; (i+4UL) <= iblock; i+=4UL )
823 {
824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
825
826 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
827 {
828 const SIMDType a1( A2.load(i ,k) );
829 const SIMDType a2( A2.load(i+1UL,k) );
830 const SIMDType a3( A2.load(i+2UL,k) );
831 const SIMDType a4( A2.load(i+3UL,k) );
832
833 const SIMDType b1( B2.load(k,j ) );
834 const SIMDType b2( B2.load(k,j+1UL) );
835
836 xmm1 += a1 * b1;
837 xmm2 += a1 * b2;
838 xmm3 += a2 * b1;
839 xmm4 += a2 * b2;
840 xmm5 += a3 * b1;
841 xmm6 += a3 * b2;
842 xmm7 += a4 * b1;
843 xmm8 += a4 * b2;
844 }
845
846 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
847 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
848 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
849 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
850 (*C)(ii+i+2UL,jbegin+j ) += sum( xmm5 ) * alpha;
851 (*C)(ii+i+2UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
852 (*C)(ii+i+3UL,jbegin+j ) += sum( xmm7 ) * alpha;
853 (*C)(ii+i+3UL,jbegin+j+1UL) += sum( xmm8 ) * alpha;
854 }
855
856 for( ; (i+2UL) <= iblock; i+=2UL )
857 {
858 SIMDType xmm1, xmm2, xmm3, xmm4;
859
860 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
861 {
862 const SIMDType a1( A2.load(i ,k) );
863 const SIMDType a2( A2.load(i+1UL,k) );
864
865 const SIMDType b1( B2.load(k,j ) );
866 const SIMDType b2( B2.load(k,j+1UL) );
867
868 xmm1 += a1 * b1;
869 xmm2 += a1 * b2;
870 xmm3 += a2 * b1;
871 xmm4 += a2 * b2;
872 }
873
874 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
875 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
876 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
877 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
878 }
879
880 if( i<iblock )
881 {
882 SIMDType xmm1, xmm2;
883
884 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
885 {
886 const SIMDType a1( A2.load(i,k) );
887
888 xmm1 += a1 * B2.load(k,j );
889 xmm2 += a1 * B2.load(k,j+1UL);
890 }
891
892 (*C)(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
893 (*C)(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
894 }
895 }
896
897 if( j<jsize )
898 {
899 size_t i( 0UL );
900
901 for( ; (i+2UL) <= iblock; i+=2UL )
902 {
903 SIMDType xmm1, xmm2;
904
905 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
906 {
907 const SIMDType b1( B2.load(k,j) );
908
909 xmm1 += A2.load(i ,k) * b1;
910 xmm2 += A2.load(i+1UL,k) * b1;
911 }
912
913 (*C)(ii+i ,jbegin+j) += sum( xmm1 ) * alpha;
914 (*C)(ii+i+1UL,jbegin+j) += sum( xmm2 ) * alpha;
915 }
916
917 if( i<iblock )
918 {
919 SIMDType xmm1;
920
921 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
922 {
923 xmm1 += A2.load(i,k) * B2.load(k,j);
924 }
925
926 (*C)(ii+i,jbegin+j) += sum( xmm1 ) * alpha;
927 }
928 }
929
930 ii += iblock;
931 }
932
933 kk += kblock;
934 }
935
936 if( remainder && kk < K )
937 {
938 const size_t ksize( K - kk );
939
940 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
941 const size_t jsize ( N - jbegin );
942
943 B2 = serial( submatrix( B, kk, jbegin, ksize, jsize, unchecked ) );
944
945 size_t ii( 0UL );
946 size_t iblock( 0UL );
947
948 while( ii < M )
949 {
950 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
951
952 if( IsLower_v<MT2> && ii+iblock <= kk ) {
953 ii += iblock;
954 continue;
955 }
956
957 A2 = serial( submatrix( A, ii, kk, iblock, ksize, unchecked ) );
958
959 size_t j( 0UL );
960
961 if( IsFloatingPoint_v<ET1> )
962 {
963 for( ; (j+5UL) <= jsize; j+=5UL )
964 {
965 size_t i( 0UL );
966
967 for( ; (i+2UL) <= iblock; i+=2UL ) {
968 for( size_t k=0UL; k<ksize; ++k ) {
969 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
970 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
971 (*C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
972 (*C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
973 (*C)(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
974 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
975 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
976 (*C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
977 (*C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
978 (*C)(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
979 }
980 }
981
982 if( i<iblock ) {
983 for( size_t k=0UL; k<ksize; ++k ) {
984 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
985 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
986 (*C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
987 (*C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
988 (*C)(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
989 }
990 }
991 }
992 }
993 else
994 {
995 for( ; (j+4UL) <= jsize; j+=4UL )
996 {
997 size_t i( 0UL );
998
999 for( ; (i+2UL) <= iblock; i+=2UL ) {
1000 for( size_t k=0UL; k<ksize; ++k ) {
1001 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1002 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1003 (*C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
1004 (*C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
1005 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1006 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1007 (*C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
1008 (*C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
1009 }
1010 }
1011
1012 if( i<iblock ) {
1013 for( size_t k=0UL; k<ksize; ++k ) {
1014 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1015 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1016 (*C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
1017 (*C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
1018 }
1019 }
1020 }
1021 }
1022
1023 for( ; (j+2UL) <= jsize; j+=2UL )
1024 {
1025 size_t i( 0UL );
1026
1027 for( ; (i+2UL) <= iblock; i+=2UL ) {
1028 for( size_t k=0UL; k<ksize; ++k ) {
1029 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1030 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1031 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1032 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1033 }
1034 }
1035
1036 if( i<iblock ) {
1037 for( size_t k=0UL; k<ksize; ++k ) {
1038 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1039 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1040 }
1041 }
1042 }
1043
1044 if( j<jsize )
1045 {
1046 size_t i( 0UL );
1047
1048 for( ; (i+2UL) <= iblock; i+=2UL ) {
1049 for( size_t k=0UL; k<ksize; ++k ) {
1050 (*C)(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
1051 (*C)(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1052 }
1053 }
1054
1055 if( i<iblock ) {
1056 for( size_t k=0UL; k<ksize; ++k ) {
1057 (*C)(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
1058 }
1059 }
1060 }
1061
1062 ii += iblock;
1063 }
1064 }
1065}
1067//*************************************************************************************************
1068
1069
1070//*************************************************************************************************
1086template< typename MT1, typename MT2, typename MT3 >
1087inline void mmm( MT1& C, const MT2& A, const MT3& B )
1088{
1089 using ET1 = ElementType_t<MT1>;
1090 using ET2 = ElementType_t<MT2>;
1091 using ET3 = ElementType_t<MT3>;
1092
1095
1096 mmm( C, A, B, ET1(1), ET1(0) );
1097}
1099//*************************************************************************************************
1100
1101
1102
1103
1104//=================================================================================================
1105//
1106// LOWER DENSE MATRIX MULTIPLICATION KERNELS
1107//
1108//=================================================================================================
1109
1110//*************************************************************************************************
1129template< typename MT1, typename MT2, typename MT3, typename ST >
1130void lmmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
1131{
1132 using ET1 = ElementType_t<MT1>;
1133 using ET2 = ElementType_t<MT2>;
1134 using ET3 = ElementType_t<MT3>;
1135 using SIMDType = SIMDTrait_t<ET1>;
1136
1145
1148
1151
1154
1155 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
1156
1157 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1158
1159 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
1160 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
1161
1162 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
1163 BLAZE_STATIC_ASSERT( JBLOCK >= SIMDSIZE && JBLOCK % SIMDSIZE == 0UL );
1164
1165 const size_t M( A.rows() );
1166 const size_t N( B.columns() );
1167 const size_t K( A.columns() );
1168
1169 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
1170
1171 DynamicMatrix<ET2,false> A2( M, KBLOCK );
1172 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
1173
1174 decltype(auto) c( derestrict( *C ) );
1175
1176 if( isDefault( beta ) ) {
1177 reset( c );
1178 }
1179 else if( !isOne( beta ) ) {
1180 c *= beta;
1181 }
1182
1183 size_t kk( 0UL );
1184 size_t kblock( 0UL );
1185
1186 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1187 {
1188 if( remainder ) {
1189 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
1190 }
1191 else {
1192 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1193 }
1194
1195 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1196 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
1197 const size_t isize ( iend - ibegin );
1198
1199 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock, unchecked ) );
1200
1201 size_t jj( 0UL );
1202 size_t jblock( 0UL );
1203
1204 while( jj < N )
1205 {
1206 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1207
1208 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
1209 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
1210 jj += jblock;
1211 continue;
1212 }
1213
1214 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock, unchecked ) );
1215
1216 size_t i( 0UL );
1217
1218 if( IsFloatingPoint_v<ET1> )
1219 {
1220 for( ; (i+5UL) <= isize; i+=5UL )
1221 {
1222 if( jj > ibegin+i+4UL ) continue;
1223
1224 const size_t jend( min( ibegin+i-jj+5UL, jblock ) );
1225 size_t j( 0UL );
1226
1227 for( ; (j+2UL) <= jend; j+=2UL )
1228 {
1229 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1230
1231 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1232 {
1233 const SIMDType a1( A2.load(i ,k) );
1234 const SIMDType a2( A2.load(i+1UL,k) );
1235 const SIMDType a3( A2.load(i+2UL,k) );
1236 const SIMDType a4( A2.load(i+3UL,k) );
1237 const SIMDType a5( A2.load(i+4UL,k) );
1238
1239 const SIMDType b1( B2.load(k,j ) );
1240 const SIMDType b2( B2.load(k,j+1UL) );
1241
1242 xmm1 += a1 * b1;
1243 xmm2 += a1 * b2;
1244 xmm3 += a2 * b1;
1245 xmm4 += a2 * b2;
1246 xmm5 += a3 * b1;
1247 xmm6 += a3 * b2;
1248 xmm7 += a4 * b1;
1249 xmm8 += a4 * b2;
1250 xmm9 += a5 * b1;
1251 xmm10 += a5 * b2;
1252 }
1253
1254 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1255 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1256 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
1257 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
1258 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
1259 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
1260 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
1261 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
1262 c(ibegin+i+4UL,jj+j ) += sum( xmm9 ) * alpha;
1263 c(ibegin+i+4UL,jj+j+1UL) += sum( xmm10 ) * alpha;
1264 }
1265
1266 if( j<jend )
1267 {
1268 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1269
1270 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1271 {
1272 const SIMDType a1( A2.load(i ,k) );
1273 const SIMDType a2( A2.load(i+1UL,k) );
1274 const SIMDType a3( A2.load(i+2UL,k) );
1275 const SIMDType a4( A2.load(i+3UL,k) );
1276 const SIMDType a5( A2.load(i+4UL,k) );
1277
1278 const SIMDType b1( B2.load(k,j) );
1279
1280 xmm1 += a1 * b1;
1281 xmm2 += a2 * b1;
1282 xmm3 += a3 * b1;
1283 xmm4 += a4 * b1;
1284 xmm5 += a5 * b1;
1285 }
1286
1287 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
1288 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
1289 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
1290 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
1291 c(ibegin+i+4UL,jj+j) += sum( xmm5 ) * alpha;
1292 }
1293 }
1294 }
1295 else
1296 {
1297 for( ; (i+4UL) <= isize; i+=4UL )
1298 {
1299 if( jj > ibegin+i+3UL ) continue;
1300
1301 const size_t jend( min( ibegin+i-jj+4UL, jblock ) );
1302 size_t j( 0UL );
1303
1304 for( ; (j+2UL) <= jend; j+=2UL )
1305 {
1306 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1307
1308 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1309 {
1310 const SIMDType a1( A2.load(i ,k) );
1311 const SIMDType a2( A2.load(i+1UL,k) );
1312 const SIMDType a3( A2.load(i+2UL,k) );
1313 const SIMDType a4( A2.load(i+3UL,k) );
1314
1315 const SIMDType b1( B2.load(k,j ) );
1316 const SIMDType b2( B2.load(k,j+1UL) );
1317
1318 xmm1 += a1 * b1;
1319 xmm2 += a1 * b2;
1320 xmm3 += a2 * b1;
1321 xmm4 += a2 * b2;
1322 xmm5 += a3 * b1;
1323 xmm6 += a3 * b2;
1324 xmm7 += a4 * b1;
1325 xmm8 += a4 * b2;
1326 }
1327
1328 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1329 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1330 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
1331 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
1332 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
1333 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
1334 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
1335 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
1336 }
1337
1338 if( j<jend )
1339 {
1340 SIMDType xmm1, xmm2, xmm3, xmm4;
1341
1342 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1343 {
1344 const SIMDType a1( A2.load(i ,k) );
1345 const SIMDType a2( A2.load(i+1UL,k) );
1346 const SIMDType a3( A2.load(i+2UL,k) );
1347 const SIMDType a4( A2.load(i+3UL,k) );
1348
1349 const SIMDType b1( B2.load(k,j) );
1350
1351 xmm1 += a1 * b1;
1352 xmm2 += a2 * b1;
1353 xmm3 += a3 * b1;
1354 xmm4 += a4 * b1;
1355 }
1356
1357 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
1358 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
1359 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
1360 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
1361 }
1362 }
1363 }
1364
1365 for( ; (i+2UL) <= isize; i+=2UL )
1366 {
1367 if( jj > ibegin+i+1UL ) continue;
1368
1369 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1370 size_t j( 0UL );
1371
1372 for( ; (j+4UL) <= jend; j+=4UL )
1373 {
1374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1375
1376 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1377 {
1378 const SIMDType a1( A2.load(i ,k) );
1379 const SIMDType a2( A2.load(i+1UL,k) );
1380
1381 const SIMDType b1( B2.load(k,j ) );
1382 const SIMDType b2( B2.load(k,j+1UL) );
1383 const SIMDType b3( B2.load(k,j+2UL) );
1384 const SIMDType b4( B2.load(k,j+3UL) );
1385
1386 xmm1 += a1 * b1;
1387 xmm2 += a1 * b2;
1388 xmm3 += a1 * b3;
1389 xmm4 += a1 * b4;
1390 xmm5 += a2 * b1;
1391 xmm6 += a2 * b2;
1392 xmm7 += a2 * b3;
1393 xmm8 += a2 * b4;
1394 }
1395
1396 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1397 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1398 c(ibegin+i ,jj+j+2UL) += sum( xmm3 ) * alpha;
1399 c(ibegin+i ,jj+j+3UL) += sum( xmm4 ) * alpha;
1400 c(ibegin+i+1UL,jj+j ) += sum( xmm5 ) * alpha;
1401 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm6 ) * alpha;
1402 c(ibegin+i+1UL,jj+j+2UL) += sum( xmm7 ) * alpha;
1403 c(ibegin+i+1UL,jj+j+3UL) += sum( xmm8 ) * alpha;
1404 }
1405
1406 for( ; (j+2UL) <= jend; j+=2UL )
1407 {
1408 SIMDType xmm1, xmm2, xmm3, xmm4;
1409
1410 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1411 {
1412 const SIMDType a1( A2.load(i ,k) );
1413 const SIMDType a2( A2.load(i+1UL,k) );
1414
1415 const SIMDType b1( B2.load(k,j ) );
1416 const SIMDType b2( B2.load(k,j+1UL) );
1417
1418 xmm1 += a1 * b1;
1419 xmm2 += a1 * b2;
1420 xmm3 += a2 * b1;
1421 xmm4 += a2 * b2;
1422 }
1423
1424 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1425 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1426 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
1427 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
1428 }
1429
1430 if( j<jend )
1431 {
1432 SIMDType xmm1, xmm2;
1433
1434 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1435 {
1436 const SIMDType a1( A2.load(i ,k) );
1437 const SIMDType a2( A2.load(i+1UL,k) );
1438
1439 const SIMDType b1( B2.load(k,j) );
1440
1441 xmm1 += a1 * b1;
1442 xmm2 += a2 * b1;
1443 }
1444
1445 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
1446 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
1447 }
1448 }
1449
1450 if( i<isize && jj <= ibegin+i )
1451 {
1452 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1453 size_t j( 0UL );
1454
1455 for( ; (j+2UL) <= jend; j+=2UL )
1456 {
1457 SIMDType xmm1, xmm2;
1458
1459 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1460 {
1461 const SIMDType a1( A2.load(i,k) );
1462
1463 xmm1 += a1 * B2.load(k,j );
1464 xmm2 += a1 * B2.load(k,j+1UL);
1465 }
1466
1467 c(ibegin+i,jj+j ) += sum( xmm1 ) * alpha;
1468 c(ibegin+i,jj+j+1UL) += sum( xmm2 ) * alpha;
1469 }
1470
1471 if( j<jend )
1472 {
1473 SIMDType xmm1;
1474
1475 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1476 {
1477 const SIMDType a1( A2.load(i,k) );
1478
1479 xmm1 += a1 * B2.load(k,j);
1480 }
1481
1482 c(ibegin+i,jj+j) += sum( xmm1 ) * alpha;
1483 }
1484 }
1485
1486 jj += jblock;
1487 }
1488
1489 kk += kblock;
1490 }
1491
1492 if( remainder && kk < K )
1493 {
1494 const size_t ksize( K - kk );
1495
1496 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1497 const size_t isize ( M - ibegin );
1498
1499 A2 = serial( submatrix( A, ibegin, kk, isize, ksize, unchecked ) );
1500
1501 size_t jj( 0UL );
1502 size_t jblock( 0UL );
1503
1504 while( jj < N )
1505 {
1506 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1507
1508 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
1509 jj += jblock;
1510 continue;
1511 }
1512
1513 B2 = serial( submatrix( B, kk, jj, ksize, jblock, unchecked ) );
1514
1515 size_t i( 0UL );
1516
1517 if( IsFloatingPoint_v<ET1> )
1518 {
1519 for( ; (i+5UL) <= isize; i+=5UL )
1520 {
1521 if( jj > ibegin+i+4UL ) continue;
1522
1523 const size_t jend( min( ibegin+i-jj+5UL, jblock ) );
1524 size_t j( 0UL );
1525
1526 for( ; (j+2UL) <= jend; j+=2UL ) {
1527 for( size_t k=0UL; k<ksize; ++k ) {
1528 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1529 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1530 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1531 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1532 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1533 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1534 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1535 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1536 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
1537 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
1538 }
1539 }
1540
1541 if( j<jend ) {
1542 for( size_t k=0UL; k<ksize; ++k ) {
1543 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1544 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1545 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1546 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1547 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
1548 }
1549 }
1550 }
1551 }
1552 else
1553 {
1554 for( ; (i+4UL) <= isize; i+=4UL )
1555 {
1556 if( jj > ibegin+i+3UL ) continue;
1557
1558 const size_t jend( min( ibegin+i-jj+4UL, jblock ) );
1559 size_t j( 0UL );
1560
1561 for( ; (j+2UL) <= jend; j+=2UL ) {
1562 for( size_t k=0UL; k<ksize; ++k ) {
1563 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1564 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1565 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1566 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1567 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1568 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1569 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1570 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1571 }
1572 }
1573
1574 if( j<jend ) {
1575 for( size_t k=0UL; k<ksize; ++k ) {
1576 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1577 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1578 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1579 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1580 }
1581 }
1582 }
1583 }
1584
1585 for( ; (i+2UL) <= isize; i+=2UL )
1586 {
1587 if( jj > ibegin+i+1UL ) continue;
1588
1589 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1590 size_t j( 0UL );
1591
1592 for( ; (j+2UL) <= jend; j+=2UL ) {
1593 for( size_t k=0UL; k<ksize; ++k ) {
1594 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1595 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1596 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1597 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1598 }
1599 }
1600
1601 if( j<jend ) {
1602 for( size_t k=0UL; k<ksize; ++k ) {
1603 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1604 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1605 }
1606 }
1607 }
1608
1609 if( i<isize && jj <= ibegin+i )
1610 {
1611 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1612 size_t j( 0UL );
1613
1614 for( ; (j+2UL) <= jend; j+=2UL ) {
1615 for( size_t k=0UL; k<ksize; ++k ) {
1616 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
1617 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1618 }
1619 }
1620
1621 if( j<jend ) {
1622 for( size_t k=0UL; k<ksize; ++k ) {
1623 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
1624 }
1625 }
1626 }
1627
1628 jj += jblock;
1629 }
1630 }
1631}
1633//*************************************************************************************************
1634
1635
1636//*************************************************************************************************
1655template< typename MT1, typename MT2, typename MT3, typename ST >
1656void lmmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
1657{
1658 using ET1 = ElementType_t<MT1>;
1659 using ET2 = ElementType_t<MT2>;
1660 using ET3 = ElementType_t<MT3>;
1661 using SIMDType = SIMDTrait_t<ET1>;
1662
1671
1674
1677
1680
1681 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
1682
1683 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1684
1685 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
1686 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
1687
1688 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
1689 BLAZE_STATIC_ASSERT( IBLOCK >= SIMDSIZE && IBLOCK % SIMDSIZE == 0UL );
1690
1691 const size_t M( A.rows() );
1692 const size_t N( B.columns() );
1693 const size_t K( A.columns() );
1694
1695 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
1696
1697 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
1698 DynamicMatrix<ET3,true> B2( KBLOCK, N );
1699
1700 decltype(auto) c( derestrict( *C ) );
1701
1702 if( isDefault( beta ) ) {
1703 reset( c );
1704 }
1705 else if( !isOne( beta ) ) {
1706 c *= beta;
1707 }
1708
1709 size_t kk( 0UL );
1710 size_t kblock( 0UL );
1711
1712 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1713 {
1714 if( remainder ) {
1715 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
1716 }
1717 else {
1718 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1719 }
1720
1721 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
1722 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
1723 const size_t jsize ( jend - jbegin );
1724
1725 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize, unchecked ) );
1726
1727 size_t ii( 0UL );
1728 size_t iblock( 0UL );
1729
1730 while( ii < M )
1731 {
1732 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
1733
1734 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
1735 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
1736 ii += iblock;
1737 continue;
1738 }
1739
1740 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock, unchecked ) );
1741
1742 size_t j( 0UL );
1743
1744 if( IsFloatingPoint_v<ET3> )
1745 {
1746 for( ; (j+5UL) <= jsize; j+=5UL )
1747 {
1748 if( ii+iblock < jbegin ) continue;
1749
1750 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1751
1752 for( ; (i+2UL) <= iblock; i+=2UL )
1753 {
1754 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1755
1756 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1757 {
1758 const SIMDType a1( A2.load(i ,k) );
1759 const SIMDType a2( A2.load(i+1UL,k) );
1760
1761 const SIMDType b1( B2.load(k,j ) );
1762 const SIMDType b2( B2.load(k,j+1UL) );
1763 const SIMDType b3( B2.load(k,j+2UL) );
1764 const SIMDType b4( B2.load(k,j+3UL) );
1765 const SIMDType b5( B2.load(k,j+4UL) );
1766
1767 xmm1 += a1 * b1;
1768 xmm2 += a1 * b2;
1769 xmm3 += a1 * b3;
1770 xmm4 += a1 * b4;
1771 xmm5 += a1 * b5;
1772 xmm6 += a2 * b1;
1773 xmm7 += a2 * b2;
1774 xmm8 += a2 * b3;
1775 xmm9 += a2 * b4;
1776 xmm10 += a2 * b5;
1777 }
1778
1779 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1780 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1781 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1782 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1783 c(ii+i ,jbegin+j+4UL) += sum( xmm5 ) * alpha;
1784 c(ii+i+1UL,jbegin+j ) += sum( xmm6 ) * alpha;
1785 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm7 ) * alpha;
1786 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm8 ) * alpha;
1787 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm9 ) * alpha;
1788 c(ii+i+1UL,jbegin+j+4UL) += sum( xmm10 ) * alpha;
1789 }
1790
1791 if( i<iblock )
1792 {
1793 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1794
1795 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1796 {
1797 const SIMDType a1( A2.load(i,k) );
1798
1799 xmm1 += a1 * B2.load(k,j );
1800 xmm2 += a1 * B2.load(k,j+1UL);
1801 xmm3 += a1 * B2.load(k,j+2UL);
1802 xmm4 += a1 * B2.load(k,j+3UL);
1803 xmm5 += a1 * B2.load(k,j+4UL);
1804 }
1805
1806 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
1807 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1808 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1809 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1810 c(ii+i,jbegin+j+4UL) += sum( xmm5 ) * alpha;
1811 }
1812 }
1813 }
1814 else
1815 {
1816 for( ; (j+4UL) <= jsize; j+=4UL )
1817 {
1818 if( ii+iblock < jbegin ) continue;
1819
1820 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1821
1822 for( ; (i+2UL) <= iblock; i+=2UL )
1823 {
1824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1825
1826 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1827 {
1828 const SIMDType a1( A2.load(i ,k) );
1829 const SIMDType a2( A2.load(i+1UL,k) );
1830
1831 const SIMDType b1( B2.load(k,j ) );
1832 const SIMDType b2( B2.load(k,j+1UL) );
1833 const SIMDType b3( B2.load(k,j+2UL) );
1834 const SIMDType b4( B2.load(k,j+3UL) );
1835
1836 xmm1 += a1 * b1;
1837 xmm2 += a1 * b2;
1838 xmm3 += a1 * b3;
1839 xmm4 += a1 * b4;
1840 xmm5 += a2 * b1;
1841 xmm6 += a2 * b2;
1842 xmm7 += a2 * b3;
1843 xmm8 += a2 * b4;
1844 }
1845
1846 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1847 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1848 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1849 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1850 c(ii+i+1UL,jbegin+j ) += sum( xmm5 ) * alpha;
1851 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
1852 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm7 ) * alpha;
1853 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm8 ) * alpha;
1854 }
1855
1856 if( i<iblock )
1857 {
1858 SIMDType xmm1, xmm2, xmm3, xmm4;
1859
1860 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1861 {
1862 const SIMDType a1( A2.load(i,k) );
1863
1864 xmm1 += a1 * B2.load(k,j );
1865 xmm2 += a1 * B2.load(k,j+1UL);
1866 xmm3 += a1 * B2.load(k,j+2UL);
1867 xmm4 += a1 * B2.load(k,j+3UL);
1868 }
1869
1870 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
1871 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1872 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1873 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1874 }
1875 }
1876 }
1877
1878 for( ; (j+2UL) <= jsize; j+=2UL )
1879 {
1880 if( ii+iblock < jbegin ) continue;
1881
1882 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1883
1884 for( ; (i+4UL) <= iblock; i+=4UL )
1885 {
1886 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1887
1888 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1889 {
1890 const SIMDType a1( A2.load(i ,k) );
1891 const SIMDType a2( A2.load(i+1UL,k) );
1892 const SIMDType a3( A2.load(i+2UL,k) );
1893 const SIMDType a4( A2.load(i+3UL,k) );
1894
1895 const SIMDType b1( B2.load(k,j ) );
1896 const SIMDType b2( B2.load(k,j+1UL) );
1897
1898 xmm1 += a1 * b1;
1899 xmm2 += a1 * b2;
1900 xmm3 += a2 * b1;
1901 xmm4 += a2 * b2;
1902 xmm5 += a3 * b1;
1903 xmm6 += a3 * b2;
1904 xmm7 += a4 * b1;
1905 xmm8 += a4 * b2;
1906 }
1907
1908 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1909 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1910 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
1911 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
1912 c(ii+i+2UL,jbegin+j ) += sum( xmm5 ) * alpha;
1913 c(ii+i+2UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
1914 c(ii+i+3UL,jbegin+j ) += sum( xmm7 ) * alpha;
1915 c(ii+i+3UL,jbegin+j+1UL) += sum( xmm8 ) * alpha;
1916 }
1917
1918 for( ; (i+2UL) <= iblock; i+=2UL )
1919 {
1920 SIMDType xmm1, xmm2, xmm3, xmm4;
1921
1922 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1923 {
1924 const SIMDType a1( A2.load(i ,k) );
1925 const SIMDType a2( A2.load(i+1UL,k) );
1926
1927 const SIMDType b1( B2.load(k,j ) );
1928 const SIMDType b2( B2.load(k,j+1UL) );
1929
1930 xmm1 += a1 * b1;
1931 xmm2 += a1 * b2;
1932 xmm3 += a2 * b1;
1933 xmm4 += a2 * b2;
1934 }
1935
1936 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1937 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1938 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
1939 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
1940 }
1941
1942 if( i<iblock )
1943 {
1944 SIMDType xmm1, xmm2;
1945
1946 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1947 {
1948 const SIMDType a1( A2.load(i,k) );
1949
1950 xmm1 += a1 * B2.load(k,j );
1951 xmm2 += a1 * B2.load(k,j+1UL);
1952 }
1953
1954 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
1955 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1956 }
1957 }
1958
1959 if( j<jsize && ii+iblock >= jbegin )
1960 {
1961 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1962
1963 for( ; (i+2UL) <= iblock; i+=2UL )
1964 {
1965 SIMDType xmm1, xmm2;
1966
1967 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1968 {
1969 const SIMDType b1( B2.load(k,j) );
1970
1971 xmm1 += A2.load(i ,k) * b1;
1972 xmm2 += A2.load(i+1UL,k) * b1;
1973 }
1974
1975 c(ii+i ,jbegin+j) += sum( xmm1 ) * alpha;
1976 c(ii+i+1UL,jbegin+j) += sum( xmm2 ) * alpha;
1977 }
1978
1979 if( i<iblock )
1980 {
1981 SIMDType xmm1;
1982
1983 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1984 {
1985 xmm1 += A2.load(i,k) * B2.load(k,j);
1986 }
1987
1988 c(ii+i,jbegin+j) += sum( xmm1 ) * alpha;
1989 }
1990 }
1991
1992 ii += iblock;
1993 }
1994
1995 kk += kblock;
1996 }
1997
1998 if( remainder && kk < K )
1999 {
2000 const size_t ksize( K - kk );
2001
2002 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2003 const size_t jsize ( N - jbegin );
2004
2005 B2 = serial( submatrix( B, kk, jbegin, ksize, jsize, unchecked ) );
2006
2007 size_t ii( 0UL );
2008 size_t iblock( 0UL );
2009
2010 while( ii < M )
2011 {
2012 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2013
2014 if( IsLower_v<MT2> && ii+iblock <= kk ) {
2015 ii += iblock;
2016 continue;
2017 }
2018
2019 A2 = serial( submatrix( A, ii, kk, iblock, ksize, unchecked ) );
2020
2021 size_t j( 0UL );
2022
2023 if( IsFloatingPoint_v<ET1> )
2024 {
2025 for( ; (j+5UL) <= jsize; j+=5UL )
2026 {
2027 if( ii+iblock < jbegin ) continue;
2028
2029 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2030
2031 for( ; (i+2UL) <= iblock; i+=2UL ) {
2032 for( size_t k=0UL; k<ksize; ++k ) {
2033 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2034 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2035 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2036 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2037 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
2038 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2039 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2040 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2041 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2042 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
2043 }
2044 }
2045
2046 if( i<iblock ) {
2047 for( size_t k=0UL; k<ksize; ++k ) {
2048 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2049 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2050 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2051 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2052 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
2053 }
2054 }
2055 }
2056 }
2057 else
2058 {
2059 for( ; (j+4UL) <= jsize; j+=4UL )
2060 {
2061 if( ii+iblock < jbegin ) continue;
2062
2063 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2064
2065 for( ; (i+2UL) <= iblock; i+=2UL ) {
2066 for( size_t k=0UL; k<ksize; ++k ) {
2067 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2068 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2069 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2070 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2071 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2072 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2073 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2074 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2075 }
2076 }
2077
2078 if( i<iblock ) {
2079 for( size_t k=0UL; k<ksize; ++k ) {
2080 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2081 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2082 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2083 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2084 }
2085 }
2086 }
2087 }
2088
2089 for( ; (j+2UL) <= jsize; j+=2UL )
2090 {
2091 if( ii+iblock < jbegin ) continue;
2092
2093 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2094
2095 for( ; (i+2UL) <= iblock; i+=2UL ) {
2096 for( size_t k=0UL; k<ksize; ++k ) {
2097 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2098 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2099 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2100 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2101 }
2102 }
2103
2104 if( i<iblock ) {
2105 for( size_t k=0UL; k<ksize; ++k ) {
2106 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2107 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2108 }
2109 }
2110 }
2111
2112 if( j<jsize )
2113 {
2114 if( ii+iblock < jbegin ) continue;
2115
2116 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2117
2118 for( ; (i+2UL) <= iblock; i+=2UL ) {
2119 for( size_t k=0UL; k<ksize; ++k ) {
2120 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
2121 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2122 }
2123 }
2124
2125 if( i<iblock ) {
2126 for( size_t k=0UL; k<ksize; ++k ) {
2127 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
2128 }
2129 }
2130 }
2131
2132 ii += iblock;
2133 }
2134 }
2135}
2137//*************************************************************************************************
2138
2139
2140//*************************************************************************************************
2156template< typename MT1, typename MT2, typename MT3 >
2157inline void lmmm( MT1& C, const MT2& A, const MT3& B )
2158{
2159 using ET1 = ElementType_t<MT1>;
2160 using ET2 = ElementType_t<MT2>;
2161 using ET3 = ElementType_t<MT3>;
2162
2165
2166 lmmm( C, A, B, ET1(1), ET1(0) );
2167}
2169//*************************************************************************************************
2170
2171
2172
2173
2174//=================================================================================================
2175//
2176// UPPER DENSE MATRIX MULTIPLICATION KERNELS
2177//
2178//=================================================================================================
2179
2180//*************************************************************************************************
2199template< typename MT1, typename MT2, typename MT3, typename ST >
2200void ummm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
2201{
2202 using ET1 = ElementType_t<MT1>;
2203 using ET2 = ElementType_t<MT2>;
2204 using ET3 = ElementType_t<MT3>;
2205 using SIMDType = SIMDTrait_t<ET1>;
2206
2215
2218
2221
2224
2225 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
2226
2227 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2228
2229 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
2230 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
2231
2232 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
2233 BLAZE_STATIC_ASSERT( JBLOCK >= SIMDSIZE && JBLOCK % SIMDSIZE == 0UL );
2234
2235 const size_t M( A.rows() );
2236 const size_t N( B.columns() );
2237 const size_t K( A.columns() );
2238
2239 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
2240
2241 DynamicMatrix<ET2,false> A2( M, KBLOCK );
2242 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
2243
2244 decltype(auto) c( derestrict( *C ) );
2245
2246 if( isDefault( beta ) ) {
2247 reset( c );
2248 }
2249 else if( !isOne( beta ) ) {
2250 c *= beta;
2251 }
2252
2253 size_t kk( 0UL );
2254 size_t kblock( 0UL );
2255
2256 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2257 {
2258 if( remainder ) {
2259 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
2260 }
2261 else {
2262 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2263 }
2264
2265 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2266 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
2267 const size_t isize ( iend - ibegin );
2268
2269 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock, unchecked ) );
2270
2271 size_t jj( 0UL );
2272 size_t jblock( 0UL );
2273
2274 while( jj < N )
2275 {
2276 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2277
2278 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
2279 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
2280 jj += jblock;
2281 continue;
2282 }
2283
2284 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock, unchecked ) );
2285
2286 size_t i( 0UL );
2287
2288 if( IsFloatingPoint_v<ET1> )
2289 {
2290 for( ; (i+5UL) <= isize; i+=5UL )
2291 {
2292 if( jj+jblock < ibegin ) continue;
2293
2294 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2295
2296 for( ; (j+2UL) <= jblock; j+=2UL )
2297 {
2298 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2299
2300 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2301 {
2302 const SIMDType a1( A2.load(i ,k) );
2303 const SIMDType a2( A2.load(i+1UL,k) );
2304 const SIMDType a3( A2.load(i+2UL,k) );
2305 const SIMDType a4( A2.load(i+3UL,k) );
2306 const SIMDType a5( A2.load(i+4UL,k) );
2307
2308 const SIMDType b1( B2.load(k,j ) );
2309 const SIMDType b2( B2.load(k,j+1UL) );
2310
2311 xmm1 += a1 * b1;
2312 xmm2 += a1 * b2;
2313 xmm3 += a2 * b1;
2314 xmm4 += a2 * b2;
2315 xmm5 += a3 * b1;
2316 xmm6 += a3 * b2;
2317 xmm7 += a4 * b1;
2318 xmm8 += a4 * b2;
2319 xmm9 += a5 * b1;
2320 xmm10 += a5 * b2;
2321 }
2322
2323 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2324 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2325 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
2326 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
2327 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
2328 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
2329 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
2330 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
2331 c(ibegin+i+4UL,jj+j ) += sum( xmm9 ) * alpha;
2332 c(ibegin+i+4UL,jj+j+1UL) += sum( xmm10 ) * alpha;
2333 }
2334
2335 if( j<jblock )
2336 {
2337 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2338
2339 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2340 {
2341 const SIMDType a1( A2.load(i ,k) );
2342 const SIMDType a2( A2.load(i+1UL,k) );
2343 const SIMDType a3( A2.load(i+2UL,k) );
2344 const SIMDType a4( A2.load(i+3UL,k) );
2345 const SIMDType a5( A2.load(i+4UL,k) );
2346
2347 const SIMDType b1( B2.load(k,j) );
2348
2349 xmm1 += a1 * b1;
2350 xmm2 += a2 * b1;
2351 xmm3 += a3 * b1;
2352 xmm4 += a4 * b1;
2353 xmm5 += a5 * b1;
2354 }
2355
2356 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
2357 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
2358 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
2359 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
2360 c(ibegin+i+4UL,jj+j) += sum( xmm5 ) * alpha;
2361 }
2362 }
2363 }
2364 else
2365 {
2366 for( ; (i+4UL) <= isize; i+=4UL )
2367 {
2368 if( jj+jblock < ibegin ) continue;
2369
2370 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2371
2372 for( ; (j+2UL) <= jblock; j+=2UL )
2373 {
2374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2375
2376 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2377 {
2378 const SIMDType a1( A2.load(i ,k) );
2379 const SIMDType a2( A2.load(i+1UL,k) );
2380 const SIMDType a3( A2.load(i+2UL,k) );
2381 const SIMDType a4( A2.load(i+3UL,k) );
2382
2383 const SIMDType b1( B2.load(k,j ) );
2384 const SIMDType b2( B2.load(k,j+1UL) );
2385
2386 xmm1 += a1 * b1;
2387 xmm2 += a1 * b2;
2388 xmm3 += a2 * b1;
2389 xmm4 += a2 * b2;
2390 xmm5 += a3 * b1;
2391 xmm6 += a3 * b2;
2392 xmm7 += a4 * b1;
2393 xmm8 += a4 * b2;
2394 }
2395
2396 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2397 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2398 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
2399 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
2400 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
2401 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
2402 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
2403 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
2404 }
2405
2406 if( j<jblock )
2407 {
2408 SIMDType xmm1, xmm2, xmm3, xmm4;
2409
2410 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2411 {
2412 const SIMDType a1( A2.load(i ,k) );
2413 const SIMDType a2( A2.load(i+1UL,k) );
2414 const SIMDType a3( A2.load(i+2UL,k) );
2415 const SIMDType a4( A2.load(i+3UL,k) );
2416
2417 const SIMDType b1( B2.load(k,j) );
2418
2419 xmm1 += a1 * b1;
2420 xmm2 += a2 * b1;
2421 xmm3 += a3 * b1;
2422 xmm4 += a4 * b1;
2423 }
2424
2425 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
2426 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
2427 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
2428 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
2429 }
2430 }
2431 }
2432
2433 for( ; (i+2UL) <= isize; i+=2UL )
2434 {
2435 if( jj+jblock < ibegin ) continue;
2436
2437 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2438
2439 for( ; (j+4UL) <= jblock; j+=4UL )
2440 {
2441 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2442
2443 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2444 {
2445 const SIMDType a1( A2.load(i ,k) );
2446 const SIMDType a2( A2.load(i+1UL,k) );
2447
2448 const SIMDType b1( B2.load(k,j ) );
2449 const SIMDType b2( B2.load(k,j+1UL) );
2450 const SIMDType b3( B2.load(k,j+2UL) );
2451 const SIMDType b4( B2.load(k,j+3UL) );
2452
2453 xmm1 += a1 * b1;
2454 xmm2 += a1 * b2;
2455 xmm3 += a1 * b3;
2456 xmm4 += a1 * b4;
2457 xmm5 += a2 * b1;
2458 xmm6 += a2 * b2;
2459 xmm7 += a2 * b3;
2460 xmm8 += a2 * b4;
2461 }
2462
2463 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2464 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2465 c(ibegin+i ,jj+j+2UL) += sum( xmm3 ) * alpha;
2466 c(ibegin+i ,jj+j+3UL) += sum( xmm4 ) * alpha;
2467 c(ibegin+i+1UL,jj+j ) += sum( xmm5 ) * alpha;
2468 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm6 ) * alpha;
2469 c(ibegin+i+1UL,jj+j+2UL) += sum( xmm7 ) * alpha;
2470 c(ibegin+i+1UL,jj+j+3UL) += sum( xmm8 ) * alpha;
2471 }
2472
2473 for( ; (j+2UL) <= jblock; j+=2UL )
2474 {
2475 SIMDType xmm1, xmm2, xmm3, xmm4;
2476
2477 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2478 {
2479 const SIMDType a1( A2.load(i ,k) );
2480 const SIMDType a2( A2.load(i+1UL,k) );
2481
2482 const SIMDType b1( B2.load(k,j ) );
2483 const SIMDType b2( B2.load(k,j+1UL) );
2484
2485 xmm1 += a1 * b1;
2486 xmm2 += a1 * b2;
2487 xmm3 += a2 * b1;
2488 xmm4 += a2 * b2;
2489 }
2490
2491 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2492 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2493 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
2494 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
2495 }
2496
2497 if( j<jblock )
2498 {
2499 SIMDType xmm1, xmm2;
2500
2501 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2502 {
2503 const SIMDType a1( A2.load(i ,k) );
2504 const SIMDType a2( A2.load(i+1UL,k) );
2505
2506 const SIMDType b1( B2.load(k,j) );
2507
2508 xmm1 += a1 * b1;
2509 xmm2 += a2 * b1;
2510 }
2511
2512 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
2513 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
2514 }
2515 }
2516
2517 if( i<isize && jj+jblock >= ibegin )
2518 {
2519 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2520
2521 for( ; (j+2UL) <= jblock; j+=2UL )
2522 {
2523 SIMDType xmm1, xmm2;
2524
2525 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2526 {
2527 const SIMDType a1( A2.load(i,k) );
2528
2529 xmm1 += a1 * B2.load(k,j );
2530 xmm2 += a1 * B2.load(k,j+1UL);
2531 }
2532
2533 c(ibegin+i,jj+j ) += sum( xmm1 ) * alpha;
2534 c(ibegin+i,jj+j+1UL) += sum( xmm2 ) * alpha;
2535 }
2536
2537 if( j<jblock )
2538 {
2539 SIMDType xmm1;
2540
2541 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2542 {
2543 const SIMDType a1( A2.load(i,k) );
2544
2545 xmm1 += a1 * B2.load(k,j);
2546 }
2547
2548 c(ibegin+i,jj+j) += sum( xmm1 ) * alpha;
2549 }
2550 }
2551
2552 jj += jblock;
2553 }
2554
2555 kk += kblock;
2556 }
2557
2558 if( remainder && kk < K )
2559 {
2560 const size_t ksize( K - kk );
2561
2562 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2563 const size_t isize ( M - ibegin );
2564
2565 A2 = serial( submatrix( A, ibegin, kk, isize, ksize, unchecked ) );
2566
2567 size_t jj( 0UL );
2568 size_t jblock( 0UL );
2569
2570 while( jj < N )
2571 {
2572 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2573
2574 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
2575 jj += jblock;
2576 continue;
2577 }
2578
2579 B2 = serial( submatrix( B, kk, jj, ksize, jblock, unchecked ) );
2580
2581 size_t i( 0UL );
2582
2583 if( IsFloatingPoint_v<ET1> )
2584 {
2585 for( ; (i+5UL) <= isize; i+=5UL )
2586 {
2587 if( jj+jblock < ibegin ) continue;
2588
2589 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2590
2591 for( ; (j+2UL) <= jblock; j+=2UL ) {
2592 for( size_t k=0UL; k<ksize; ++k ) {
2593 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2594 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2595 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2596 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2597 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2598 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2599 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2600 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2601 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
2602 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
2603 }
2604 }
2605
2606 if( j<jblock ) {
2607 for( size_t k=0UL; k<ksize; ++k ) {
2608 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2609 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2610 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2611 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2612 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
2613 }
2614 }
2615 }
2616 }
2617 else
2618 {
2619 for( ; (i+4UL) <= isize; i+=4UL )
2620 {
2621 if( jj+jblock < ibegin ) continue;
2622
2623 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2624
2625 for( ; (j+2UL) <= jblock; j+=2UL ) {
2626 for( size_t k=0UL; k<ksize; ++k ) {
2627 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2628 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2629 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2630 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2631 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2632 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2633 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2634 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2635 }
2636 }
2637
2638 if( j<jblock ) {
2639 for( size_t k=0UL; k<ksize; ++k ) {
2640 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2641 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2642 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2643 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2644 }
2645 }
2646 }
2647 }
2648
2649 for( ; (i+2UL) <= isize; i+=2UL )
2650 {
2651 if( jj+jblock < ibegin ) continue;
2652
2653 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2654
2655 for( ; (j+2UL) <= jblock; j+=2UL ) {
2656 for( size_t k=0UL; k<ksize; ++k ) {
2657 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2658 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2659 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2660 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2661 }
2662 }
2663
2664 if( j<jblock ) {
2665 for( size_t k=0UL; k<ksize; ++k ) {
2666 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2667 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2668 }
2669 }
2670 }
2671
2672 if( i<isize && jj+jblock >= ibegin )
2673 {
2674 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2675
2676 for( ; (j+2UL) <= jblock; j+=2UL ) {
2677 for( size_t k=0UL; k<ksize; ++k ) {
2678 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
2679 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2680 }
2681 }
2682
2683 if( j<jblock ) {
2684 for( size_t k=0UL; k<ksize; ++k ) {
2685 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
2686 }
2687 }
2688 }
2689
2690 jj += jblock;
2691 }
2692 }
2693}
2695//*************************************************************************************************
2696
2697
2698//*************************************************************************************************
2717template< typename MT1, typename MT2, typename MT3, typename ST >
2718void ummm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
2719{
2720 using ET1 = ElementType_t<MT1>;
2721 using ET2 = ElementType_t<MT2>;
2722 using ET3 = ElementType_t<MT3>;
2723 using SIMDType = SIMDTrait_t<ET1>;
2724
2733
2736
2739
2742
2743 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
2744
2745 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2746
2747 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
2748 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
2749
2750 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
2751 BLAZE_STATIC_ASSERT( IBLOCK >= SIMDSIZE && IBLOCK % SIMDSIZE == 0UL );
2752
2753 const size_t M( A.rows() );
2754 const size_t N( B.columns() );
2755 const size_t K( A.columns() );
2756
2757 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
2758
2759 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
2760 DynamicMatrix<ET3,true> B2( KBLOCK, N );
2761
2762 decltype(auto) c( derestrict( *C ) );
2763
2764 if( isDefault( beta ) ) {
2765 reset( c );
2766 }
2767 else if( !isOne( beta ) ) {
2768 c *= beta;
2769 }
2770
2771 size_t kk( 0UL );
2772 size_t kblock( 0UL );
2773
2774 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2775 {
2776 if( remainder ) {
2777 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
2778 }
2779 else {
2780 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2781 }
2782
2783 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2784 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
2785 const size_t jsize ( jend - jbegin );
2786
2787 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize, unchecked ) );
2788
2789 size_t ii( 0UL );
2790 size_t iblock( 0UL );
2791
2792 while( ii < M )
2793 {
2794 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2795
2796 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
2797 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
2798 ii += iblock;
2799 continue;
2800 }
2801
2802 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock, unchecked ) );
2803
2804 size_t j( 0UL );
2805
2806 if( IsFloatingPoint_v<ET3> )
2807 {
2808 for( ; (j+5UL) <= jsize; j+=5UL )
2809 {
2810 if( ii > jbegin+j+4UL ) continue;
2811
2812 const size_t iend( min( iblock, jbegin+j-ii+5UL ) );
2813 size_t i( 0UL );
2814
2815 for( ; (i+2UL) <= iend; i+=2UL )
2816 {
2817 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2818
2819 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2820 {
2821 const SIMDType a1( A2.load(i ,k) );
2822 const SIMDType a2( A2.load(i+1UL,k) );
2823
2824 const SIMDType b1( B2.load(k,j ) );
2825 const SIMDType b2( B2.load(k,j+1UL) );
2826 const SIMDType b3( B2.load(k,j+2UL) );
2827 const SIMDType b4( B2.load(k,j+3UL) );
2828 const SIMDType b5( B2.load(k,j+4UL) );
2829
2830 xmm1 += a1 * b1;
2831 xmm2 += a1 * b2;
2832 xmm3 += a1 * b3;
2833 xmm4 += a1 * b4;
2834 xmm5 += a1 * b5;
2835 xmm6 += a2 * b1;
2836 xmm7 += a2 * b2;
2837 xmm8 += a2 * b3;
2838 xmm9 += a2 * b4;
2839 xmm10 += a2 * b5;
2840 }
2841
2842 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
2843 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2844 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2845 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2846 c(ii+i ,jbegin+j+4UL) += sum( xmm5 ) * alpha;
2847 c(ii+i+1UL,jbegin+j ) += sum( xmm6 ) * alpha;
2848 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm7 ) * alpha;
2849 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm8 ) * alpha;
2850 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm9 ) * alpha;
2851 c(ii+i+1UL,jbegin+j+4UL) += sum( xmm10 ) * alpha;
2852 }
2853
2854 if( i<iend )
2855 {
2856 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2857
2858 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2859 {
2860 const SIMDType a1( A2.load(i,k) );
2861
2862 xmm1 += a1 * B2.load(k,j );
2863 xmm2 += a1 * B2.load(k,j+1UL);
2864 xmm3 += a1 * B2.load(k,j+2UL);
2865 xmm4 += a1 * B2.load(k,j+3UL);
2866 xmm5 += a1 * B2.load(k,j+4UL);
2867 }
2868
2869 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
2870 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2871 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2872 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2873 c(ii+i,jbegin+j+4UL) += sum( xmm5 ) * alpha;
2874 }
2875 }
2876 }
2877 else
2878 {
2879 for( ; (j+4UL) <= jsize; j+=4UL )
2880 {
2881 if( ii > jbegin+j+3UL ) continue;
2882
2883 const size_t iend( min( iblock, jbegin+j-ii+4UL ) );
2884 size_t i( 0UL );
2885
2886 for( ; (i+2UL) <= iend; i+=2UL )
2887 {
2888 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2889
2890 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2891 {
2892 const SIMDType a1( A2.load(i ,k) );
2893 const SIMDType a2( A2.load(i+1UL,k) );
2894
2895 const SIMDType b1( B2.load(k,j ) );
2896 const SIMDType b2( B2.load(k,j+1UL) );
2897 const SIMDType b3( B2.load(k,j+2UL) );
2898 const SIMDType b4( B2.load(k,j+3UL) );
2899
2900 xmm1 += a1 * b1;
2901 xmm2 += a1 * b2;
2902 xmm3 += a1 * b3;
2903 xmm4 += a1 * b4;
2904 xmm5 += a2 * b1;
2905 xmm6 += a2 * b2;
2906 xmm7 += a2 * b3;
2907 xmm8 += a2 * b4;
2908 }
2909
2910 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
2911 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2912 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2913 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2914 c(ii+i+1UL,jbegin+j ) += sum( xmm5 ) * alpha;
2915 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
2916 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm7 ) * alpha;
2917 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm8 ) * alpha;
2918 }
2919
2920 if( i<iend )
2921 {
2922 SIMDType xmm1, xmm2, xmm3, xmm4;
2923
2924 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2925 {
2926 const SIMDType a1( A2.load(i,k) );
2927
2928 xmm1 += a1 * B2.load(k,j );
2929 xmm2 += a1 * B2.load(k,j+1UL);
2930 xmm3 += a1 * B2.load(k,j+2UL);
2931 xmm4 += a1 * B2.load(k,j+3UL);
2932 }
2933
2934 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
2935 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2936 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2937 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2938 }
2939 }
2940 }
2941
2942 for( ; (j+2UL) <= jsize; j+=2UL )
2943 {
2944 if( ii > jbegin+j+1UL ) continue;
2945
2946 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
2947 size_t i( 0UL );
2948
2949 for( ; (i+4UL) <= iend; i+=4UL )
2950 {
2951 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2952
2953 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2954 {
2955 const SIMDType a1( A2.load(i ,k) );
2956 const SIMDType a2( A2.load(i+1UL,k) );
2957 const SIMDType a3( A2.load(i+2UL,k) );
2958 const SIMDType a4( A2.load(i+3UL,k) );
2959
2960 const SIMDType b1( B2.load(k,j ) );
2961 const SIMDType b2( B2.load(k,j+1UL) );
2962
2963 xmm1 += a1 * b1;
2964 xmm2 += a1 * b2;
2965 xmm3 += a2 * b1;
2966 xmm4 += a2 * b2;
2967 xmm5 += a3 * b1;
2968 xmm6 += a3 * b2;
2969 xmm7 += a4 * b1;
2970 xmm8 += a4 * b2;
2971 }
2972
2973 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
2974 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2975 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
2976 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
2977 c(ii+i+2UL,jbegin+j ) += sum( xmm5 ) * alpha;
2978 c(ii+i+2UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
2979 c(ii+i+3UL,jbegin+j ) += sum( xmm7 ) * alpha;
2980 c(ii+i+3UL,jbegin+j+1UL) += sum( xmm8 ) * alpha;
2981 }
2982
2983 for( ; (i+2UL) <= iend; i+=2UL )
2984 {
2985 SIMDType xmm1, xmm2, xmm3, xmm4;
2986
2987 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2988 {
2989 const SIMDType a1( A2.load(i ,k) );
2990 const SIMDType a2( A2.load(i+1UL,k) );
2991
2992 const SIMDType b1( B2.load(k,j ) );
2993 const SIMDType b2( B2.load(k,j+1UL) );
2994
2995 xmm1 += a1 * b1;
2996 xmm2 += a1 * b2;
2997 xmm3 += a2 * b1;
2998 xmm4 += a2 * b2;
2999 }
3000
3001 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
3002 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
3003 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
3004 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
3005 }
3006
3007 if( i<iend )
3008 {
3009 SIMDType xmm1, xmm2;
3010
3011 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
3012 {
3013 const SIMDType a1( A2.load(i,k) );
3014
3015 xmm1 += a1 * B2.load(k,j );
3016 xmm2 += a1 * B2.load(k,j+1UL);
3017 }
3018
3019 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
3020 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
3021 }
3022 }
3023
3024 if( j<jsize && ii <= jbegin+j )
3025 {
3026 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
3027 size_t i( 0UL );
3028
3029 for( ; (i+2UL) <= iend; i+=2UL )
3030 {
3031 SIMDType xmm1, xmm2;
3032
3033 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
3034 {
3035 const SIMDType b1( B2.load(k,j) );
3036
3037 xmm1 += A2.load(i ,k) * b1;
3038 xmm2 += A2.load(i+1UL,k) * b1;
3039 }
3040
3041 c(ii+i ,jbegin+j) += sum( xmm1 ) * alpha;
3042 c(ii+i+1UL,jbegin+j) += sum( xmm2 ) * alpha;
3043 }
3044
3045 if( i<iend )
3046 {
3047 SIMDType xmm1;
3048
3049 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
3050 {
3051 xmm1 += A2.load(i,k) * B2.load(k,j);
3052 }
3053
3054 c(ii+i,jbegin+j) += sum( xmm1 ) * alpha;
3055 }
3056 }
3057
3058 ii += iblock;
3059 }
3060
3061 kk += kblock;
3062 }
3063
3064 if( remainder && kk < K )
3065 {
3066 const size_t ksize( K - kk );
3067
3068 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
3069 const size_t jsize ( N - jbegin );
3070
3071 B2 = serial( submatrix( B, kk, jbegin, ksize, jsize, unchecked ) );
3072
3073 size_t ii( 0UL );
3074 size_t iblock( 0UL );
3075
3076 while( ii < M )
3077 {
3078 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
3079
3080 if( IsLower_v<MT2> && ii+iblock <= kk ) {
3081 ii += iblock;
3082 continue;
3083 }
3084
3085 A2 = serial( submatrix( A, ii, kk, iblock, ksize, unchecked ) );
3086
3087 size_t j( 0UL );
3088
3089 if( IsFloatingPoint_v<ET1> )
3090 {
3091 for( ; (j+5UL) <= jsize; j+=5UL )
3092 {
3093 if( ii > jbegin+j+4UL ) continue;
3094
3095 const size_t iend( min( iblock, jbegin+j-ii+5UL ) );
3096 size_t i( 0UL );
3097
3098 for( ; (i+2UL) <= iend; i+=2UL ) {
3099 for( size_t k=0UL; k<ksize; ++k ) {
3100 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3101 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3102 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3103 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3104 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
3105 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3106 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3107 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3108 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3109 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
3110 }
3111 }
3112
3113 if( i<iend ) {
3114 for( size_t k=0UL; k<ksize; ++k ) {
3115 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3116 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3117 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3118 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3119 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
3120 }
3121 }
3122 }
3123 }
3124 else
3125 {
3126 for( ; (j+4UL) <= jsize; j+=4UL )
3127 {
3128 if( ii > jbegin+j+3UL ) continue;
3129
3130 const size_t iend( min( iblock, jbegin+j-ii+4UL ) );
3131 size_t i( 0UL );
3132
3133 for( ; (i+2UL) <= iend; i+=2UL ) {
3134 for( size_t k=0UL; k<ksize; ++k ) {
3135 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3136 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3137 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3138 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3139 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3140 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3141 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3142 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3143 }
3144 }
3145
3146 if( i<iend ) {
3147 for( size_t k=0UL; k<ksize; ++k ) {
3148 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3149 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3150 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3151 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3152 }
3153 }
3154 }
3155 }
3156
3157 for( ; (j+2UL) <= jsize; j+=2UL )
3158 {
3159 if( ii > jbegin+j+1UL ) continue;
3160
3161 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
3162 size_t i( 0UL );
3163
3164 for( ; (i+2UL) <= iend; i+=2UL ) {
3165 for( size_t k=0UL; k<ksize; ++k ) {
3166 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3167 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3168 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3169 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3170 }
3171 }
3172
3173 if( i<iend ) {
3174 for( size_t k=0UL; k<ksize; ++k ) {
3175 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3176 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3177 }
3178 }
3179 }
3180
3181 if( j<jsize && ii <= jbegin+j )
3182 {
3183 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
3184 size_t i( 0UL );
3185
3186 for( ; (i+2UL) <= iend; i+=2UL ) {
3187 for( size_t k=0UL; k<ksize; ++k ) {
3188 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
3189 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
3190 }
3191 }
3192
3193 if( i<iend ) {
3194 for( size_t k=0UL; k<ksize; ++k ) {
3195 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
3196 }
3197 }
3198 }
3199
3200 ii += iblock;
3201 }
3202 }
3203}
3205//*************************************************************************************************
3206
3207
3208//*************************************************************************************************
3224template< typename MT1, typename MT2, typename MT3 >
3225inline void ummm( MT1& C, const MT2& A, const MT3& B )
3226{
3227 using ET1 = ElementType_t<MT1>;
3228 using ET2 = ElementType_t<MT2>;
3229 using ET3 = ElementType_t<MT3>;
3230
3233
3234 ummm( C, A, B, ET1(1), ET1(0) );
3235}
3237//*************************************************************************************************
3238
3239
3240
3241
3242//=================================================================================================
3243//
3244// SYMMETRIC DENSE MATRIX MULTIPLICATION KERNELS
3245//
3246//=================================================================================================
3247
3248//*************************************************************************************************
3266template< typename MT1, typename MT2, typename MT3, typename ST >
3267void smmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha )
3268{
3269 using ET1 = ElementType_t<MT1>;
3270 using ET2 = ElementType_t<MT2>;
3271 using ET3 = ElementType_t<MT3>;
3272
3277
3280
3283
3286
3287 const size_t M( A.rows() );
3288 const size_t N( B.columns() );
3289
3290 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3291
3292 lmmm( C, A, B, alpha, ST(0) );
3293
3294 for( size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3295 {
3296 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3297
3298 for( size_t i=ii; i<iend; ++i ) {
3299 for( size_t j=i+1UL; j<iend; ++j ) {
3300 (*C)(i,j) = (*C)(j,i);
3301 }
3302 }
3303
3304 for( size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3305 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3306 for( size_t i=ii; i<iend; ++i ) {
3307 for( size_t j=jj; j<jend; ++j ) {
3308 (*C)(i,j) = (*C)(j,i);
3309 }
3310 }
3311 }
3312 }
3313}
3315//*************************************************************************************************
3316
3317
3318//*************************************************************************************************
3336template< typename MT1, typename MT2, typename MT3, typename ST >
3337void smmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha )
3338{
3339 using ET1 = ElementType_t<MT1>;
3340 using ET2 = ElementType_t<MT2>;
3341 using ET3 = ElementType_t<MT3>;
3342
3347
3350
3353
3356
3357 const size_t M( A.rows() );
3358 const size_t N( B.columns() );
3359
3360 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3361
3362 ummm( C, A, B, alpha, ST(0) );
3363
3364 for( size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3365 {
3366 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3367
3368 for( size_t j=jj; j<jend; ++j ) {
3369 for( size_t i=jj+1UL; i<jend; ++i ) {
3370 (*C)(i,j) = (*C)(j,i);
3371 }
3372 }
3373
3374 for( size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3375 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3376 for( size_t j=jj; j<jend; ++j ) {
3377 for( size_t i=ii; i<iend; ++i ) {
3378 (*C)(i,j) = (*C)(j,i);
3379 }
3380 }
3381 }
3382 }
3383}
3385//*************************************************************************************************
3386
3387
3388//*************************************************************************************************
3404template< typename MT1, typename MT2, typename MT3 >
3405inline void smmm( MT1& C, const MT2& A, const MT3& B )
3406{
3407 using ET1 = ElementType_t<MT1>;
3408 using ET2 = ElementType_t<MT2>;
3409 using ET3 = ElementType_t<MT3>;
3410
3413
3414 smmm( C, A, B, ET1(1) );
3415}
3417//*************************************************************************************************
3418
3419
3420
3421
3422//=================================================================================================
3423//
3424// HERMITIAN DENSE MATRIX MULTIPLICATION KERNELS
3425//
3426//=================================================================================================
3427
3428//*************************************************************************************************
3446template< typename MT1, typename MT2, typename MT3, typename ST >
3447void hmmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha )
3448{
3449 using ET1 = ElementType_t<MT1>;
3450 using ET2 = ElementType_t<MT2>;
3451 using ET3 = ElementType_t<MT3>;
3452
3457
3460
3463
3466
3467 const size_t M( A.rows() );
3468 const size_t N( B.columns() );
3469
3470 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3471
3472 lmmm( C, A, B, alpha, ST(0) );
3473
3474 for( size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3475 {
3476 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3477
3478 for( size_t i=ii; i<iend; ++i ) {
3479 for( size_t j=i+1UL; j<iend; ++j ) {
3480 (*C)(i,j) = conj( (*C)(j,i) );
3481 }
3482 }
3483
3484 for( size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3485 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3486 for( size_t i=ii; i<iend; ++i ) {
3487 for( size_t j=jj; j<jend; ++j ) {
3488 (*C)(i,j) = conj( (*C)(j,i) );
3489 }
3490 }
3491 }
3492 }
3493}
3495//*************************************************************************************************
3496
3497
3498//*************************************************************************************************
3516template< typename MT1, typename MT2, typename MT3, typename ST >
3517void hmmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha )
3518{
3519 using ET1 = ElementType_t<MT1>;
3520 using ET2 = ElementType_t<MT2>;
3521 using ET3 = ElementType_t<MT3>;
3522
3527
3530
3533
3536
3537 const size_t M( A.rows() );
3538 const size_t N( B.columns() );
3539
3540 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3541
3542 ummm( C, A, B, alpha, ST(0) );
3543
3544 for( size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3545 {
3546 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3547
3548 for( size_t j=jj; j<jend; ++j ) {
3549 for( size_t i=jj+1UL; i<jend; ++i ) {
3550 (*C)(i,j) = conj( (*C)(j,i) );
3551 }
3552 }
3553
3554 for( size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3555 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3556 for( size_t j=jj; j<jend; ++j ) {
3557 for( size_t i=ii; i<iend; ++i ) {
3558 (*C)(i,j) = conj( (*C)(j,i) );
3559 }
3560 }
3561 }
3562 }
3563}
3565//*************************************************************************************************
3566
3567
3568//*************************************************************************************************
3584template< typename MT1, typename MT2, typename MT3 >
3585inline void hmmm( MT1& C, const MT2& A, const MT3& B )
3586{
3587 using ET1 = ElementType_t<MT1>;
3588 using ET2 = ElementType_t<MT2>;
3589 using ET3 = ElementType_t<MT3>;
3590
3593
3594 hmmm( C, A, B, ET1(1) );
3595}
3597//*************************************************************************************************
3598
3599} // namespace blaze
3600
3601#endif
Constraint on the data type.
Header file for auxiliary alias declarations.
Header file for run time assertion macros.
Header file for kernel specific block sizes.
Header file for the blaze::checked and blaze::unchecked instances.
Constraints on the storage order of matrix types.
Constraint on the data type.
Header file for the isDefault shim.
Header file for the IsFloatingPoint type trait.
Header file for the IsLower type trait.
Header file for the isOne shim.
Header file for the IsPadded type trait.
Header file for the IsUpper type trait.
Constraint on the data type.
Header file for the prevMultiple shim.
Constraints on the storage order of matrix types.
Constraint on the data type.
Header file for all SIMD functionality.
Compile time assertion.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Constraint on the data type.
Header file for the implementation of a dynamic MxN matrix.
Header file for the DenseMatrix base class.
decltype(auto) min(const DenseMatrix< MT1, SO1 > &lhs, const DenseMatrix< MT2, SO2 > &rhs)
Computes the componentwise minimum of the dense matrices lhs and rhs.
Definition: DMatDMatMapExpr.h:1339
decltype(auto) conj(const DenseMatrix< MT, SO > &dm)
Returns a matrix containing the complex conjugate of each single element of dm.
Definition: DMatMapExpr.h:1464
decltype(auto) serial(const DenseMatrix< MT, SO > &dm)
Forces the serial evaluation of the given dense matrix expression dm.
Definition: DMatSerialExpr.h:812
decltype(auto) sum(const DenseMatrix< MT, SO > &dm)
Reduces the given dense matrix by means of addition.
Definition: DMatReduceExpr.h:2156
bool isDefault(const DiagonalMatrix< MT, SO, DF > &m)
Returns whether the given diagonal matrix is in default state.
Definition: DiagonalMatrix.h:169
#define BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Symmetric.h:79
#define BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.
Definition: RowMajorMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Hermitian.h:79
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Upper.h:81
#define BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE(T)
Constraint on the data type.
Definition: DenseMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE(T)
Constraint on the data type.
Definition: Computation.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: UniUpper.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE(T)
Constraint on the data type.
Definition: Adaptor.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: Lower.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: UniLower.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: StrictlyLower.h:81
#define BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE(T)
Constraint on the data type.
Definition: StrictlyUpper.h:81
#define BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE(T)
Constraint on the data type.
Definition: ColumnMajorMatrix.h:61
#define BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES(T1, T2)
Constraint on the data type.
Definition: SIMDCombinable.h:61
BLAZE_ALWAYS_INLINE constexpr auto prevMultiple(T1 value, T2 factor) noexcept
Rounds down an integral value to the previous multiple of a given factor.
Definition: PrevMultiple.h:68
bool isOne(const Proxy< PT, RT > &proxy)
Returns whether the represented element is 1.
Definition: Proxy.h:2337
constexpr void reset(Matrix< MT, SO > &matrix)
Resetting the given matrix.
Definition: Matrix.h:806
constexpr size_t size(const Matrix< MT, SO > &matrix) noexcept
Returns the total number of elements of the matrix.
Definition: Matrix.h:676
#define BLAZE_INTERNAL_ASSERT(expr, msg)
Run time assertion macro for internal checks.
Definition: Assert.h:101
#define BLAZE_STATIC_ASSERT(expr)
Compile time assertion macro.
Definition: StaticAssert.h:112
decltype(auto) submatrix(Matrix< MT, SO > &, RSAs...)
Creating a view on a specific submatrix of the given matrix.
Definition: Submatrix.h:181
constexpr Unchecked unchecked
Global Unchecked instance.
Definition: Check.h:146
Header file for the serial shim.
Header file for basic type definitions.
Header file for the generic min algorithm.
Header file for the implementation of the Submatrix view.