Commits

Stian Andreassen committed 9bc690b

Revert _tc_mul to the best version and remove check_nonneg (it did't clear when compiling in jit mode)

Comments (0)

Files changed (2)

pypy/rlib/rbigint.py

 from pypy.rlib.rarithmetic import ovfcheck, r_longlong, widen, is_valid_int
 from pypy.rlib.rarithmetic import most_neg_value_of_same_type
 from pypy.rlib.rfloat import isfinite
-from pypy.rlib.debug import make_sure_not_resized, check_regular_int, check_nonneg
+from pypy.rlib.debug import make_sure_not_resized, check_regular_int
 from pypy.rlib.objectmodel import we_are_translated, specialize
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
             _check_digits(digits)
         make_sure_not_resized(digits)
         self._digits = digits
+        assert size >= 0
         self.size = size or len(digits)
         self.sign = sign
 
     setdigit._always_inline_ = True
 
     def numdigits(self):
-        return check_nonneg(self.size)
+        return self.size
     numdigits._always_inline_ = True
     
     @staticmethod
     return z
 
 
-def _tcmul_split(n, size):
+def _tcmul_split(n):
     """
     A helper for Karatsuba multiplication (k_mul).
     Takes a bigint "n" and an integer "size" representing the place to
     the return values are >= 0.
     """
     size_n = n.numdigits() // 3
-    size_lo = min(size_n, size)
-    lo = rbigint(n._digits[:size_lo], 1)
-    mid = rbigint(n._digits[size_lo:size_lo * 2], 1)
-    hi = rbigint(n._digits[size_lo *2:], 1)
+    lo = rbigint(n._digits[:size_n], 1)
+    mid = rbigint(n._digits[size_n:size_n * 2], 1)
+    hi = rbigint(n._digits[size_n *2:], 1)
     lo._normalize()
     mid._normalize()
     hi._normalize()
     return hi, mid, lo
 
+THREERBIGINT = rbigint.fromint(3)
 def _tc_mul(a, b):
     """
     Toom Cook
     bsize = b.numdigits()
 
     # Split a & b into hi, mid and lo pieces.
-    shift = (2+bsize) // 3
-    ah, am, al = _tcmul_split(a, shift)
+    shift = bsize // 3
+    ah, am, al = _tcmul_split(a)
     assert ah.sign == 1    # the split isn't degenerate
 
     if a is b:
         bm = am
         bl = al
     else:
-        bh, bm, bl = _tcmul_split(b, shift)
+        bh, bm, bl = _tcmul_split(b)
+        
     # 2. ahl, bhl
-    ahl = _x_add(al, ah)
-    bhl = _x_add(bl, bh)
-    
+    ahl = al.add(ah)
+    bhl = bl.add(bh)
+
     # Points
     v0 = al.mul(bl)
-    vn1 = ahl.sub(am).mul(bhl.sub(bm))
-    
-    ahml = _x_add(ahl, am)
-    bhml = _x_add(bhl, bm)
-    
-    v1 = ahml.mul(bhml)
-    v2 = _x_add(ahml, ah).lshift(1).sub(al).mul(_x_add(bhml, bh).lshift(1).sub(bl))
+    v1 = ahl.add(bm).mul(bhl.add(bm))
+
+    vn1 = ahl.sub(bm).mul(bhl.sub(bm))
+    v2 = al.add(am.lqshift(1)).add(ah.lshift(2)).mul(bl.add(bm.lqshift(1)).add(bh.lqshift(2)))
+
     vinf = ah.mul(bh)
-    
-    t2 = _x_sub(v2, vn1)
-    _inplace_divrem1(t2, t2, 3)
-    tn1 = v1.sub(vn1)
-    _v_rshift(tn1, tn1, tn1.numdigits(), 1)
-    t1 = v1
-    _v_isub(t1, 0, t1.numdigits(), v0, v0.numdigits())
-    _v_isub(t2, 0, t2.numdigits(), t1, t1.numdigits())
+
+    # Construct
+    t1 = v0.mul(THREERBIGINT).add(vn1.lqshift(1)).add(v2)
+    _inplace_divrem1(t1, t1, 6)
+    t1 = t1.sub(vinf.lqshift(1))
+    t2 = v1
+    _v_iadd(t2, 0, t2.numdigits(), vn1, vn1.numdigits())
     _v_rshift(t2, t2, t2.numdigits(), 1)
-    _v_isub(t1, 0, t1.numdigits(), tn1, tn1.numdigits())
-    _v_isub(t1, 0, t1.numdigits(), vinf, vinf.numdigits())
-    
-    t2 = t2.sub(vinf.lshift(1))
-    _v_isub(tn1, 0, tn1.numdigits(), t2, t2.numdigits())
-    
+
+    r1 = v1.sub(t1)
+    r2 = t2
+    _v_isub(r2, 0, r2.numdigits(), v0, v0.numdigits())
+    r2 = r2.sub(vinf)
+    r3 = t1
+    _v_isub(r3, 0, r3.numdigits(), t2, t2.numdigits())
+
     # Now we fit t+ t2 + t4 into the new string.
     # Now we got to add the r1 and r3 in the mid shift.
     # Allocate result space.
     ret = rbigint([NULLDIGIT] * (4 * shift + vinf.numdigits() + 1), 1)  # This is because of the size of vinf
     
     ret._digits[:v0.numdigits()] = v0._digits
-    #print ret.numdigits(), r2.numdigits(), vinf.numdigits(), shift, shift * 5, asize, bsize
-    #print r2.sign >= 0
     assert t2.sign >= 0
-    #print 2*shift + r2.numdigits() < ret.numdigits()
     assert 2*shift + t2.numdigits() < ret.numdigits()
-    ret._digits[shift * 2:shift * 2+t2.numdigits()] = t2._digits
-    #print vinf.sign >= 0
+    ret._digits[shift * 2:shift * 2+r2.numdigits()] = r2._digits
     assert vinf.sign >= 0
-    #print 4*shift + vinf.numdigits() <= ret.numdigits()
     assert 4*shift + vinf.numdigits() <= ret.numdigits()
     ret._digits[shift*4:shift*4+vinf.numdigits()] = vinf._digits
 
 
     i = ret.numdigits() - shift
-    _v_iadd(ret, shift, i, tn1, tn1.numdigits())
-    _v_iadd(ret, shift * 3, i, t1, t1.numdigits())
+    _v_iadd(ret, shift * 3, i, r3, r3.numdigits())
+    _v_iadd(ret, shift, i, r1, r1.numdigits())
+    
 
     ret._normalize()
     return ret

pypy/translator/goal/targetbigintbenchmark.py

     sumTime = 0.0
     
     
-    """ t = time()
+    """t = time()
     by = rbigint.fromint(2**62).lshift(1030000)
     for n in xrange(5000):
         by2 = by.lshift(63)