Commits

bod...@mail.dm.unipi.it  committed 71823bc

Added new functions for addmul and addsqr using new sequences.
Added also some trivial tests to test_multiplication.c

  • Participants
  • Parent commits 1a00450

Comments (0)

Files changed (2)

File src/strassen.c

   size_t m,k,n;
   size_t mmm, kkk, nnn;
   
+  if(C->nrows == 0 || C->ncols == 0)
+    return C;
+
   m = A->nrows;
   k = A->ncols;
   n = B->ncols;
 
-  if(C->nrows == 0 || C->ncols == 0)
-    return C;
-
   /* handle case first, where the input matrices are too small already */
   if (CLOSER(m, m/2, cutoff) || CLOSER(k, k/2, cutoff) || CLOSER(n, n/2, cutoff)) {
     /* we copy the matrix first since it is only constant memory
   return C;
 }
 
+packedmatrix *_mzd_addmul_evenb(packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff) {
+  /**
+   * \todo make sure not to overwrite crap after ncols and before width*RADIX
+   */
+
+  size_t m,k,n;
+  size_t mmm, kkk, nnn;
+  
+  if(C->nrows == 0 || C->ncols == 0)
+    return C;
+
+  m = A->nrows;
+  k = A->ncols;
+  n = B->ncols;
+
+  /* handle case first, where the input matrices are too small already */
+  if (CLOSER(m, m/2, cutoff) || CLOSER(k, k/2, cutoff) || CLOSER(n, n/2, cutoff)) {
+    /* we copy the matrix first since it is only constant memory
+       overhead and improves data locality, if you remove it make sure
+       there are no speed regressions */
+    packedmatrix *Cbar = mzd_copy(NULL, C);
+    mzd_addmul_m4rm(Cbar, A, B, 0);
+    mzd_copy(C, Cbar);
+    mzd_free(Cbar);
+    return C;
+  }
+
+  /* adjust cutting numbers to work on words */
+  {
+    unsigned long mult = RADIX;
+    unsigned long width = MIN(MIN(m,n),k)/2;
+    while (width > cutoff) {
+      width>>=1;
+      mult<<=1;
+    }
+
+    mmm = (((m - m%mult)/RADIX) >> 1) * RADIX;
+    kkk = (((k - k%mult)/RADIX) >> 1) * RADIX;
+    nnn = (((n - n%mult)/RADIX) >> 1) * RADIX;
+  }
+
+  /*         |C |    |A |   |B | 
+   * Compute |  | += |  | x |  |  */
+  {
+    packedmatrix *A11 = mzd_init_window(A,   0,   0,   mmm,   kkk);
+    packedmatrix *A12 = mzd_init_window(A,   0, kkk,   mmm, 2*kkk);
+    packedmatrix *A21 = mzd_init_window(A, mmm,   0, 2*mmm,   kkk);
+    packedmatrix *A22 = mzd_init_window(A, mmm, kkk, 2*mmm, 2*kkk);
+
+    packedmatrix *B11 = mzd_init_window(B,   0,   0,   kkk,   nnn);
+    packedmatrix *B12 = mzd_init_window(B,   0, nnn,   kkk, 2*nnn);
+    packedmatrix *B21 = mzd_init_window(B, kkk,   0, 2*kkk,   nnn);
+    packedmatrix *B22 = mzd_init_window(B, kkk, nnn, 2*kkk, 2*nnn);
+
+    packedmatrix *C11 = mzd_init_window(C,   0,   0,   mmm,   nnn);
+    packedmatrix *C12 = mzd_init_window(C,   0, nnn,   mmm, 2*nnn);
+    packedmatrix *C21 = mzd_init_window(C, mmm,   0, 2*mmm,   nnn);
+    packedmatrix *C22 = mzd_init_window(C, mmm, nnn, 2*mmm, 2*nnn);
+  
+    /**
+     * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
+     * Suited for Squaring and Highest Power Computation";
+     * http://bodrato.it/papres/#CIVV2008 for reference on the used
+     * sequence of operations.
+     */
+
+    packedmatrix *S = mzd_init(mmm, kkk);
+    packedmatrix *T = mzd_init(kkk, nnn);
+    packedmatrix *U = mzd_init(mmm, nnn);
+
+    _mzd_add(S, A22, A21);                   /* 1  S = A22 - A21       */
+    _mzd_add(T, B22, B21);                   /* 2  T = B22 - B21       */
+    _mzd_mul_evenb(U, S, T, cutoff);         /* 3  U = S*T             */
+    _mzd_add(C22, U, C22);                   /* 4  C22 = U + C22       */
+    _mzd_add(C12, U, C12);                   /* 5  C12 = U + C12       */
+
+    _mzd_mul_evenb(U, A12, B21, cutoff);     /* 8  U = A12*B21         */
+    _mzd_add(C11, U, C11);                   /* 9  C11 = U + C11       */
+
+    _mzd_addmul_evenb(C11, A11, B11, cutoff);/* 11 C11 = A11*B11 + C11 */
+
+    _mzd_add(S, S, A12);                     /* 6  S = S - A12         */
+    _mzd_add(T, T, B12);                     /* 7  T = T - B12         */
+    _mzd_addmul_evenb(U, S, T, cutoff);      /* 10 U = S*T + U         */
+    _mzd_add(C12, C12, U);                   /* 15 C12 = U + C12       */
+
+    _mzd_add(S, A11, S);                     /* 12 S = A11 - S         */
+    _mzd_addmul_evenb(C12, S, B12, cutoff);  /* 14 C12 = S*B12 + C12   */
+
+    _mzd_add(T, B11, T);                     /* 13 T = B11 - T         */
+    _mzd_addmul_evenb(C21, A21, T, cutoff);  /* 16 C21 = A21*T + C21   */
+
+    _mzd_add(S, A22, A12);                   /* 17 S = A22 + A21       */
+    _mzd_add(T, B22, B12);                   /* 18 T = B22 + B21       */
+    _mzd_addmul_evenb(U, S, T, cutoff);      /* 19 U = U - S*T         */
+    _mzd_add(C21, C21, U);                   /* 20 C21 = C21 - U3      */
+    _mzd_add(C22, C22, U);                   /* 21 C22 = C22 - U3      */
+
+    /* clean up */
+    mzd_free_window(A11); mzd_free_window(A12);
+    mzd_free_window(A21); mzd_free_window(A22);
+
+    mzd_free_window(B11); mzd_free_window(B12);
+    mzd_free_window(B21); mzd_free_window(B22);
+
+    mzd_free_window(C11); mzd_free_window(C12);
+    mzd_free_window(C21); mzd_free_window(C22);
+
+    mzd_free(S);
+    mzd_free(T);
+    mzd_free(U);
+  }
+  /* deal with rest */
+  nnn*=2;
+  if (n > nnn) {
+    /*         | C|    |AA|   | B|
+     * Compute | C| += |AA| x | B| */
+    packedmatrix *B_last_col = mzd_init_window(B, 0, nnn, k, n); 
+    packedmatrix *C_last_col = mzd_init_window(C, 0, nnn, m, n);
+    mzd_addmul_m4rm(C_last_col, A, B_last_col, 0);
+    mzd_free_window(B_last_col);
+    mzd_free_window(C_last_col);
+  }
+  mmm*=2;
+  if (m > mmm) {
+    /*         |  |    |  |   |B |
+     * Compute |C | += |AA| x |B | */
+    packedmatrix *A_last_row = mzd_init_window(A, mmm, 0, m, k);
+    packedmatrix *B_first_col= mzd_init_window(B,   0, 0, k, nnn);
+    packedmatrix *C_last_row = mzd_init_window(C, mmm, 0, m, nnn);
+    mzd_addmul_m4rm(C_last_row, A_last_row, B_first_col, 0);
+    mzd_free_window(A_last_row);
+    mzd_free_window(B_first_col);
+    mzd_free_window(C_last_row);
+  }
+  kkk*=2;
+  if (k > kkk) {
+    /* Add to  |  |   | B|   |C |
+     * result  |A | x |  | = |  | */
+    packedmatrix *A_last_col = mzd_init_window(A,   0, kkk, mmm, k);
+    packedmatrix *B_last_row = mzd_init_window(B, kkk,   0,   k, nnn);
+    packedmatrix *C_bulk = mzd_init_window(C, 0, 0, mmm, nnn);
+    mzd_addmul_m4rm(C_bulk, A_last_col, B_last_row, 0);
+    mzd_free_window(A_last_col);
+    mzd_free_window(B_last_row);
+    mzd_free_window(C_bulk);
+  }
+
+  return C;
+}
+
+packedmatrix *_mzd_addsqr_evenb(packedmatrix *C, packedmatrix *A, int cutoff) {
+  /**
+   * \todo make sure not to overwrite crap after ncols and before width*RADIX
+   */
+
+  size_t m;
+  size_t mmm;
+  
+  if(C->nrows == 0)
+    return C;
+
+  m = A->nrows;
+
+  /* handle case first, where the input matrices are too small already */
+  if (CLOSER(m, m/2, cutoff)) {
+    /* we copy the matrix first since it is only constant memory
+       overhead and improves data locality, if you remove it make sure
+       there are no speed regressions */
+    packedmatrix *Cbar = mzd_copy(NULL, C);
+    mzd_addmul_m4rm(Cbar, A, A, 0);
+    mzd_copy(C, Cbar);
+    mzd_free(Cbar);
+    return C;
+  }
+
+  /* adjust cutting numbers to work on words */
+  {
+    unsigned long mult = RADIX;
+    unsigned long width = m>>1;
+    while (width > cutoff) {
+      width>>=1;
+      mult<<=1;
+    }
+
+    mmm = (((m - m%mult)/RADIX) >> 1) * RADIX;
+  }
+
+  /*         |C |    |A |   |B | 
+   * Compute |  | += |  | x |  |  */
+  {
+    packedmatrix *A11 = mzd_init_window(A,   0,   0,   mmm,   mmm);
+    packedmatrix *A12 = mzd_init_window(A,   0, mmm,   mmm, 2*mmm);
+    packedmatrix *A21 = mzd_init_window(A, mmm,   0, 2*mmm,   mmm);
+    packedmatrix *A22 = mzd_init_window(A, mmm, mmm, 2*mmm, 2*mmm);
+
+    packedmatrix *C11 = mzd_init_window(C,   0,   0,   mmm,   mmm);
+    packedmatrix *C12 = mzd_init_window(C,   0, mmm,   mmm, 2*mmm);
+    packedmatrix *C21 = mzd_init_window(C, mmm,   0, 2*mmm,   mmm);
+    packedmatrix *C22 = mzd_init_window(C, mmm, mmm, 2*mmm, 2*mmm);
+  
+    /**
+     * \note See Marco Bodrato; "A Strassen-like Matrix Multiplication
+     * Suited for Squaring and Highest Power Computation"; on-line v.
+     * http://bodrato.it/papres/#CIVV2008 for reference on the used
+     * sequence of operations.
+     */
+
+    packedmatrix *S = mzd_init(mmm, mmm);
+    packedmatrix *U = mzd_init(mmm, mmm);
+
+    _mzd_add(S, A22, A21);                   /* 1  S = A22 - A21       */
+    _mzd_sqr_evenb(U, S, cutoff);            /* 3  U = S^2             */
+    _mzd_add(C22, U, C22);                   /* 4  C22 = U + C22       */
+    _mzd_add(C12, U, C12);                   /* 5  C12 = U + C12       */
+
+    _mzd_mul_evenb(U, A12, A21, cutoff);     /* 8  U = A12*A21         */
+    _mzd_add(C11, U, C11);                   /* 9  C11 = U + C11       */
+
+    _mzd_addsqr_evenb(C11, A11, cutoff);     /* 11 C11 = A11^2 + C11   */
+
+    _mzd_add(S, S, A12);                     /* 6  S = S + A12         */
+    _mzd_addsqr_evenb(U, S, cutoff);         /* 10 U = S^2 + U         */
+    _mzd_add(C12, C12, U);                   /* 15 C12 = U + C12       */
+
+    _mzd_add(S, A11, S);                     /* 12 S = A11 - S         */
+    _mzd_addmul_evenb(C12, S, A12, cutoff);  /* 14 C12 = S*B12 + C12   */
+
+    _mzd_addmul_evenb(C21, A21, S, cutoff);  /* 16 C21 = A21*T + C21   */
+
+    _mzd_add(S, A22, A12);                   /* 17 S = A22 + A21       */
+    _mzd_addsqr_evenb(U, S, cutoff);         /* 19 U = U - S^2         */
+    _mzd_add(C21, C21, U);                   /* 20 C21 = C21 - U3      */
+    _mzd_add(C22, C22, U);                   /* 21 C22 = C22 - U3      */
+
+    /* clean up */
+    mzd_free_window(A11); mzd_free_window(A12);
+    mzd_free_window(A21); mzd_free_window(A22);
+
+    mzd_free_window(C11); mzd_free_window(C12);
+    mzd_free_window(C21); mzd_free_window(C22);
+
+    mzd_free(S);
+    mzd_free(U);
+  }
+  /* deal with rest */
+  mmm*=2;
+  if (m > mmm) {
+    /*         | C|    |AA|   | B|
+     * Compute | C| += |AA| x | B| */
+    {
+      packedmatrix *A_last_col = mzd_init_window(A, 0, mmm, m, m); 
+      packedmatrix *C_last_col = mzd_init_window(C, 0, mmm, m, m);
+      mzd_addmul_m4rm(C_last_col, A, A_last_col, 0);
+      mzd_free_window(A_last_col);
+      mzd_free_window(C_last_col);
+    }
+    /*         |  |    |  |   |B |
+     * Compute |C | += |AA| x |B | */
+    {
+      packedmatrix *A_last_row = mzd_init_window(A, mmm, 0, m, m);
+      packedmatrix *A_first_col= mzd_init_window(A,   0, 0, m, mmm);
+      packedmatrix *C_last_row = mzd_init_window(C, mmm, 0, m, mmm);
+      mzd_addmul_m4rm(C_last_row, A_last_row, A_first_col, 0);
+      mzd_free_window(A_last_row);
+      mzd_free_window(A_first_col);
+      mzd_free_window(C_last_row);
+    }
+    /* Add to  |  |   | B|   |C |
+     * result  |A | x |  | = |  | */
+    {
+      packedmatrix *A_last_col = mzd_init_window(A,   0, mmm, mmm, m);
+      packedmatrix *A_last_row = mzd_init_window(A, mmm,   0,   m, mmm);
+      packedmatrix *C_bulk = mzd_init_window(C, 0, 0, mmm, mmm);
+      mzd_addmul_m4rm(C_bulk, A_last_col, A_last_row, 0);
+      mzd_free_window(A_last_col);
+      mzd_free_window(A_last_row);
+      mzd_free_window(C_bulk);
+    }
+  }
+
+  return C;
+}
+
 packedmatrix *_mzd_addmul(packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff){
   /**
    * Assumes that B and C are aligned in the same manner (as in a Schur complement)
   
   if (!A->offset){
     if (!B->offset) /* A even, B even */
-      return _mzd_addmul_even (C, A, B, cutoff);
+      return (A==B)?_mzd_addsqr_evenb(C, A, cutoff):_mzd_addmul_evenb (C, A, B, cutoff);
     else {  /* A even, B weird */
       size_t bnc = RADIX - B->offset;
       if (B->ncols <= bnc){
 	packedmatrix * B1 = mzd_init_window (B, 0, bnc, B->nrows, B->ncols);
 	packedmatrix * C1 = mzd_init_window (C, 0, bnc, C->nrows, C->ncols);
 	_mzd_addmul_even_weird  (C0,  A, B0, cutoff);
-	_mzd_addmul_even (C1, A, B1, cutoff);
+	_mzd_addmul_evenb (C1, A, B1, cutoff);
 	mzd_free_window (B0); mzd_free_window (B1);
 	mzd_free_window (C0); mzd_free_window (C1);
       }
       _mzd_addmul_weird_weird (C0, A0, B00, cutoff);
       _mzd_addmul_even_weird  (C0,  A1, B10, cutoff);
       _mzd_addmul_weird_even  (C1,  A0, B01, cutoff);
-      _mzd_addmul_even  (C1,  A1, B11, cutoff);
+      _mzd_addmul_evenb  (C1,  A1, B11, cutoff);
 
       mzd_free_window (A0);  mzd_free_window (A1);
       mzd_free_window (C0);  mzd_free_window (C1);
       packedmatrix * B0  = mzd_init_window (B, 0, 0, anc, B->ncols);
       packedmatrix * B1  = mzd_init_window (B, anc, 0, B->nrows, B->ncols);
       _mzd_addmul_weird_even (C, A0, B0, cutoff);
-      _mzd_addmul_even  (C, A1, B1, cutoff);
+      _mzd_addmul_evenb  (C, A1, B1, cutoff);
       mzd_free_window (A0); mzd_free_window (A1);
       mzd_free_window (B0); mzd_free_window (B1);
     }
   for (size_t i=0; i < A->nrows; ++i){
     tmp->values [tmp->rowswap[i]] = (A->values [A->rowswap [i]] << A->offset);
   }
-  _mzd_addmul_even (C, tmp, B, cutoff);
+  _mzd_addmul_evenb (C, tmp, B, cutoff);
   mzd_free(tmp);
   return C;
 }
    word mask = ((ONE << B->ncols) - 1) << (RADIX-B->offset - B->ncols);
    for (size_t i=0; i < B->nrows; ++i)
      tmp->values [tmp->rowswap[i]] = B->values [B->rowswap [i]] & mask;
