ebo avatar ebo committed c38e7fd

Added tail recursion elimination for simple cases

Comments (0)

Files changed (3)

         return set()
 
 
+class TailRecursionAnalysis(object):
+    """Marks simple tail recursions for later optimizations"""
+
+    def __init__(self):
+        pass
+
+    def analyse(self, node):
+        self.visit(node, None, True)
+
+    @qndispatch.on("node")
+    def visit(self, node, name, tail):
+        pass
+
+    @visit.when(vmast.UnaryOp)
+    def visit(self, node, name, tail):
+        self.visit(node.expr, name, False)
+
+    @visit.when(vmast.BinaryOp)
+    def visit(self, node, name, tail):
+        self.visit(node.e1, name, False)
+        self.visit(node.e2, name, False)
+
+    @visit.when(vmast.IfOp)
+    def visit(self, node, name, tail):
+        self.visit(node.test, name, False)
+        self.visit(node.true, name, tail)
+        self.visit(node.false, name, tail)
+
+    @visit.when(vmast.LetExpression)
+    def visit(self, node, name, tail):
+        self.visit(node.value, None, False)
+        self.visit(node.target, name, tail)
+
+    @visit.when(vmast.LetRecExpression)
+    def visit(self, node, name, tail):
+        for var, value in node.vars:
+            self.visit(value, var, False)
+        self.visit(node.target, name, tail)
+
+    @visit.when(vmast.FuncExpression)
+    def visit(self, node, name, tail):
+        self.visit(node.body, name, True)
+
+    @visit.when(vmast.CallExpression)
+    def visit(self, node, name, tail):
+        self.visit(node.name, name, False)
+        for arg in node.arguments:
+            self.visit(arg, name, False)
+
+        if (tail and
+            isinstance(node.name, vmast.Identifier) and
+            node.name.identifier == name):
+            node.tailrecursion = True
+
+    @visit.when(vmast.ListConstruction)
+    def visit(self, node, name, tail):
+        self.visit(node.item, name, False)
+        self.visit(node.rest, name, False)
+
+    @visit.when(vmast.ListMatch)
+    def visit(self, node, name, tail):
+        self.visit(node.empty, name, tail)
+        self.visit(node.filled, name, tail)
+
+    @visit.when(vmast.Token)
+    def visit(self, node, name, tail):
+        pass
+
+
 class TOPLetRecReorderingAnalysis(object):
 
     def getnodeps(self, defs):
         self.assert_('y' in l)
 
 
+class TestTOPOptimization(unittest.TestCase):
+    def test_tailrecursion(self):
+        func = topcompiler.compile("""
+let rec fac = fun i acc -> if i > 1 then
+        fac (i-1) (acc * i)
+    else
+        acc
+and fac2 = fun i -> fac i 1 in
+fac2""")
+        fac = func()
+        self.assertEqual(fac(4), 24)
+
+
 class TestTOPCompiler(unittest.TestCase):
     def test_compile(self):
         func = topcompiler.compile("0")
 from opcode import opmap, cmp_op
 import types
 
-from codeanalysis import CellVariablesAnalysis,\
-    FreeAnalysis, TOPLetRecReorderingAnalysis, TOPJumpCalculations
+from codeanalysis import (CellVariablesAnalysis, FreeAnalysis,
+                          TOPLetRecReorderingAnalysis, TOPJumpCalculations,
+                          TailRecursionAnalysis)
 
 PY_VERSION = None
 if sys.version_info[0] == 2:
 
     @code.when(vmast.CallExpression)
     def code(self, node, rho, kp):
+        if hasattr(node, "tailrecursion") and node.tailrecursion:
+            for i, arg in enumerate(node.arguments):
+                self.code(arg, rho, kp)
+            for i in reversed(xrange(len(node.arguments))):
+                self.dest.append(struct.pack("=BH", opmap["STORE_FAST"], i))
+            self.dest.append(struct.pack("=BH", opmap["JUMP_ABSOLUTE"], 0))
+            self.maxstack = max(self.maxstack, kp + 1)
+            return
         self.code(node.name, rho, kp)
         for i, arg in enumerate(node.arguments):
             self.code(arg, rho, kp + i + 1)
         self.dest.append(("LABEL", l2))
 
 
-def compile_code(ast, args=[], rho={}):
+def compile_code(ast, args, rho=None):
     """Creates a new CodeObject from ast"""
 
+    if rho == None:
+        rho = {}
+
     compiler = TOPCompiler()
 
     argcount = len(args)
         tuple(compiler.constants),
         tuple(compiler.names),  # names
         tuple(compiler.locals),  # varnames
-        "",  # filename
+        "<string>",  # filename
         "",  # name
         0,  # firstlineno
         "",  # lnotab
         args = code.func_code.co_varnames[:code.func_code.co_argcount]
 
     ast = plygram.parser.parse(doc)
+    TailRecursionAnalysis().analyse(ast)
+
     codeobj = compile_code(ast, args)
 
     func = types.FunctionType(codeobj, globals())
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.