Commits

Armin Rigo committed 9439564

Patch by qbproger: write square roots with the x86 SSE2
instruction SQRTSD. Also cleans up a bit the error checking
around math.sqrt().

Comments (0)

Files changed (11)

pypy/jit/backend/x86/assembler.py

         effectinfo = op.getdescr().get_extra_info()
         oopspecindex = effectinfo.oopspecindex
         genop_llong_list[oopspecindex](self, op, arglocs, resloc)
+        
+    def regalloc_perform_math(self, op, arglocs, resloc):
+        effectinfo = op.getdescr().get_extra_info()
+        oopspecindex = effectinfo.oopspecindex
+        genop_math_list[oopspecindex](self, op, arglocs, resloc)
 
     def regalloc_perform_with_guard(self, op, guard_op, faillocs,
                                     arglocs, resloc, current_depths):
     genop_guard_float_eq = _cmpop_guard_float("E", "E", "NE","NE")
     genop_guard_float_gt = _cmpop_guard_float("A", "B", "BE","AE")
     genop_guard_float_ge = _cmpop_guard_float("AE","BE", "B", "A")
+    
+    def genop_math_sqrt(self, op, arglocs, resloc):
+        self.mc.SQRTSD(arglocs[0], resloc)
 
     def genop_guard_float_ne(self, op, guard_op, guard_token, arglocs, result_loc):
         guard_opnum = guard_op.getopnum()
 genop_discard_list = [Assembler386.not_implemented_op_discard] * rop._LAST
 genop_list = [Assembler386.not_implemented_op] * rop._LAST
 genop_llong_list = {}
+genop_math_list = {}
 genop_guard_list = [Assembler386.not_implemented_op_guard] * rop._LAST
 
 for name, value in Assembler386.__dict__.iteritems():
         opname = name[len('genop_llong_'):]
         num = getattr(EffectInfo, 'OS_LLONG_' + opname.upper())
         genop_llong_list[num] = value
+    elif name.startswith('genop_math_'):
+        opname = name[len('genop_math_'):]
+        num = getattr(EffectInfo, 'OS_MATH_' + opname.upper())
+        genop_math_list[num] = value
     elif name.startswith('genop_'):
         opname = name[len('genop_'):]
         num = getattr(rop, opname.upper())

pypy/jit/backend/x86/regalloc.py

         if not we_are_translated():
             self.assembler.dump('%s <- %s(%s)' % (result_loc, op, arglocs))
         self.assembler.regalloc_perform_llong(op, arglocs, result_loc)
+        
+    def PerformMath(self, op, arglocs, result_loc):
+        if not we_are_translated():
+            self.assembler.dump('%s <- %s(%s)' % (result_loc, op, arglocs))
+        self.assembler.regalloc_perform_math(op, arglocs, result_loc)
 
     def locs_for_fail(self, guard_op):
         return [self.loc(v) for v in guard_op.getfailargs()]
     consider_float_gt = _consider_float_cmp
     consider_float_ge = _consider_float_cmp
 
-    def consider_float_neg(self, op):
+    def _consider_float_unary_op(self, op):
         loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(0))
         self.Perform(op, [loc0], loc0)
         self.xrm.possibly_free_var(op.getarg(0))
-
-    def consider_float_abs(self, op):
-        loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(0))
-        self.Perform(op, [loc0], loc0)
-        self.xrm.possibly_free_var(op.getarg(0))
+        
+    consider_float_neg = _consider_float_unary_op
+    consider_float_abs = _consider_float_unary_op
 
     def consider_cast_float_to_int(self, op):
         loc0 = self.xrm.make_sure_var_in_reg(op.getarg(0))
         loc1 = self.rm.make_sure_var_in_reg(op.getarg(1))
         self.PerformLLong(op, [loc1], loc0)
         self.rm.possibly_free_vars_for_op(op)
+        
+    def _consider_math_sqrt(self, op):
+        loc0 = self.xrm.force_result_in_reg(op.result, op.getarg(1))
+        self.PerformMath(op, [loc0], loc0)
+        self.xrm.possibly_free_var(op.getarg(1))
 
     def _call(self, op, arglocs, force_store=[], guard_not_forced_op=None):
         save_all_regs = guard_not_forced_op is not None
                    guard_not_forced_op=guard_not_forced_op)
 
     def consider_call(self, op):
