Commits

Anonymous committed 297b5e5

New Strassen-like sequence for multiplication, and squaring.

  • Participants
  • Parent commits 338a4f1

Comments (0)

Files changed (2)

 *
 *    Copyright (C) 2008 Martin Albrecht <M.R.Albrecht@rhul.ac.uk>
 *    Copyright (C) 2008 Clement Pernet <pernet@math.washington.edu>
+*    Copyright (C) 2008 Marco Bodrato <bodrato@mail.dm.unipi.it>
 *
 *  Distributed under the terms of the GNU General Public License (GPL)
 *  version 2 or higher.
   return C;
 }
 
+
+packedmatrix *_mzd_mul_evenb(packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff) {
+  size_t m,k,n;
+  size_t mmm, kkk, nnn;
+  
+  m = A->nrows;
+  k = A->ncols;
+  n = B->ncols;
+
+  if(C->nrows == 0 || C->ncols == 0)
+    return C;
+
+  /* handle case first, where the input matrices are too small already */
+  if (CLOSER(m, m/2, cutoff) || CLOSER(k, k/2, cutoff) || CLOSER(n, n/2, cutoff)) {
+    /* we copy the matrix first since it is only constant memory
+       overhead and improves data locality, if you remove it make sure
+       there are no speed regressions */
+    /* C = _mzd_mul_m4rm(C, A, B, 0, TRUE); */
+    packedmatrix *Cbar = mzd_init(m, n);
+    _mzd_mul_m4rm(Cbar, A, B, 0, FALSE);
+    mzd_copy(C, Cbar);
+    mzd_free(Cbar);
+    return C;
+  }
+
+  /* adjust cutting numbers to work on words */
+  {
+    unsigned long mult = RADIX;
+    unsigned long width = MIN(MIN(m,n),k)/2;
+    while (width > cutoff) {
+      width>>=1;
+      mult<<=1;
+    }
+
+    mmm = (((m - m%mult)/RADIX) >> 1) * RADIX;
+    kkk = (((k - k%mult)/RADIX) >> 1) * RADIX;
+    nnn = (((n - n%mult)/RADIX) >> 1) * RADIX;
+  }
+  /*         |A |   |B |   |C |
+   * Compute |  | x |  | = |  | */
+  {
+    packedmatrix *A11 = mzd_init_window(A,   0,   0,   mmm,   kkk);
+    packedmatrix *A12 = mzd_init_window(A,   0, kkk,   mmm, 2*kkk);
+    packedmatrix *A21 = mzd_init_window(A, mmm,   0, 2*mmm,   kkk);
+    packedmatrix *A22 = mzd_init_window(A, mmm, kkk, 2*mmm, 2*kkk);
+
+    packedmatrix *B11 = mzd_init_window(B,   0,   0,   kkk,   nnn);
+    packedmatrix *B12 = mzd_init_window(B,   0, nnn,   kkk, 2*nnn);
+    packedmatrix *B21 = mzd_init_window(B, kkk,   0, 2*kkk,   nnn);
+    packedmatrix *B22 = mzd_init_window(B, kkk, nnn, 2*kkk, 2*nnn);
+
+    packedmatrix *C11 = mzd_init_window(C,   0,   0,   mmm,   nnn);
+    packedmatrix *C12 = mzd_init_window(C,   0, nnn,   mmm, 2*nnn);
+    packedmatrix *C21 = mzd_init_window(C, mmm,   0, 2*mmm,   nnn);
+    packedmatrix *C22 = mzd_init_window(C, mmm, nnn, 2*mmm, 2*nnn);
+  
+    /**
+     * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
+     * Suited for Squaring and Highest Power Computation";
+     * http://bodrato.it/papres/#CIVV2008 for reference on the used
+     * sequence of operations.
+     */
+
+    /* change this to mzd_init(mmm, MAX(nnn,kkk)) to fix the todo below */
+    packedmatrix *Wmk = mzd_init(mmm, kkk);
+    packedmatrix *Wkn = mzd_init(kkk, nnn);
+
+    _mzd_add(Wkn, B22, B12);		 /* Wkn = B22 + B12 */
+    _mzd_add(Wmk, A22, A12);		 /* Wmk = A22 + A12 */
+    _mzd_mul_evenb(C21, Wmk, Wkn, cutoff);/* C21 = Wmk * Wkn */
+
+    _mzd_add(Wmk, A22, A21);		 /* Wmk = A22 - A21 */
+    _mzd_add(Wkn, B22, B21);		 /* Wkn = B22 - B21 */
+    _mzd_mul_evenb(C22, Wmk, Wkn, cutoff);/* C22 = Wmk * Wkn */
+
+    _mzd_add(Wkn, Wkn, B12);		 /* Wkn = Wkn + B12 */
+    _mzd_add(Wmk, Wmk, A12);		 /* Wmk = Wmk + A12 */
+    _mzd_mul_evenb(C11, Wmk, Wkn, cutoff);/* C11 = Wmk * Wkn */
+
+    _mzd_add(Wmk, Wmk, A11);		 /* Wmk = Wmk - A11 */
+    _mzd_mul_evenb(C12, Wmk, B12, cutoff);/* C12 = Wmk * B12 */
+    _mzd_add(C12, C12, C22);		 /* C12 = C12 + C22 */
+
+    /**
+     * \todo ideally we would use the same Wmk throughout the function
+     * but some called function doesn't like that and we end up with a
+     * wrong result if we use virtual Wmk matrices. Ideally, this should
+     * be fixed not worked around. The check whether the bug has been
+     * fixed, use only one Wmk and check if mzd_mul(4096, 3528,
+     * 4096, 2124) still returns the correct answer.
+     */
+
+    mzd_free(Wmk);
+    Wmk = mzd_mul(NULL, A12, B21, cutoff);/*Wmk = A12 * B21 */
+
+    _mzd_add(C11, C11, Wmk);		  /* C11 = C11 + Wmk */
+    _mzd_add(C12, C11, C12);		  /* C12 = C11 - C12 */
+    _mzd_add(C11, C21, C11);		  /* C11 = C21 - C11 */
+    _mzd_add(Wkn, Wkn, B11);		  /* Wkn = Wkn - B11 */
+    _mzd_mul_evenb(C21, A21, Wkn, cutoff);/* C21 = A21 * Wkn */
+    mzd_free(Wkn);
+
+    _mzd_add(C21, C11, C21);		  /* C21 = C11 - C21 */
+    _mzd_add(C22, C22, C11);		  /* C22 = C22 + C11 */
+    _mzd_mul_evenb(C11, A11, B11, cutoff);/* C11 = A11 * B11 */
+
+    _mzd_add(C11, C11, Wmk);		  /* C11 = C11 + Wmk */
+
+    /* clean up */
+    mzd_free_window(A11); mzd_free_window(A12);
+    mzd_free_window(A21); mzd_free_window(A22);
+
+    mzd_free_window(B11); mzd_free_window(B12);
+    mzd_free_window(B21); mzd_free_window(B22);
+
+    mzd_free_window(C11); mzd_free_window(C12);
+    mzd_free_window(C21); mzd_free_window(C22);
+
+    mzd_free(Wmk);
+  }
+  /* deal with rest */
+  nnn*=2;
+  if (n > nnn) {
+    /*         |AA|   | B|   | C|
+     * Compute |AA| x | B| = | C| */
+    packedmatrix *B_last_col = mzd_init_window(B, 0, nnn, k, n); 
+    packedmatrix *C_last_col = mzd_init_window(C, 0, nnn, m, n);
+    _mzd_mul_m4rm(C_last_col, A, B_last_col, 0, TRUE);
+    mzd_free_window(B_last_col);
+    mzd_free_window(C_last_col);
+  }
+  mmm*=2;
+  if (m > mmm) {
+    /*         |  |   |B |   |  |
+     * Compute |AA| x |B | = |C | */
+    packedmatrix *A_last_row = mzd_init_window(A, mmm, 0, m, k);
+    packedmatrix *B_first_col= mzd_init_window(B,   0, 0, k, nnn);
+    packedmatrix *C_last_row = mzd_init_window(C, mmm, 0, m, nnn);
+    _mzd_mul_m4rm(C_last_row, A_last_row, B_first_col, 0, TRUE);
+    mzd_free_window(A_last_row);
+    mzd_free_window(B_first_col);
+    mzd_free_window(C_last_row);
+  }
+  kkk*=2;
+  if (k > kkk) {
+    /* Add to  |  |   | B|   |C |
+     * result  |A | x |  | = |  | */
+    packedmatrix *A_last_col = mzd_init_window(A,   0, kkk, mmm, k);
+    packedmatrix *B_last_row = mzd_init_window(B, kkk,   0,   k, nnn);
+    packedmatrix *C_bulk = mzd_init_window(C, 0, 0, mmm, nnn);
+    mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0);
+    mzd_free_window(A_last_col);
+    mzd_free_window(B_last_row);
+    mzd_free_window(C_bulk);
+  }
+
+  return C;
+}
+
+packedmatrix *_mzd_sqr_evenb(packedmatrix *C, packedmatrix *A, int cutoff) {
+  size_t m;
+  size_t mmm;
+  
+  m = A->nrows;
+  /* handle case first, where the input matrices are too small already */
+  if (CLOSER(m, m/2, cutoff)) {
+    /* we copy the matrix first since it is only constant memory
+       overhead and improves data locality, if you remove it make sure
+       there are no speed regressions */
+    /* C = _mzd_mul_m4rm(C, A, B, 0, TRUE); */
+    packedmatrix *Cbar = mzd_init(m, m);
+    _mzd_mul_m4rm(Cbar, A, A, 0, FALSE);
+    mzd_copy(C, Cbar);
+    mzd_free(Cbar);
+    return C;
+  }
+
+  /* adjust cutting numbers to work on words */
+  {
+    unsigned long mult = RADIX;
+    unsigned long width = m>>1;
+    while (width > cutoff) {
+      width>>=1;
+      mult<<=1;
+    }
+
+    mmm = (((m - m%mult)/RADIX) >> 1) * RADIX;
+  }
+  /*         |A |   |A |   |C |
+   * Compute |  | x |  | = |  | */
+  {
+    packedmatrix *A11 = mzd_init_window(A,   0,   0,   mmm,   mmm);
+    packedmatrix *A12 = mzd_init_window(A,   0, mmm,   mmm, 2*mmm);
+    packedmatrix *A21 = mzd_init_window(A, mmm,   0, 2*mmm,   mmm);
+    packedmatrix *A22 = mzd_init_window(A, mmm, mmm, 2*mmm, 2*mmm);
+
+    packedmatrix *C11 = mzd_init_window(C,   0,   0,   mmm,   mmm);
+    packedmatrix *C12 = mzd_init_window(C,   0, mmm,   mmm, 2*mmm);
+    packedmatrix *C21 = mzd_init_window(C, mmm,   0, 2*mmm,   mmm);
+    packedmatrix *C22 = mzd_init_window(C, mmm, mmm, 2*mmm, 2*mmm);
+  
+    /**
+     * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
+     * Suited for Squaring and Highest Power Computation";
+     * http://bodrato.it/papres/#CIVV2008 for reference on the used
+     * sequence of operations.
+     */
+
+    packedmatrix *Wmk;
+    packedmatrix *Wkn = mzd_init(mmm, mmm);
+
+    _mzd_add(Wkn, A22, A12);                 /* Wkn = A22 + A12 */
+    _mzd_sqr_evenb(C21, Wkn, cutoff);     /* C21 = Wkn^2 */
+
+    _mzd_add(Wkn, A22, A21);                 /* Wkn = A22 - A21 */
+    _mzd_sqr_evenb(C22, Wkn, cutoff);     /* C22 = Wkn^2 */
+
+    _mzd_add(Wkn, Wkn, A12);                 /* Wkn = Wkn + A12 */
+    _mzd_sqr_evenb(C11, Wkn, cutoff);     /* C11 = Wkn^2 */
+
+    _mzd_add(Wkn, Wkn, A11);                 /* Wkn = Wkn - A11 */
+    _mzd_mul_evenb(C12, Wkn, A12, cutoff);/* C12 = Wkn * A12 */
+    _mzd_add(C12, C12, C22);		  /* C12 = C12 + C22 */
+
+    Wmk = mzd_mul(NULL, A12, A21, cutoff);/*Wmk = A12 * A21 */
+
+    _mzd_add(C11, C11, Wmk);		  /* C11 = C11 + Wmk */
+    _mzd_add(C12, C11, C12);		  /* C12 = C11 - C12 */
+    _mzd_add(C11, C21, C11);		  /* C11 = C21 - C11 */
+    _mzd_mul_evenb(C21, A21, Wkn, cutoff);/* C21 = A21 * Wkn */
+    mzd_free(Wkn);
+
+    _mzd_add(C21, C11, C21);		  /* C21 = C11 - C21 */
+    _mzd_add(C22, C22, C11);		  /* C22 = C22 + C11 */
+    _mzd_sqr_evenb(C11, A11, cutoff);     /* C11 = A11^2 */
+
+    _mzd_add(C11, C11, Wmk);		  /* C11 = C11 + Wmk */
+
+    /* clean up */
+    mzd_free_window(A11); mzd_free_window(A12);
+    mzd_free_window(A21); mzd_free_window(A22);
+
+    mzd_free_window(C11); mzd_free_window(C12);
+    mzd_free_window(C21); mzd_free_window(C22);
+
+    mzd_free(Wmk);
+  }
+  /* deal with rest */
+  mmm*=2;
+  if (m > mmm) {
+    /*         |AA|   | A|   | C|
+     * Compute |AA| x | A| = | C| */
+    {
+      packedmatrix *A_last_col = mzd_init_window(A, 0, mmm, m, m);
+      packedmatrix *C_last_col = mzd_init_window(C, 0, mmm, m, m);
+      _mzd_mul_m4rm(C_last_col, A, A_last_col, 0, TRUE);
+      mzd_free_window(A_last_col);
+      mzd_free_window(C_last_col);
+    }
+    /*         |  |   |A |   |  |
+     * Compute |AA| x |A | = |C | */
+    {
+      packedmatrix *A_last_row = mzd_init_window(A, mmm, 0, m, m);
+      packedmatrix *A_first_col= mzd_init_window(A,   0, 0, m, mmm);
+      packedmatrix *C_last_row = mzd_init_window(C, mmm, 0, m, mmm);
+      _mzd_mul_m4rm(C_last_row, A_last_row, A_first_col, 0, TRUE);
+      mzd_free_window(A_last_row);
+      mzd_free_window(A_first_col);
+      mzd_free_window(C_last_row);
+    }
+    /* Add to  |  |   | A|   |C |
+     * result  |A | x |  | = |  | */
+    {
+      packedmatrix *A_last_col = mzd_init_window(A,   0, mmm, mmm, m);
+      packedmatrix *A_last_row = mzd_init_window(A, mmm,   0,   m, mmm);
+      packedmatrix *C_bulk = mzd_init_window(C, 0, 0, mmm, mmm);
+      mzd_addmul_m4rm(C_bulk, A_last_col, A_last_row, 0);
+      mzd_free_window(A_last_col);
+      mzd_free_window(A_last_row);
+      mzd_free_window(C_bulk);
+    }
+  }
+
+  return C;
+}
+
+
 #ifdef HAVE_OPENMP
 packedmatrix *_mzd_mul_mp_even(packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff) {
   /**
     C = _mzd_mul_even(C, A, B, cutoff);
   }
 #else
-  C = _mzd_mul_even(C, A, B, cutoff);
+  C = (A==B)?_mzd_sqr_evenb(C, A, cutoff):_mzd_mul_evenb(C, A, B, cutoff);
 #endif  
   return C;
 }

testsuite/test_multiplication.c

 
 }
 
+/**
+ * Check that the results of all implemented squaring algorithms match
+ * up. 
+ *
+ * \param m Number of rows and columns of A
+ * \param k Parameter k of M4RM algorithm, may be 0 for automatic choice.
+ * \param cutoff Cut off parameter at which dimension to switch from
+ * Strassen to M4RM
+ */
+int sqr_test_equality(int m, int k, int cutoff) {
+  int ret  = 0;
+  packedmatrix *A, *C, *D, *E;
+  
+  printf("   mul: m: %4d, k: %2d, cutoff: %4d",m,k,cutoff);
+
+  /* we create one random matrix */
+  A = mzd_init(m, m);
+  mzd_randomize(A);
+
+  /* C = A*A via Strassen */
+  C = mzd_mul(NULL, A, A, cutoff);
+
+  /* D = A*A via M4RM, temporary buffers are managed internally */
+  D = mzd_mul_m4rm(    NULL, A, A, k);
+
+  /* E = A*A via naive cubic multiplication */
+  E = mzd_mul_naive(    NULL, A, A);
+
+  mzd_free(A);
+
+  if (mzd_equal(C, D) != TRUE) {
+    printf(" Strassen != M4RM");
+    ret -=1;
+  }
+
+  if (mzd_equal(D, E) != TRUE) {
+    printf(" M4RM != Naiv");
+    ret -= 1;
+  }
+
+  if (mzd_equal(C, E) != TRUE) {
+    printf(" Strassen != Naiv");
+    ret -= 1;
+  }
+
+  mzd_free(C);
+  mzd_free(D);
+  mzd_free(E);
+
+  if(ret==0) {
+    printf(" ... passed\n");
+  } else {
+    printf(" ... FAILED\n");
+  }
+
+  return ret;
+}
+
 int addmul_test_equality(int m, int l, int n, int k, int cutoff) {
   int ret  = 0;
   packedmatrix *A, *B, *C, *D, *E, *F;
   status += addmul_test_equality(1290, 1290, 2000, 0, 64);
   status += addmul_test_equality(1000, 210, 200, 0, 64);
 
+  status += sqr_test_equality(1, 0, 1024);
+  status += sqr_test_equality(128, 0, 0);
+  status += sqr_test_equality(131, 0, 0);
+  status += sqr_test_equality(64,  0, 64);
+  status += sqr_test_equality(128, 0, 64);
+  status += sqr_test_equality(171, 0, 63); 
+  status += sqr_test_equality(171, 0, 131); 
+  status += sqr_test_equality(193, 10, 64);
+  status += sqr_test_equality(1025, 3, 256);
+  status += sqr_test_equality(2048, 0, 1024);
+  status += sqr_test_equality(3528, 0, 1024);
+  status += sqr_test_equality(1000, 0, 256);
+  status += sqr_test_equality(1000, 0, 64);
+  status += sqr_test_equality(1710, 0, 256);
+  status += sqr_test_equality(1290, 0, 64);
+  status += sqr_test_equality(2000, 0, 256);
+  status += sqr_test_equality(2000, 0, 64);
+  status += sqr_test_equality(210, 0, 64);
+
   if (status == 0) {
     printf("All tests passed.\n");
     return 0;