Commits

Anonymous committed 71849b7 Draft

(price, arigato) Apply signature to return type

  • Participants
  • Parent commits 7ad4dbc
  • Branches signatures

Comments (0)

Files changed (3)

pypy/annotation/description.py

 from __future__ import absolute_import
 import types, py
-from pypy.annotation.signature import enforce_signature_args
+from pypy.annotation.signature import enforce_signature_args, enforce_signature_return
 from pypy.objspace.flow.model import Constant, FunctionGraph
 from pypy.objspace.flow.bytecode import cpython_code_signature
 from pypy.objspace.flow.argument import rawshape, ArgErr
             new_args = args.unmatch_signature(self.signature, inputcells)
             inputcells = self.parse_arguments(new_args, graph)
             result = schedule(graph, inputcells)
+            signature = getattr(self.pyobj, '_signature_', None)
+            if signature:
+                result = enforce_signature_return(self, signature[1], result)
+                self.bookkeeper.annotator.addpendingblock(graph, graph.returnblock, [result])
         # Some specializations may break the invariant of returning
         # annotations that are always more general than the previous time.
         # We restore it here:

pypy/annotation/signature.py

         inputcells[:] = args_s
 
 
-def enforce_signature_args(funcdesc, argtypes, inputcells):
-    assert len(argtypes) == len(inputcells)
-    args_s = []
-    for i, argtype in enumerate(argtypes):
-        args_s.append(annotation(argtype, bookkeeper=funcdesc.bookkeeper))
-    for i, (s_arg, s_input) in enumerate(zip(args_s, inputcells)):
-        if not s_arg.contains(s_input):
+def enforce_signature_args(funcdesc, paramtypes, actualtypes):
+    assert len(paramtypes) == len(actualtypes)
+    params_s = []
+    for i, paramtype in enumerate(paramtypes):
+        params_s.append(annotation(paramtype, bookkeeper=funcdesc.bookkeeper))
+    for i, (s_param, s_actual) in enumerate(zip(params_s, actualtypes)):
+        if not s_param.contains(s_actual):
             raise Exception("%r argument %d:\n"
                             "expected %s,\n"
-                            "     got %s" % (funcdesc, i+1,
-                                         s_arg,
-                                         s_input))
-    inputcells[:] = args_s
+                            "     got %s" % (funcdesc, i+1, s_param, s_actual))
+    actualtypes[:] = params_s
+
+
+def enforce_signature_return(funcdesc, sigtype, inferredtype):
+    annsigtype = annotation(sigtype, bookkeeper=funcdesc.bookkeeper)
+    return annsigtype

pypy/rlib/test/test_objectmodel.py

         return a + len(b)
     assert getsig(f) == [model.SomeInteger(), model.SomeString(), model.SomeInteger()]
 
+def test_signature_return():
+    @signature(returns=types.str())
+    def f():
+        return 'a'
+    assert getsig(f) == [model.SomeString()]
+
+    @signature(types.str(), returns=types.str())
+    def f(x):
+        return x
+    def g():
+        return f('a')
+    t = TranslationContext()
+    a = t.buildannotator()
+    a.annotate_helper(g, [])
+    assert a.bindings[graphof(t, f).startblock.inputargs[0]] == model.SomeString()
+    assert a.bindings[graphof(t, f).getreturnvar()] == model.SomeString()
+
+
 def test_signature_errors():
     @signature(types.int(), types.str(), returns=types.int())
     def f(a, b):