-   _mzd_addmul_even (C, A, tmp, cutoff);
+   _mzd_addmul_evenb (C, A, tmp, cutoff);
    C->offset=offset;
    C->ncols = cncols;
    mzd_free (tmp);

File testsuite/test_multiplication.c

   int ret  = 0;
   packedmatrix *A, *C, *D, *E;
   
-  printf("   mul: m: %4d, k: %2d, cutoff: %4d",m,k,cutoff);
+  printf("   sqr: m: %4d, k: %2d, cutoff: %4d",m,k,cutoff);
 
   /* we create one random matrix */
   A = mzd_init(m, m);
   return ret;
 }
 
+int addsqr_test_equality(int m, int k, int cutoff) {
+  int ret  = 0;
+  packedmatrix *A, *C, *D, *E, *F;
+  
+  printf("addsqr: m: %4d, k: %2d, cutoff: %4d",m,k,cutoff);
+
+  /* we create two random matrices */
+  A = mzd_init(m, m);
+  C = mzd_init(m, m);
+  mzd_randomize(A);
+  mzd_randomize(C);
+
+  /* D = C + A*B via M4RM, temporary buffers are managed internally */
+  D = mzd_copy(NULL, C);
+  D = mzd_addmul_m4rm(D, A, A, k);
+
+  /* E = C + A*B via naive cubic multiplication */
+  E = mzd_mul_m4rm(NULL, A, A, k);
+  mzd_add(E, E, C);
+
+  /* F = C + A*B via naive cubic multiplication */
+  F = mzd_copy(NULL, C);
+  F = mzd_addmul(F, A, A, cutoff);
+
+  mzd_free(A);
+  mzd_free(C);
+
+  if (mzd_equal(D, E) != TRUE) {
+    printf(" M4RM != add,mul");
+    ret -=1;
+  }
+  if (mzd_equal(E, F) != TRUE) {
+    printf(" add,mul = addmul");
+    ret -=1;
+  }
+  if (mzd_equal(F, D) != TRUE) {
+    printf(" M4RM != addmul");
+    ret -=1;
+  }
+
+  if (ret==0)
+    printf(" ... passed\n");
+  else
+    printf(" ... FAILED\n");
+
+
+  mzd_free(D);
+  mzd_free(E);
+  mzd_free(F);
+  return ret;
+}
+
 int main(int argc, char **argv) {
   int status = 0;
   
   status += sqr_test_equality(2000, 0, 64);
   status += sqr_test_equality(210, 0, 64);
 
+  status += addsqr_test_equality(1, 0, 0);
+  status += addsqr_test_equality(131, 0, 0);
+  status += addsqr_test_equality(64,  0, 64);
+  status += addsqr_test_equality(128, 0, 64);
+  status += addsqr_test_equality(171, 0, 63);
+  status += addsqr_test_equality(171, 0, 131);
+  status += addsqr_test_equality(193, 10, 64);
+  status += addsqr_test_equality(1025, 3, 256);
+  status += addsqr_test_equality(4096, 0, 2048);
+  status += addsqr_test_equality(1000, 0, 256);
+  status += addsqr_test_equality(1000, 0, 64);
+  status += addsqr_test_equality(1710, 0, 256);
+  status += addsqr_test_equality(1290, 0, 64);
+  status += addsqr_test_equality(2000, 0, 256);
+  status += addsqr_test_equality(2000, 0, 64);
+  status += addsqr_test_equality(210, 0, 64);
+
   if (status == 0) {
     printf("All tests passed.\n");
     return 0;