1. CarloWood
  2. M4RI

Commits

Martin Albrecht  committed 2e37b6e

fix a SIGSEGV and sometimes wrong results for matrix multiplication

  • Participants
  • Parent commits 48253ac
  • Branches default

Comments (0)

Files changed (3)

File Makefile.am

View file
 include_HEADERS = src/m4ri.h src/brilliantrussian.h src/misc.h src/packedmatrix.h src/grayflex.h src/watch.h src/strassen.h src/parity.h src/permutation.h src/config.h src/trsm.h src/lqup.h
 
 #libm4ri_la_LDFLAGS = -version-info 0:0:0
-libm4ri_la_LDFLAGS = -release 0.0.20080817
+libm4ri_la_LDFLAGS = -release 0.0.20080821

File src/strassen.c

View file
 
   /* adjust cutting numbers to work on words */
   unsigned long mult = 1;
-  long width = a;
+  long width = MIN(MIN(a,b),c);
   while (width > 2*cutoff) {
     width/=2;
     mult*=2;
    * but some called function doesn't like that and we end up with a
    * wrong result if we use virtual X0 matrices. Ideally, this should
    * be fixed not worked around. The check whether the bug has been
-   * fixed, use only one X0 and check if mzd_mul_strassen(4096, 3528,
-   * 4096, 1024) still returns the correct answer.
+   * fixed, use only one X0 and check if mzd_mul(4096, 3528, 4096,
+   * 1024) still returns the correct answer.
    */
 
   mzd_free(X0);
 
 packedmatrix *mzd_mul(packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff) {
   if(A->ncols != B->nrows)
-    m4ri_die("mzd_mul_strassen: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
+    m4ri_die("mzd_mul: A ncols (%d) need to match B nrows (%d).\n", A->ncols, B->nrows);
   
   if (cutoff < 0)
-    m4ri_die("mzd_mul_strassen: cutoff must be > 0.\n");
+    m4ri_die("mzd_mul: cutoff must be >= 0.\n");
 
   if(cutoff == 0) {
     cutoff = STRASSEN_MUL_CUTOFF;
   }
 
   cutoff = cutoff/RADIX * RADIX;
-  if (cutoff == 0) {
+  if (cutoff < RADIX) {
     cutoff = RADIX;
   };
 
   if (C == NULL) {
     C = mzd_init(A->nrows, B->ncols);
   } else if (C->nrows != A->nrows || C->ncols != B->ncols){
-    m4ri_die("mzd_mul_strassen: C (%d x %d) has wrong dimensions, expected (%d x %d)\n",
+    m4ri_die("mzd_mul: C (%d x %d) has wrong dimensions, expected (%d x %d)\n",
 	     C->nrows, C->ncols, A->nrows, B->ncols);
   }
 #ifdef HAVE_OPENMP
        overhead and improves data locality, if you remove it make sure
        there are no speed regressions */
     packedmatrix *Cbar = mzd_copy (NULL, C);
-
     mzd_addmul_m4rm (Cbar, A, B, 0);
     mzd_copy(C, Cbar);
     mzd_free(Cbar);
 
   /* adjust cutting numbers to work on words */
   unsigned long mult = 1;
-  long width = a;
+  long width = MIN(MIN(a,b),c);
   while (width > 2*cutoff) {
     width/=2;
     mult*=2;
   return C;
 }
 
-packedmatrix *_mzd_addmul (packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff){
+packedmatrix *_mzd_addmul(packedmatrix *C, packedmatrix *A, packedmatrix *B, int cutoff){
   /**
    * Assumes that B and C are aligned in the same manner (as in a Schur complement)
    */
   }
   
   cutoff = cutoff/RADIX * RADIX;
-  if (cutoff == 0) {
+  if (cutoff < RADIX) {
     cutoff = RADIX;
   };
 

File testsuite/test_multiplication.c

View file
   status += mul_test_equality(2048, 2048, 4096, 0, 1024);
   status += mul_test_equality(4096, 3528, 4096, 0, 1024);
   status += mul_test_equality(1024, 1025, 1, 0, 1024);
+  status += mul_test_equality(1000,1000,1000, 0, 256);
+  status += mul_test_equality(1000,10,20, 0, 64);
+  status += mul_test_equality(1710,1290,1000, 0, 256);
+  status += mul_test_equality(1290, 1710, 200, 0, 64);
+  status += mul_test_equality(1290, 1710, 2000, 0, 256);
+  status += mul_test_equality(1290, 1290, 2000, 0, 64);
+  status += mul_test_equality(1000, 210, 200, 0, 64);
 
   status += addmul_test_equality(21, 171, 31, 0, 63);
   status += addmul_test_equality(21, 171, 31, 0, 131);
   status += addmul_test_equality(193, 65, 65, 10, 64);
   status += addmul_test_equality(1025, 1025, 1025, 3, 256);
   status += addmul_test_equality(4096, 4096, 4096, 0, 2048);
+  status += addmul_test_equality(1000,1000,1000, 0, 256);
+  status += addmul_test_equality(1000,10,20, 0, 64);
+  status += addmul_test_equality(1710,1290,1000, 0, 256);
+  status += addmul_test_equality(1290, 1710, 200, 0, 64);
+  status += addmul_test_equality(1290, 1710, 2000, 0, 256);
+  status += addmul_test_equality(1290, 1290, 2000, 0, 64);
+  status += addmul_test_equality(1000, 210, 200, 0, 64);
 
   if (status == 0) {
     printf("All tests passed.\n");