Commits

Amaury Forgeot d'Arc  committed b81b759

Add 4 basic operations.

  • Participants
  • Parent commits e6fb259
  • Branches decimal-libmpdec

Comments (0)

Files changed (4)

File pypy/module/_decimal/interp_context.py

         if to_trap:
             raise interp_signals.flags_as_exception(space, to_trap)
 
+    def catch_status(self, space):
+        return ContextStatus(space, self)
+
     def copy_w(self, space):
         w_copy = W_Context(space)
         rffi.structcopy(w_copy.ctx, self.ctx)
         context = getcontext(space)
     return context
 
+class ContextStatus:
+    def __init__(self, space, context):
+        self.space = space
+        self.context = context
+
+    def __enter__(self):
+        self.status_ptr = lltype.malloc(rffi.CArrayPtr(rffi.UINT).TO, 1,
+                                        flavor='raw', zero=True)
+        return self.context.ctx, self.status_ptr
+        
+    def __exit__(self, *args):
+        status = rffi.cast(lltype.Signed, self.status_ptr[0])
+        lltype.free(self.status_ptr, flavor='raw')
+        # May raise a DecimalException
+        self.context.addstatus(self.space, status)
+
+
 class ConvContext:
     def __init__(self, space, mpd, context, exact):
         self.space = space

File pypy/module/_decimal/interp_decimal.py

 from rpython.rlib import rmpdec, rarithmetic, rbigint, rfloat
+from rpython.rlib.objectmodel import specialize
 from rpython.rlib.rstring import StringBuilder
 from rpython.rtyper.lltypesystem import rffi, lltype
 from pypy.interpreter.baseobjspace import W_Root
     def apply(self, space, context, w_subtype=None):
         # Apply the context to the input operand. Return a new W_Decimal.
         w_result = W_Decimal.allocate(space, w_subtype)
-        with lltype.scoped_alloc(rffi.CArrayPtr(rffi.UINT).TO, 1) as status_ptr:
+        with context.catch_status(space) as (ctx, status_ptr):
             rmpdec.mpd_qcopy(w_result.mpd, self.mpd, status_ptr)
-            context.addstatus(space, rffi.cast(lltype.Signed, status_ptr[0]))
             rmpdec.mpd_qfinalize(w_result.mpd, context.ctx, status_ptr)
-            context.addstatus(space, rffi.cast(lltype.Signed, status_ptr[0]))
         return w_result
 
     def descr_str(self, space):
     def descr_eq(self, space, w_other):
         return self.compare(space, w_other, 'eq')
 
+    # Operations
+    @staticmethod
+    def convert_op(space, w_value, context):
+        if isinstance(w_value, W_Decimal):
+            return None, w_value
+        elif space.isinstance_w(w_value, space.w_int):
+            value = space.bigint_w(w_value)
+            return None, decimal_from_bigint(space, None, value, context,
+                                             exact=True)
+        return space.w_NotImplemented, None
+
+    def convert_binop(self, space, w_other, context):
+        w_err, w_a = W_Decimal.convert_op(space, self, context)
+        if w_err:
+            return w_err, None, None
+        w_err, w_b = W_Decimal.convert_op(space, w_other, context)
+        if w_err:
+            return w_err, None, None
+        return None, w_a, w_b
+
+    def binary_number_method(self, space, mpd_func, w_other):
+        context = interp_context.getcontext(space)
+
+        w_err, w_a, w_b = self.convert_binop(space, w_other, context)
+        if w_err:
+            return w_err
+        w_result = W_Decimal.allocate(space)
+        with context.catch_status(space) as (ctx, status_ptr):
+            mpd_func(w_result.mpd, w_a.mpd, w_b.mpd, ctx, status_ptr)
+        return w_result
+
+    def descr_add(self, space, w_other):
+        return self.binary_number_method(space, rmpdec.mpd_qadd, w_other)
+    def descr_sub(self, space, w_other):
+        return self.binary_number_method(space, rmpdec.mpd_qsub, w_other)
+    def descr_mul(self, space, w_other):
+        return self.binary_number_method(space, rmpdec.mpd_qmul, w_other)
+    def descr_truediv(self, space, w_other):
+        return self.binary_number_method(space, rmpdec.mpd_qdiv, w_other)
+
     # Boolean functions
     def is_qnan_w(self, space):
         return space.wrap(bool(rmpdec.mpd_isqnan(self.mpd)))
     w_result.mpd.c_exp = - k
 
     if not exact:
