LAPACK++  2022.07.00
LAPACK C++ API
flops.hh
1 // Copyright (c) 2017-2022, University of Tennessee. All rights reserved.
2 // SPDX-License-Identifier: BSD-3-Clause
3 // This program is free software: you can redistribute it and/or modify it under
4 // the terms of the BSD 3-Clause license. See the accompanying LICENSE file.
5 
6 #ifndef LAPACK_FLOPS_HH
7 #define LAPACK_FLOPS_HH
8 
9 #include "lapack.hh"
10 #include "blas/flops.hh"
11 
12 #include <complex>
13 
14 namespace lapack {
15 
16 //==============================================================================
17 // Generic formulas come from LAWN 41
18 // BLAS formulas generally assume alpha == 1 or -1, and beta == 1, -1, or 0;
19 // otherwise add some smaller order term.
20 // Some formulas are wrong when m, n, or k == 0; flops should be 0
21 // (e.g., syr2k, unmqr).
22 // Formulas may give negative results for invalid combinations of m, n, k
23 // (e.g., ungqr, unmqr).
24 
25 //------------------------------------------------------------ getrf
26 // LAWN 41 omits (m < n) case
27 inline double fmuls_getrf(double m, double n)
28 {
29  return (m >= n)
30  ? (0.5*m*n*n - 1./6*n*n*n + 0.5*m*n - 0.5*n*n + 2/3.*n)
31  : (0.5*n*m*m - 1./6*m*m*m + 0.5*n*m - 0.5*m*m + 2/3.*m);
32 }
33 
34 inline double fadds_getrf(double m, double n)
35 {
36  return (m >= n)
37  ? (0.5*m*n*n - 1./6*n*n*n - 0.5*m*n + 1./6*n)
38  : (0.5*n*m*m - 1./6*m*m*m - 0.5*n*m + 1./6*m);
39 }
40 
41 //------------------------------------------------------------ getri
42 inline double fmuls_getri(double n)
43  { return 2/3.*n*n*n + 0.5*n*n + 5./6*n; }
44 
45 inline double fadds_getri(double n)
46  { return 2/3.*n*n*n - 1.5*n*n + 5./6*n; }
47 
48 //------------------------------------------------------------ getrs
49 inline double fmuls_getrs(double n, double nrhs)
50  { return nrhs*n*n; }
51 
52 inline double fadds_getrs(double n, double nrhs)
53  { return nrhs*n*(n - 1); }
54 
55 //------------------------------------------------------------ potrf
56 inline double fmuls_potrf(double n)
57  { return 1./6*n*n*n + 0.5*n*n + 1./3.*n; }
58 
59 inline double fadds_potrf(double n)
60  { return 1./6*n*n*n - 1./6*n; }
61 
62 //------------------------------------------------------------ potri
63 inline double fmuls_potri(double n)
64  { return 1./3.*n*n*n + n*n + 2/3.*n; }
65 
66 inline double fadds_potri(double n)
67  { return 1./3.*n*n*n - 0.5*n*n + 1./6*n; }
68 
69 //------------------------------------------------------------ potrs
70 inline double fmuls_potrs(double n, double nrhs)
71  { return nrhs*n*(n + 1); }
72 
73 inline double fadds_potrs(double n, double nrhs)
74  { return nrhs*n*(n - 1); }
75 
76 //------------------------------------------------------------ pbtrf
77 inline double fmuls_pbtrf(double n, double k)
78  { return n*(1./2.*k*k + 3./2.*k + 1) - 1./3.*k*k*k - k*k - 2./3.*k; }
79 
80 inline double fadds_pbtrf(double n, double k)
81  { return n*(1./2.*k*k + 1./2.*k) - 1./3.*k*k*k - 1./2.*k*k - 1./6.*k; }
82 
83 //------------------------------------------------------------ pbtrs
84 inline double fmuls_pbtrs(double n, double nrhs, double k)
85  { return nrhs*(2*n*k + 2*n - k*k - k); }
86 
87 inline double fadds_pbtrs(double n, double nrhs, double k)
88  { return nrhs*(2*n*k - k*k - k); }
89 
90 //------------------------------------------------------------ sytrf
91 inline double fmuls_sytrf(double n)
92  { return 1/6.*n*n*n + 0.5*n*n + 10/3.*n; }
93 
94 inline double fadds_sytrf(double n)
95  { return 1/6.*n*n*n - 1/6.*n; }
96 
97 //------------------------------------------------------------ sytri
98 inline double fmuls_sytri(double n)
99  { return 1/3.*n*n*n + n*n + 2/3.*n; }
100 
101 inline double fadds_sytri(double n)
102  { return 1/3.*n*n*n - 1/3.*n; }
103 
104 //------------------------------------------------------------ sytrs
105 inline double fmuls_sytrs(double n, double nrhs)
106  { return nrhs*n*(n + 1); }
107 
108 inline double fadds_sytrs(double n, double nrhs)
109  { return nrhs*n*(n - 1); }
110 
111 //------------------------------------------------------------ geqrf
112 inline double fmuls_geqrf(double m, double n)
113 {
114  return (m > n)
115  ? (m*n*n - 1./3.*n*n*n + m*n + 0.5*n*n + 23./6*n)
116  : (n*m*m - 1./3.*m*m*m + 2*n*m - 0.5*m*m + 23./6*m);
117 }
118 
119 inline double fadds_geqrf(double m, double n)
120 {
121  return (m > n)
122  ? (m*n*n - 1./3.*n*n*n + 0.5*n*n + 5./6*n)
123  : (n*m*m - 1./3.*m*m*m + n*m - 0.5*m*m + 5./6*m);
124 }
125 
126 //------------------------------------------------------------ geqrt
127 // TODO: this seems odd -- should it match geqrf? At least be O(mn^2)?
128 inline double fmuls_geqrt(double m, double n)
129  { return 0.5*m*n; }
130 
131 inline double fadds_geqrt(double m, double n)
132  { return 0.5*m*n; }
133 
134 //------------------------------------------------------------ geqlf
135 inline double fmuls_geqlf(double m, double n)
136  { return fmuls_geqrf(m, n); }
137 
138 inline double fadds_geqlf(double m, double n)
139  { return fadds_geqrf(m, n); }
140 
141 //------------------------------------------------------------ gerqf
142 inline double fmuls_gerqf(double m, double n)
143 {
144  return (m > n)
145  ? (m*n*n - 1./3.*n*n*n + m*n + 0.5*n*n + 29./6*n)
146  : (n*m*m - 1./3.*m*m*m + 2*n*m - 0.5*m*m + 29./6*m);
147 }
148 
149 inline double fadds_gerqf(double m, double n)
150 {
151  return (m > n)
152  ? (m*n*n - 1./3.*n*n*n + m*n - 0.5*n*n + 5./6*n)
153  : (n*m*m - 1./3.*m*m*m + 0.5*m*m + 5./6*m);
154 }
155 
156 //------------------------------------------------------------ gelqf
157 inline double fmuls_gelqf(double m, double n)
158  { return fmuls_gerqf(m, n); }
159 
160 inline double fadds_gelqf(double m, double n)
161  { return fadds_gerqf(m, n); }
162 
163 //------------------------------------------------------------ ungqr
164 inline double fmuls_ungqr(double m, double n, double k)
165  { return 2*m*n*k - (m + n)*k*k + 2/3.*k*k*k + 2*n*k - k*k - 5./3.*k; }
166 
167 inline double fadds_ungqr(double m, double n, double k)
168  { return 2*m*n*k - (m + n)*k*k + 2/3.*k*k*k + n*k - m*k + 1./3.*k; }
169 
170 //------------------------------------------------------------ ungql
171 inline double fmuls_ungql(double m, double n, double k)
172  { return fmuls_ungqr(m, n, k); }
173 
174 inline double fadds_ungql(double m, double n, double k)
175  { return fadds_ungqr(m, n, k); }
176 
177 //------------------------------------------------------------ ungrq
178 inline double fmuls_ungrq(double m, double n, double k)
179  { return 2*m*n*k - (m + n)*k*k + 2/3.*k*k*k + m*k + n*k - k*k - 2/3.*k; }
180 
181 inline double fadds_ungrq(double m, double n, double k)
182  { return 2*m*n*k - (m + n)*k*k + 2/3.*k*k*k + m*k - n*k + 1./3.*k; }
183 
184 //------------------------------------------------------------ unglq
185 inline double fmuls_unglq(double m, double n, double k)
186  { return fmuls_ungrq(m, n, k); }
187 
188 inline double fadds_unglq(double m, double n, double k)
189  { return fadds_ungrq(m, n, k); }
190 
191 //------------------------------------------------------------ unmqr
192 inline double fmuls_unmqr(lapack::Side side, double m, double n, double k)
193 {
194  return (side == lapack::Side::Left)
195  ? (2*n*m*k - n*k*k + 2*n*k)
196  : (2*n*m*k - m*k*k + m*k + n*k - 0.5*k*k + 0.5*k);
197 }
198 
199 inline double fadds_unmqr(lapack::Side side, double m, double n, double k)
200 {
201  return (side == lapack::Side::Left)
202  ? (2*n*m*k - n*k*k + n*k)
203  : (2*n*m*k - m*k*k + m*k);
204 }
205 
206 //------------------------------------------------------------ unmql
207 inline double fmuls_unmql(lapack::Side side, double m, double n, double k)
208  { return fmuls_unmqr(side, m, n, k); }
209 
210 inline double fadds_unmql(lapack::Side side, double m, double n, double k)
211  { return fadds_unmqr(side, m, n, k); }
212 
213 //------------------------------------------------------------ unmrq
214 inline double fmuls_unmrq(lapack::Side side, double m, double n, double k)
215  { return fmuls_unmqr(side, m, n, k); }
216 
217 inline double fadds_unmrq(lapack::Side side, double m, double n, double k)
218  { return fadds_unmqr(side, m, n, k); }
219 
220 //------------------------------------------------------------ unmlq
221 inline double fmuls_unmlq(lapack::Side side, double m, double n, double k)
222  { return fmuls_unmqr(side, m, n, k); }
223 
224 inline double fadds_unmlq(lapack::Side side, double m, double n, double k)
225  { return fadds_unmqr(side, m, n, k); }
226 
227 //------------------------------------------------------------ trtri
228 inline double fmuls_trtri(double n)
229  { return 1./6*n*n*n + 0.5*n*n + 1./3.*n; }
230 
231 inline double fadds_trtri(double n)
232  { return 1./6*n*n*n - 0.5*n*n + 1./3.*n; }
233 
234 //------------------------------------------------------------ gehrd
235 inline double fmuls_gehrd(double n)
236  { return 5./3.*n*n*n + 0.5*n*n - 7./6*n; }
237 
238 inline double fadds_gehrd(double n)
239  { return 5./3.*n*n*n - n*n - 2/3.*n; }
240 
241 //------------------------------------------------------------ sytrd
242 inline double fmuls_sytrd(double n)
243  { return 2/3.*n*n*n + 2.5*n*n - 1./6*n; }
244 
245 inline double fadds_sytrd(double n)
246  { return 2/3.*n*n*n + n*n - 8./3.*n; }
247 
248 inline double fmuls_hetrd(double n)
249  { return fmuls_sytrd(n); }
250 
251 inline double fadds_hetrd(double n)
252  { return fadds_sytrd(n); }
253 
254 //------------------------------------------------------------ gebrd
255 inline double fmuls_gebrd(double m, double n)
256 {
257  return (m >= n)
258  ? (2*m*n*n - 2/3.*n*n*n + 2*n*n + 20./3.*n)
259  : (2*n*m*m - 2/3.*m*m*m + 2*m*m + 20./3.*m);
260 }
261 
262 inline double fadds_gebrd(double m, double n)
263 {
264  return (m >= n)
265  ? (2*m*n*n - 2/3.*n*n*n + n*n - m*n + 5./3.*n)
266  : (2*n*m*m - 2/3.*m*m*m + m*m - n*m + 5./3.*m);
267 }
268 
269 //------------------------------------------------------------ larfg
270 inline double fmuls_larfg(double n)
271  { return 2*n; }
272 
273 inline double fadds_larfg(double n)
274  { return n; }
275 
276 //------------------------------------------------------------ geadd
277 inline double fmuls_geadd(double m, double n)
278  { return 2*m*n; }
279 
280 inline double fadds_geadd(double m, double n)
281  { return m*n; }
282 
283 //------------------------------------------------------------ lauum
284 inline double fmuls_lauum(double n)
285  { return fmuls_potri(n) - fmuls_trtri(n); }
286 
287 inline double fadds_lauum(double n)
288  { return fadds_potri(n) - fadds_trtri(n); }
289 
290 //------------------------------------------------------------ lange
291 inline double fmuls_lange(lapack::Norm norm, double m, double n)
292  { return norm == lapack::Norm::Fro ? m*n : 0; }
293 
294 inline double fadds_lange(lapack::Norm norm, double m, double n)
295 {
296  switch (norm) {
297  case lapack::Norm::One: return (m-1)*n;
298  case lapack::Norm::Inf: return (n-1)*m;
299  case lapack::Norm::Fro: return m*n-1;
300  default: return 0;
301  }
302 }
303 
304 //------------------------------------------------------------ lanhe
305 inline double fmuls_lanhe(lapack::Norm norm, double n)
306  { return norm == lapack::Norm::Fro ? n*(n+1)/2 : 0; }
307 
308 inline double fadds_lanhe(lapack::Norm norm, double n)
309 {
310  switch (norm) {
311  case lapack::Norm::One: return (n-1)*n;
312  case lapack::Norm::Inf: return (n-1)*n;
313  case lapack::Norm::Fro: return n*(n+1)/2-1;
314  default: return 0;
315  }
316 }
317 
318 //==============================================================================
319 // template class. Example:
320 // gbyte< float >::gemv( m, n ) yields bytes transferred for sgemv.
321 // gbyte< std::complex<float> >::gemv( m, n ) yields bytes transferred for cgemv.
322 //==============================================================================
323 template< typename T >
324 class Gbyte:
325  public blas::Gbyte<T>
326 {
327 };
328 
329 //==============================================================================
330 // template class. Example:
331 // gflop< float >::getrf( m, n ) yields flops for sgetrf.
332 // gflop< std::complex<float> >::getrf( m, n ) yields flops for cgetrf.
333 //==============================================================================
334 template< typename T >
335 class Gflop:
336  public blas::Gflop<T>
337 {
338 public:
339  using blas::Gflop<T>::mul_ops;
340  using blas::Gflop<T>::add_ops;
341 
342  // LU
343  static double gesv(double n, double nrhs)
344  { return getrf(n, n) + getrs(n, nrhs); }
345 
346  static double getrf(double m, double n)
347  { return 1e-9 * (mul_ops*fmuls_getrf(m, n) + add_ops*fadds_getrf(m, n)); }
348 
349  static double getri(double n)
350  { return 1e-9 * (mul_ops*fmuls_getri(n) + add_ops*fadds_getri(n)); }
351 
352  static double getrs(double n, double nrhs)
353  { return 1e-9 * (mul_ops*fmuls_getrs(n, nrhs) + add_ops*fadds_getrs(n, nrhs)); }
354 
355  // Cholesky
356  static double posv(double n, double nrhs)
357  { return potrf(n) + potrs(n, nrhs); }
358 
359  static double potrf(double n)
360  { return 1e-9 * (mul_ops*fmuls_potrf(n) + add_ops*fadds_potrf(n)); }
361 
362  static double potri(double n)
363  { return 1e-9 * (mul_ops*fmuls_potri(n) + add_ops*fadds_potri(n)); }
364 
365  static double potrs(double n, double nrhs)
366  { return 1e-9 * (mul_ops*fmuls_potrs(n, nrhs) + add_ops*fadds_potrs(n, nrhs)); }
367 
368  // Band Cholesky
369  static double pbsv(double n, double nrhs, double k)
370  { return pbtrf(n, k) + pbtrs(n, nrhs, k); }
371 
372  static double pbtrf(double n, double k)
373  { return 1e-9 * (mul_ops*fmuls_pbtrf(n, k) + add_ops*fadds_pbtrf(n, k)); }
374 
375  static double pbtrs(double n, double nrhs, double k)
376  { return 1e-9 * (mul_ops*fmuls_pbtrs(n, nrhs, k) + add_ops*fadds_pbtrs(n, nrhs, k)); }
377 
378  // LDL^T
379  static double sysv(double n, double nrhs)
380  { return sytrf(n) + sytrs(n, nrhs); }
381 
382  static double sytrf(double n)
383  { return 1e-9 * (mul_ops*fmuls_sytrf(n) + add_ops*fadds_sytrf(n)); }
384 
385  static double sytri(double n)
386  { return 1e-9 * (mul_ops*fmuls_sytri(n) + add_ops*fadds_sytri(n)); }
387 
388  static double sytrs(double n, double nrhs)
389  { return 1e-9 * (mul_ops*fmuls_sytrs(n, nrhs) + add_ops*fadds_sytrs(n, nrhs)); }
390 
391  static double hesv(double n, double nrhs)
392  { return sysv(n, nrhs); }
393 
394  static double hetrf(double n)
395  { return sytrf(n); }
396 
397  static double hetri(double n)
398  { return sytri(n); }
399 
400  static double hetrs(double n, double nrhs)
401  { return sytrs(n, nrhs); }
402 
403  // QR, QL, RQ, LQ
404  static double geqrf(double m, double n)
405  { return 1e-9 * (mul_ops*fmuls_geqrf(m, n) + add_ops*fadds_geqrf(m, n)); }
406 
407  static double geqrt(double m, double n)
408  { return 1e-9 * (mul_ops*fmuls_geqrt(m, n) + add_ops*fadds_geqrt(m, n)); }
409 
410  static double geqlf(double m, double n)
411  { return 1e-9 * (mul_ops*fmuls_geqlf(m, n) + add_ops*fadds_geqlf(m, n)); }
412 
413  static double gerqf(double m, double n)
414  { return 1e-9 * (mul_ops*fmuls_gerqf(m, n) + add_ops*fadds_gerqf(m, n)); }
415 
416  static double gelqf(double m, double n)
417  { return 1e-9 * (mul_ops*fmuls_gelqf(m, n) + add_ops*fadds_gelqf(m, n)); }
418 
419  // generate Q
420  static double ungqr(double m, double n, double k)
421  { return 1e-9 * (mul_ops*fmuls_ungqr(m, n, k) + add_ops*fadds_ungqr(m, n, k)); }
422 
423  static double orgqr(double m, double n, double k)
424  { return ungqr(m, n, k); }
425 
426  static double ungql(double m, double n, double k)
427  { return 1e-9 * (mul_ops*fmuls_ungql(m, n, k) + add_ops*fadds_ungql(m, n, k)); }
428 
429  static double orgql(double m, double n, double k)
430  { return ungql(m, n, k); }
431 
432  static double ungrq(double m, double n, double k)
433  { return 1e-9 * (mul_ops*fmuls_ungrq(m, n, k) + add_ops*fadds_ungrq(m, n, k)); }
434 
435  static double orgrq(double m, double n, double k)
436  { return ungrq(m, n, k); }
437 
438  static double unglq(double m, double n, double k)
439  { return 1e-9 * (mul_ops*fmuls_unglq(m, n, k) + add_ops*fadds_unglq(m, n, k)); }
440 
441  static double orglq(double m, double n, double k)
442  { return unglq(m, n, k); }
443 
444  // multiply by Q
445  static double unmqr(lapack::Side side, double m, double n, double k)
446  { return 1e-9 * (mul_ops*fmuls_unmqr(side, m, n, k) + add_ops*fadds_unmqr(side, m, n, k)); }
447 
448  static double ormqr(lapack::Side side, double m, double n, double k)
449  { return unmqr(side, m, n, k); }
450 
451  static double unmql(lapack::Side side, double m, double n, double k)
452  { return 1e-9 * (mul_ops*fmuls_unmql(side, m, n, k) + add_ops*fadds_unmql(side, m, n, k)); }
453 
454  static double ormql(lapack::Side side, double m, double n, double k)
455  { return unmql(side, m, n, k); }
456 
457  static double unmrq(lapack::Side side, double m, double n, double k)
458  { return 1e-9 * (mul_ops*fmuls_unmrq(side, m, n, k) + add_ops*fadds_unmrq(side, m, n, k)); }
459 
460  static double ormrq(lapack::Side side, double m, double n, double k)
461  { return unmrq(side, m, n, k); }
462 
463  static double unmlq(lapack::Side side, double m, double n, double k)
464  { return 1e-9 * (mul_ops*fmuls_unmlq(side, m, n, k) + add_ops*fadds_unmlq(side, m, n, k)); }
465 
466  static double ormlq(lapack::Side side, double m, double n, double k)
467  { return unmlq(side, m, n, k); }
468 
469  // least squares
470  static double gels(double m, double n, double nrhs)
471  {
472  blas::Side left = blas::Side::Left;
473  return (m >= n
474  ? geqrf(m, n) + unmqr(left, m, nrhs, n) + blas::Gflop<T>::trsm(left, n, nrhs)
475  : gelqf(m, n) + unmlq(left, n, nrhs, m) + blas::Gflop<T>::trsm(left, m, nrhs));
476  }
477 
478  // triangle inverse
479  static double trtri(double n)
480  { return 1e-9 * (mul_ops*fmuls_trtri(n) + add_ops*fadds_trtri(n)); }
481 
482  // Hessenberg reduction (non-symmetric eigenvalue)
483  static double gehrd(double n)
484  { return 1e-9 * (mul_ops*fmuls_gehrd(n) + add_ops*fadds_gehrd(n)); }
485 
486  // tridiagonal reduction (symmetric eigenvalue)
487  static double hetrd(double n)
488  { return 1e-9 * (mul_ops*fmuls_sytrd(n) + add_ops*fadds_sytrd(n)); }
489 
490  static double sytrd(double n)
491  { return hetrd(n); }
492 
493  // bidiagonal reduction (SVD)
494  static double gebrd(double m, double n)
495  { return 1e-9 * (mul_ops*fmuls_gebrd(m, n) + add_ops*fadds_gebrd(m, n)); }
496 
497  // Householder reflector generate
498  static double larfg(double n)
499  { return 1e-9 * (mul_ops*fmuls_larfg(n) + add_ops*fadds_larfg(n)); }
500 
501  // matrix add
502  static double geadd(double m, double n)
503  { return 1e-9 * (mul_ops*fmuls_geadd(m, n) + add_ops*fadds_geadd(m, n)); }
504 
505  // U^H*U or L*L^T
506  static double lauum(double n)
507  { return 1e-9 * (mul_ops*fmuls_lauum(n) + add_ops*fadds_lauum(n)); }
508 
509  // norm
510  static double lange(lapack::Norm norm, double m, double n)
511  { return 1e-9 * (mul_ops*fmuls_lange(norm, m, n) + add_ops*fadds_lange(norm, m, n)); }
512 
513  static double lanhe(lapack::Norm norm, double n)
514  { return 1e-9 * (mul_ops*fmuls_lanhe(norm, n) + add_ops*fadds_lanhe(norm, n)); }
515 
516  static double lansy(lapack::Norm norm, double n)
517  { return lanhe(norm, n); }
518 };
519 
520 } // namespace lapack
521 
522 #endif // LAPACK_FLOPS_HH