-        if IS_X86_32:
-            # support for some of the llong operations,
-            # which only exist on x86-32
-            effectinfo = op.getdescr().get_extra_info()
-            if effectinfo is not None:
-                oopspecindex = effectinfo.oopspecindex
+        effectinfo = op.getdescr().get_extra_info()
+        if effectinfo is not None:
+            oopspecindex = effectinfo.oopspecindex
+            if IS_X86_32:
+                # support for some of the llong operations,
+                # which only exist on x86-32
                 if oopspecindex in (EffectInfo.OS_LLONG_ADD,
                                     EffectInfo.OS_LLONG_SUB,
                                     EffectInfo.OS_LLONG_AND,
                 if oopspecindex == EffectInfo.OS_LLONG_LT:
                     if self._maybe_consider_llong_lt(op):
                         return
-        #
+            if oopspecindex == EffectInfo.OS_MATH_SQRT:
+                return self._consider_math_sqrt(op)
         self._consider_call(op)
 
     def consider_call_may_force(self, op, guard_op):

pypy/jit/backend/x86/regloc.py

     UCOMISD = _binaryop('UCOMISD')
     CVTSI2SD = _binaryop('CVTSI2SD')
     CVTTSD2SI = _binaryop('CVTTSD2SI')
+    
+    SQRTSD = _binaryop('SQRTSD')
 
     ANDPD = _binaryop('ANDPD')
     XORPD = _binaryop('XORPD')

pypy/jit/backend/x86/rx86.py

 define_modrm_modes('MOVSD_x*', ['\xF2', rex_nw, '\x0F\x10', register(1,8)], regtype='XMM')
 define_modrm_modes('MOVSD_*x', ['\xF2', rex_nw, '\x0F\x11', register(2,8)], regtype='XMM')
 
+define_modrm_modes('SQRTSD_x*', ['\xF2', rex_nw, '\x0F\x51', register(1,8)], regtype='XMM')
+
 #define_modrm_modes('XCHG_r*', [rex_w, '\x87', register(1, 8)])
 
 define_modrm_modes('ADDSD_x*', ['\xF2', rex_nw, '\x0F\x58', register(1, 8)], regtype='XMM')

pypy/jit/codewriter/effectinfo.py

     OS_LLONG_UGE                = 91
     OS_LLONG_URSHIFT            = 92
     OS_LLONG_FROM_UINT          = 93
+    #
+    OS_MATH_SQRT                = 100
 
     def __new__(cls, readonly_descrs_fields,
                 write_descrs_fields, write_descrs_arrays,

pypy/jit/codewriter/jtransform.py

             prepare = self._handle_jit_call
         elif oopspec_name.startswith('libffi_'):
             prepare = self._handle_libffi_call
+        elif oopspec_name.startswith('math.sqrt'):
+            prepare = self._handle_math_sqrt_call
         else:
             prepare = self.prepare_builtin_call
         try:
         assert vinfo is not None
         self.vable_flags[op.args[0]] = op.args[2].value
         return []
+        
+    # ---------
+    # ll_math.sqrt_nonneg()
+    
+    def _handle_math_sqrt_call(self, op, oopspec_name, args):
+        return self._handle_oopspec_call(op, args, EffectInfo.OS_MATH_SQRT,
+                                         EffectInfo.EF_PURE)
 
 # ____________________________________________________________
 

pypy/jit/codewriter/support.py

 from pypy.rpython import rlist
 from pypy.rpython.lltypesystem import rstr as ll_rstr, rdict as ll_rdict
 from pypy.rpython.lltypesystem import rlist as lltypesystem_rlist
+from pypy.rpython.lltypesystem.module import ll_math
 from pypy.rpython.lltypesystem.lloperation import llop
 from pypy.rpython.ootypesystem import rdict as oo_rdict
 from pypy.rpython.llinterp import LLInterpreter
         return -x
     else:
         return x
+        
+# math support
+# ------------
+
+_ll_1_ll_math_ll_math_sqrt = ll_math.ll_math_sqrt
 
 
 # long long support
     ('int_mod_zer',          [lltype.Signed, lltype.Signed], lltype.Signed),
     ('int_lshift_ovf',       [lltype.Signed, lltype.Signed], lltype.Signed),
     ('int_abs',              [lltype.Signed],                lltype.Signed),
+    ('ll_math.ll_math_sqrt', [lltype.Float],                 lltype.Float),
     ]
 
 

pypy/jit/codewriter/test/test_jtransform.py

 from pypy.jit.codewriter.jtransform import Transformer
 from pypy.jit.metainterp.history import getkind
 from pypy.rpython.lltypesystem import lltype, llmemory, rclass, rstr, rlist
+from pypy.rpython.lltypesystem.module import ll_math
 from pypy.translator.unsimplify import varoftype
 from pypy.jit.codewriter import heaptracker, effectinfo
 from pypy.jit.codewriter.flatten import ListOfKind
             PUNICODE = lltype.Ptr(rstr.UNICODE)
             INT = lltype.Signed
             UNICHAR = lltype.UniChar
+            FLOAT = lltype.Float
             argtypes = {
+             EI.OS_MATH_SQRT:  ([FLOAT], FLOAT),
              EI.OS_STR2UNICODE:([PSTR], PUNICODE),
              EI.OS_STR_CONCAT: ([PSTR, PSTR], PSTR),
              EI.OS_STR_SLICE:  ([PSTR, INT, INT], PSTR),
     assert op1.args[1] == 'calldescr-%d' % effectinfo.EffectInfo.OS_ARRAYCOPY
     assert op1.args[2] == ListOfKind('int', [v3, v4, v5])
     assert op1.args[3] == ListOfKind('ref', [v1, v2])
+
+def test_math_sqrt():
+    # test that the oopspec is present and correctly transformed
+    FLOAT = lltype.Float
+    FUNC = lltype.FuncType([FLOAT], FLOAT)
+    func = lltype.functionptr(FUNC, 'll_math',
+                              _callable=ll_math.sqrt_nonneg)
+    v1 = varoftype(FLOAT)
+    v2 = varoftype(FLOAT)
+    op = SpaceOperation('direct_call', [const(func), v1], v2)
+    tr = Transformer(FakeCPU(), FakeBuiltinCallControl())
+    op1 = tr.rewrite_operation(op)
+    assert op1.opname == 'residual_call_irf_f'
+    assert op1.args[0].value == func
+    assert op1.args[1] == 'calldescr-%d' % effectinfo.EffectInfo.OS_MATH_SQRT
+    assert op1.args[2] == ListOfKind("int", [])
+    assert op1.args[3] == ListOfKind("ref", [])
+    assert op1.args[4] == ListOfKind('float', [v1])
+    assert op1.result == v2

pypy/jit/metainterp/test/support.py

 from pypy.jit.metainterp.warmstate import set_future_value
 from pypy.jit.codewriter.policy import JitPolicy
 from pypy.jit.codewriter import longlong
+from pypy.rlib.rfloat import isinf, isnan
 
 def _get_jitcodes(testself, CPUClass, func, values, type_system,
                   supports_longlong=False, **kwds):
         result1 = _run_with_blackhole(self, args)
         # try to run it with pyjitpl.py
         result2 = _run_with_pyjitpl(self, args)
-        assert result1 == result2
+        assert result1 == result2 or isnan(result1) and isnan(result2)
         # try to run it by running the code compiled just before
         result3 = _run_with_machine_code(self, args)
-        assert result1 == result3 or result3 == NotImplemented
+        assert result1 == result3 or result3 == NotImplemented or isnan(result1) and isnan(result3)
         #
         if (longlong.supports_longlong and
             isinstance(result1, longlong.r_float_storage)):

pypy/rpython/extfuncregistry.py

 register_external(math.floor, [float], float,
                   export_name="ll_math.ll_math_floor", sandboxsafe=True,
                   llimpl=ll_math.ll_math_floor)
+register_external(math.sqrt, [float], float,
+                  export_name="ll_math.ll_math_sqrt", sandboxsafe=True,
+                  llimpl=ll_math.ll_math_sqrt)
 
 complex_math_functions = [
     ('frexp', [float],        (float, int)),

pypy/rpython/lltypesystem/module/ll_math.py

 from pypy.rlib import jit, rposix
 from pypy.translator.tool.cbuild import ExternalCompilationInfo
 from pypy.translator.platform import platform
-from pypy.rlib.rfloat import isinf, isnan, INFINITY, NAN
+from pypy.rlib.rfloat import isfinite, isinf, isnan, INFINITY, NAN
 
 if sys.platform == "win32":
     if platform.name == "msvc":
                         [rffi.DOUBLE, rffi.DOUBLE], rffi.DOUBLE)
 math_floor = llexternal('floor', [rffi.DOUBLE], rffi.DOUBLE, pure_function=True)
 
+math_sqrt = llexternal('sqrt', [rffi.DOUBLE], rffi.DOUBLE)
+
+@jit.purefunction
+def sqrt_nonneg(x):
+    return math_sqrt(x)
+sqrt_nonneg.oopspec = "math.sqrt_nonneg(x)"
+
 # ____________________________________________________________
 #
 # Error handling functions
         _likely_raise(errno, r)
     return r
 
+def ll_math_sqrt(x):
+    if x < 0.0:
+        raise ValueError, "math domain error"
+    
+    if isfinite(x):
+        return sqrt_nonneg(x)
+
+    return x   # +inf or nan
+
 # ____________________________________________________________
 #
 # Default implementations
 unary_math_functions = [
     'acos', 'asin', 'atan',
     'ceil', 'cos', 'cosh', 'exp', 'fabs',
-    'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'log', 'log10',
+    'sin', 'sinh', 'tan', 'tanh', 'log', 'log10',
     'acosh', 'asinh', 'atanh', 'log1p', 'expm1',
     ]
 unary_math_functions_can_overflow = [