-        with lltype.scoped_alloc(rffi.CArrayPtr(rffi.UINT).TO, 1) as status_ptr:
-            rmpdec.mpd_qfinalize(w_result.mpd, context.ctx, status_ptr)
-            context.addstatus(space, rffi.cast(lltype.Signed, status_ptr[0]))
+        with context.catch_status(space) as (ctx, status_ptr):
+            rmpdec.mpd_qfinalize(w_result.mpd, ctx, status_ptr)
     return w_result
 
 def decimal_from_object(space, w_subtype, w_value, context, exact=True):
     __bool__ = interp2app(W_Decimal.descr_bool),
     __float__ = interp2app(W_Decimal.descr_float),
     __eq__ = interp2app(W_Decimal.descr_eq),
+    #
+    __add__ = interp2app(W_Decimal.descr_add),
+    __sub__ = interp2app(W_Decimal.descr_sub),
+    __mul__ = interp2app(W_Decimal.descr_mul),
+    __truediv__ = interp2app(W_Decimal.descr_truediv),
+    #
     is_qnan = interp2app(W_Decimal.is_qnan_w),
     is_infinite = interp2app(W_Decimal.is_infinite_w),
     )

File pypy/module/_decimal/test/test_decimal.py

         InvalidOperation = self.decimal.InvalidOperation
         localcontext = self.decimal.localcontext
 
+        self.decimal.getcontext().traps[InvalidOperation] = False
+
         #empty
         assert str(Decimal('')) == 'NaN'
 
 
         nc = self.decimal.Context()
         r = nc.create_decimal(0.1)
-        assert assertEqual(type(r)) is Decimal
+        assert type(r) is Decimal
         assert str(r) == '0.1000000000000000055511151231'
         assert nc.create_decimal(float('nan')).is_qnan()
         assert nc.create_decimal(float('inf')).is_infinite()
             x = self.random_float()
             assert x == float(nc.create_decimal(x))  # roundtrip
 
+    def test_operations(self):
+        Decimal = self.decimal.Decimal
+
+        assert Decimal(4) + Decimal(3) == Decimal(7)
+        assert Decimal(4) - Decimal(3) == Decimal(1)
+        assert Decimal(4) * Decimal(3) == Decimal(12)
+        assert Decimal(6) / Decimal(3) == Decimal(2)

File rpython/rlib/rmpdec.py

         "mpd_iszero", "mpd_isnegative", "mpd_isinfinite", "mpd_isspecial",
         "mpd_isnan", "mpd_issnan", "mpd_isqnan",
         "mpd_qcmp",
-        "mpd_qpow", "mpd_qmul",
+        "mpd_qpow", "mpd_qadd", "mpd_qsub", "mpd_qmul", "mpd_qdiv",
         "mpd_qround_to_int",
         ],
     compile_extra=compile_extra,
     'mpd_qpow',
     [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
     lltype.Void)
+mpd_qadd = external(
+    'mpd_qadd',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
+mpd_qsub = external(
+    'mpd_qsub',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
 mpd_qmul = external(
     'mpd_qmul',
     [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
     lltype.Void)
+mpd_qdiv = external(
+    'mpd_qdiv',
+    [MPD_PTR, MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],
+    lltype.Void)
 
 mpd_qround_to_int = external(
     'mpd_qround_to_int', [MPD_PTR, MPD_PTR, MPD_CONTEXT_PTR, rffi.UINTP],