Commits

Shashank Bharadwaj committed 0413840 Draft

adding the optimized for loop example too

  • Participants
  • Parent commits bd82dcd
  • Branches cc-change

Comments (0)

Files changed (6)

File src/org/python/antlr/ast/OptimizedFor.java

+/* Generated file, do not modify.  See jython/src/templates/typed/README.txt */
+
+package org.python.antlr.ast;
+
+import org.antlr.runtime.Token;
+import org.python.antlr.AST;
+import org.python.antlr.PythonTree;
+import org.python.antlr.base.expr;
+import org.python.antlr.base.stmt;
+import org.python.core.PyObject;
+import org.python.core.PyType;
+import org.python.expose.ExposedGet;
+import org.python.expose.ExposedType;
+
+@ExposedType(name = "_ast.OptimizedFor", base = AST.class)
+public class OptimizedFor extends For {
+
+    public static final PyType TYPE = PyType.fromClass(OptimizedFor.class);
+
+    public OptimizedFor(PyType subType) {
+        super(subType);
+    }
+
+    public OptimizedFor() {
+        super();
+    }
+
+    public OptimizedFor(PyObject target, PyObject iter, PyObject body, PyObject orelse) {
+        super(target, iter, body, orelse);
+    }
+
+    public OptimizedFor(Token token, expr target, expr iter, java.util.List<stmt> body, java.util.List<stmt>
+    orelse) {
+        super(token, target, iter, body, orelse);
+    }
+
+    public OptimizedFor(Integer ttype, Token token, expr target, expr iter, java.util.List<stmt> body,
+    java.util.List<stmt> orelse) {
+        super(ttype, token, target, iter, body, orelse);
+    }
+
+    public OptimizedFor(PythonTree tree, expr target, expr iter, java.util.List<stmt> body,
+    java.util.List<stmt> orelse) {
+        super(tree, target, iter, body, orelse);
+    }
+
+    @ExposedGet(name = "repr")
+    public String toString() {
+        return "OptimizedFor";
+    }
+
+    public <R> R accept(VisitorIF<R> visitor) throws Exception {
+        return visitor.visitOptimizedFor(this);
+    }
+
+}

File src/org/python/antlr/ast/VisitorBase.java

         return ret;
     }
 
+    public R visitOptimizedFor(OptimizedFor node) throws Exception {
+        R ret = unhandled_node(node);
+        traverse(node);
+        return ret;
+    }
+
     abstract protected R unhandled_node(PythonTree node) throws Exception;
     abstract public void traverse(PythonTree node) throws Exception;
 }

File src/org/python/antlr/ast/VisitorIF.java

     public R visitYieldRestoreLocals(YieldRestoreLocals node) throws Exception;
     public R visitYieldInput(YieldInput node) throws Exception;
 	public R visitFinallyGoto(FinallyGoto node) throws Exception;
+    public R visitOptimizedFor(OptimizedFor node) throws Exception;
 }

File src/org/python/compiler/OptimizedCodeCompiler.java

 import java.util.ArrayList;
 
 import org.objectweb.asm.Handle;
+import org.objectweb.asm.Label;
 import org.objectweb.asm.Opcodes;
 import org.python.antlr.ast.Call;
+import org.python.antlr.ast.For;
 import org.python.antlr.ast.Name;
+import org.python.antlr.ast.OptimizedFor;
 import org.python.antlr.ast.expr_contextType;
 import org.python.antlr.base.expr;
+import org.python.core.Py;
+import org.python.core.PyException;
+import org.python.core.PyInteger;
+import org.python.core.PyObject;
 import org.python.core.ThreadState;
+import org.python.core.opt.IndyOptimizerException;
 import org.python.core.opt.MethodHandleHelper;
+import org.python.core.opt.RangeOptimizer;
 import org.python.core.opt.SpecializeCallSite;
 
 /**
         return super.visitCall(node);
     }
     
-//    @Override
-//    public Object visitName(Name node) throws Exception {
-//        String name;
-//        if (fast_locals) {
-//            name = node.getInternalId();
-//        } else {
-//            name = getName(node.getInternalId());
-//        }
-//        SymInfo syminf = tbl.get(name);
-//        
-//        expr_contextType ctx = node.getInternalCtx();
-//
-//        if (ctx == expr_contextType.AugStore) {
-//            ctx = augmode;
-//        }
-//        
-//        switch(ctx){
-//            case Load: {
-//                if (my_scope != null && my_scope.ac != null && !my_scope.ac.arglist
-//                        && !my_scope.ac.keywordlist) {
-//                    int argcount = my_scope.ac.names.size();
-////                    if (((syminf.flags & ScopeInfo.BOUND) != 0) && ((syminf.flags & ScopeInfo.FROM_PARAM) != 0) &&
-//                    if ((syminf.flags == 13) &&
-//                            (syminf.locals_index < argcount) && (argcount <= MethodHandleHelper.MAX_ARGUMENT_ARITY)) {
-//                        int i = 0;
-//                        for (i = 0; i < argcount; i++) {
-//                            if (my_scope.ac.names.get(i).equals(name))
-//                                break;
-//                        }
-//                        if (i < argcount) {
-////                            Py.writeError(TYPE, "loading " + name + " from stack: " + syminf.locals_index);
-//                            code.aload(syminf.locals_index + VAR_OFFSET); 
-//                            return null;
-//                        }
-//                    }
-//                }
-//                break;
-//            }
-//            case Store: {
-//                if (my_scope != null && my_scope.ac != null && !my_scope.ac.arglist
-//                        && !my_scope.ac.keywordlist) {
-//                    int argcount = my_scope.ac.names.size();
-////                    if (((syminf.flags & ScopeInfo.BOUND) != 0) && ((syminf.flags & ScopeInfo.FROM_PARAM) != 0) && 
-//                    if ((syminf.flags == 13) &&
-//                            (syminf.locals_index < argcount) && (argcount <= MethodHandleHelper.MAX_ARGUMENT_ARITY)) {
-//                        int i = 0;
-//                        for (i = 0; i < argcount; i++) {
-//                            if (my_scope.ac.names.get(i).equals(name))
-//                                break;
-//                        }
-//                        if (i < argcount) {
-////                            Py.writeError(TYPE, "storing " + name + " to stack: " + syminf.locals_index);
-//                            code.aload(temporary);
-//                            code.astore(syminf.locals_index + VAR_OFFSET); 
-//                            return null;
-//                        }
-//                    }
-//                }
-//                break;
-//            }
-//        }
-//        return super.visitName(node);
-//    }
+    @Override
+    public Object visitName(Name node) throws Exception {
+        String name;
+        if (fast_locals) {
+            name = node.getInternalId();
+        } else {
+            name = getName(node.getInternalId());
+        }
+        SymInfo syminf = tbl.get(name);
+
+        expr_contextType ctx = node.getInternalCtx();
+
+        if (ctx == expr_contextType.AugStore) {
+            ctx = augmode;
+        }
+        if (my_scope != null && my_scope.ac != null) {
+            int argcount = my_scope.ac.names.size();
+            if (fast_locals && !my_scope.ac.arglist && !my_scope.ac.keywordlist
+                    && (syminf.flags == 13) && syminf.locals_index < argcount
+                    && (argcount < MethodHandleHelper.MAX_ARGUMENT_ARITY)) {
+                switch (ctx) {
+                case Load: {
+                    // if (((syminf.flags & ScopeInfo.BOUND) != 0) &&
+                    // ((syminf.flags & ScopeInfo.FROM_PARAM) != 0) &&
+                    int i = 0;
+                    for (i = 0; i < argcount; i++) {
+                        if (my_scope.ac.names.get(i).equals(name))
+                            break;
+                    }
+                    if (i < argcount) {
+                        // Py.writeError(TYPE, "loading " + name +
+                        // " from stack: " + syminf.locals_index);
+                        code.aload(syminf.locals_index + VAR_OFFSET);
+                        return null;
+                    }
+                    break;
+                }
+                case Store: {
+                    // if (((syminf.flags & ScopeInfo.BOUND) != 0) &&
+                    // ((syminf.flags & ScopeInfo.FROM_PARAM) != 0) &&
+                    int i = 0;
+                    for (i = 0; i < argcount; i++) {
+                        if (my_scope.ac.names.get(i).equals(name))
+                            break;
+                    }
+                    if (i < argcount) {
+                        // Py.writeError(TYPE, "storing " + name + " to stack: "
+                        // + syminf.locals_index);
+                        code.aload(temporary);
+                        code.astore(syminf.locals_index + VAR_OFFSET);
+                        return null;
+                    }
+                    break;
+                }
+                }
+            }
+        }
+        return super.visitName(node);
+    }
     
     
-//    @Override
-//    public Object visitOptimizedFor(OptimizedFor node) throws Exception {
-//        java.util.List<expr> values = null;
-//        expr internalFunc = null;
-//        values = ((Call)node.getInternalIter()).getInternalArgs();
-//        internalFunc = ((Call)node.getInternalIter()).getInternalFunc();
-//        int start_value = code.getLocal(p(int.class));
-//        int stop_value = code.getLocal(p(int.class));
-//        int step_value = code.getLocal(p(int.class));
-//        int down_count_flag = code.getLocal(p(int.class));
-//        switch(values.size()){
-//            case 1:
-//                visit(values.get(0));
-//                code.iconst_0();
-//                code.istore(start_value);
-//                // stackProduce();
-//                // stackConsume();
-//                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
-//                code.istore(stop_value);
-//                code.iconst_1();
-//                code.istore(step_value);
-//                code.iconst_0();
-//                code.istore(down_count_flag);
-//                break;
-//            case 2:
-//                visit(values.get(0));
-//                // stackProduce();
-//                // stackConsume();
-//                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
-//                code.istore(start_value);
-//                visit(values.get(1));
-//                // stackProduce();
-//                // stackConsume();
-//                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
-//                code.istore(stop_value);
-//                code.iconst_1();
-//                code.istore(step_value);
-//                code.iconst_0();
-//                code.istore(down_count_flag);
-//                break;
-//            case 3:
-//                visit(values.get(0));
-//                // stackProduce();
-//                // stackConsume();
-//                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
-//                code.istore(start_value);
-//                visit(values.get(1));
-//                // stackProduce();
-//                // stackConsume();
-//                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
-//                code.istore(stop_value);
-//                visit(values.get(2));
-//                // stackProduce();
-//                // stackConsume();
-//                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
-//                code.istore(step_value);
-//                // Since a step is provided, we have to check for down count
-//                // Also we need to throw an error in case step size is 0
-//                Label stepNotZero = new Label();
-//                Label greaterThanZero = new Label();
-//                Label step_end = new Label();
-//                code.iconst_0();
-//                code.iload(step_value);
-//                code.if_icmpne(stepNotZero);
-//                // step == 0: throw error
-//                code.ldc("[x]range() step argument must not be zero");
-//                code.invokestatic(p(Py.class), "ValueError", sig(PyException.class, String.class));
-//                code.athrow();
-//                // step != 0
-//                code.label(stepNotZero);
-//                code.iload(step_value);
-//                code.iconst_0();
-//                code.if_icmpgt(greaterThanZero);
-//                // step < 0, start down count
-//                code.iconst_1();
-//                code.istore(down_count_flag);
-//                code.goto_(step_end);
-//                // step > 0, normal up count
-//                code.label(greaterThanZero);
-//                code.iconst_0();
-//                code.istore(down_count_flag);
-//                code.label(step_end);
-//                break;
-//            default:
-//                break;
-//        }
-//        // Try-catch block for guarding
-//        Label start = new Label();
-//        Label end = new Label();
-//        Label handler_start = new Label();
-//        Label handler_end = new Label();
-//        code.trycatch(start, end, handler_start, p(IndyOptimizerException.class));
-//        code.label(start);
-//        // Do the try-stuff
-//        visit(internalFunc);
-//        // stackProduce();
-//        // stackConsume();
-//        code.invokedynamic("xrangeTarget", sig(Void.TYPE, PyObject.class), 
-//                           RangeOptimizer.className, "xrangeBsm", Opcodes.H_INVOKESTATIC);
-//        
-//        optimizedForLoop(node, start_value, stop_value, step_value, down_count_flag);
-//        code.label(end);
-//        code.goto_(handler_end);
-//        code.label(handler_start);
-//        // Catch block
-//        // TODO: call the pbcvm here
-//        code.new_(p(IndyOptimizerException.class));
-//        code.ldc("Range function Exception!");
-//        code.invokespecial(p(IndyOptimizerException.class), "<init>", sig(Void.TYPE, String.class));
-//        code.athrow();
-//        code.label(handler_end);
-//        code.freeLocal(start_value);
-//        code.freeLocal(stop_value);
-//        code.freeLocal(step_value);
-//        return null;
-//    }
-//
-//    private void optimizedForLoop(For forNode,
-//                                  int start_value,
-//                                  int stop_value,
-//                                  int step_value,
-//                                  int down_count_flag) throws Exception {
-//        // Now the optimized code path for the for loop begins
-//        int savebcf = beginLoop();
-//        Label continue_loop = continueLabels.peek();
-//        Label break_loop = breakLabels.peek();
-//        Label start_loop = new Label();
-//        Label next_loop = new Label();
-//        Label down_count_label = new Label();
-//        Label loop_end = new Label();
-//        setline(forNode);
-//        // reuse the start_value variable as the iter_tmp variable
-//        int iter_tmp = start_value;
-//        int expr_tmp = code.getLocal(p(PyObject.class));
-//        // set up the loop iterator
-//        code.iload(start_value);
-//        code.istore(iter_tmp);
-//        // do check at end of loop. Saves one opcode ;-)
-//        code.goto_(next_loop);
-//        code.label(start_loop);
-//        // set iter variable to current entry in list
-//        set(forNode.getInternalTarget(), expr_tmp);
-//        // evaluate for body
-//        suite(forNode.getInternalBody());
-//        code.label(continue_loop);
-//        setline(forNode);
-//        code.iload(iter_tmp);
-//        code.iload(step_value);
-//        code.iadd();
-//        code.istore(iter_tmp);
-//        code.label(next_loop);
-//        // make the element available in python
-//        code.iload(iter_tmp);
-//        code.invokestatic(p(Py.class), "newInteger", sig(PyInteger.class, int.class));
-//        code.astore(expr_tmp);
-//        // down counting?
-//        code.iload(down_count_flag);
-//        code.iconst_1();
-//        code.if_icmpeq(down_count_label);
-//        // now check if we should go back into the loop
-//        code.iload(iter_tmp);
-//        code.iload(stop_value);
-//        code.if_icmplt(start_loop);
-//        code.goto_(loop_end);
-//        // We are down counting
-//        code.label(down_count_label);
-//        code.iload(iter_tmp);
-//        code.iload(stop_value);
-//        code.if_icmpgt(start_loop);
-//        code.label(loop_end);
-//        finishLoop(savebcf);
-//        if (forNode.getInternalOrelse() != null) {
-//            // Do else clause if provided
-//            suite(forNode.getInternalOrelse());
-//        }
-//        code.label(break_loop);
-//        code.freeLocal(expr_tmp);
-//    }
+    @Override
+    public Object visitOptimizedFor(OptimizedFor node) throws Exception {
+        java.util.List<expr> values = null;
+        expr internalFunc = null;
+        values = ((Call)node.getInternalIter()).getInternalArgs();
+        internalFunc = ((Call)node.getInternalIter()).getInternalFunc();
+        int start_value = code.getLocal(p(int.class));
+        int stop_value = code.getLocal(p(int.class));
+        int step_value = code.getLocal(p(int.class));
+        int down_count_flag = code.getLocal(p(int.class));
+        switch(values.size()){
+            case 1:
+                visit(values.get(0));
+                code.iconst_0();
+                code.istore(start_value);
+                // stackProduce();
+                // stackConsume();
+                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
+                code.istore(stop_value);
+                code.iconst_1();
+                code.istore(step_value);
+                code.iconst_0();
+                code.istore(down_count_flag);
+                break;
+            case 2:
+                visit(values.get(0));
+                // stackProduce();
+                // stackConsume();
+                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
+                code.istore(start_value);
+                visit(values.get(1));
+                // stackProduce();
+                // stackConsume();
+                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
+                code.istore(stop_value);
+                code.iconst_1();
+                code.istore(step_value);
+                code.iconst_0();
+                code.istore(down_count_flag);
+                break;
+            case 3:
+                visit(values.get(0));
+                // stackProduce();
+                // stackConsume();
+                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
+                code.istore(start_value);
+                visit(values.get(1));
+                // stackProduce();
+                // stackConsume();
+                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
+                code.istore(stop_value);
+                visit(values.get(2));
+                // stackProduce();
+                // stackConsume();
+                code.invokevirtual(p(PyObject.class), "asInt", sig(int.class));
+                code.istore(step_value);
+                // Since a step is provided, we have to check for down count
+                // Also we need to throw an error in case step size is 0
+                Label stepNotZero = new Label();
+                Label greaterThanZero = new Label();
+                Label step_end = new Label();
+                code.iconst_0();
+                code.iload(step_value);
+                code.if_icmpne(stepNotZero);
+                // step == 0: throw error
+                code.ldc("[x]range() step argument must not be zero");
+                code.invokestatic(p(Py.class), "ValueError", sig(PyException.class, String.class));
+                code.athrow();
+                // step != 0
+                code.label(stepNotZero);
+                code.iload(step_value);
+                code.iconst_0();
+                code.if_icmpgt(greaterThanZero);
+                // step < 0, start down count
+                code.iconst_1();
+                code.istore(down_count_flag);
+                code.goto_(step_end);
+                // step > 0, normal up count
+                code.label(greaterThanZero);
+                code.iconst_0();
+                code.istore(down_count_flag);
+                code.label(step_end);
+                break;
+            default:
+                break;
+        }
+        // Try-catch block for guarding
+        Label start = new Label();
+        Label end = new Label();
+        Label handler_start = new Label();
+        Label handler_end = new Label();
+        code.trycatch(start, end, handler_start, p(IndyOptimizerException.class));
+        code.label(start);
+        // Do the try-stuff
+        visit(internalFunc);
+        // stackProduce();
+        // stackConsume();
+        code.invokedynamic("xrangeTarget", sig(Void.TYPE, PyObject.class), 
+                           RangeOptimizer.className, "xrangeBsm", Opcodes.H_INVOKESTATIC);
+        
+        optimizedForLoop(node, start_value, stop_value, step_value, down_count_flag);
+        code.label(end);
+        code.goto_(handler_end);
+        code.label(handler_start);
+        // Catch block
+        // TODO: call the pbcvm here
+        code.new_(p(IndyOptimizerException.class));
+        code.ldc("Range function Exception!");
+        code.invokespecial(p(IndyOptimizerException.class), "<init>", sig(Void.TYPE, String.class));
+        code.athrow();
+        code.label(handler_end);
+        code.freeLocal(start_value);
+        code.freeLocal(stop_value);
+        code.freeLocal(step_value);
+        return null;
+    }
+
+    private void optimizedForLoop(For forNode,
+                                  int start_value,
+                                  int stop_value,
+                                  int step_value,
+                                  int down_count_flag) throws Exception {
+        // Now the optimized code path for the for loop begins
+        int savebcf = beginLoop();
+        Label continue_loop = continueLabels.peek();
+        Label break_loop = breakLabels.peek();
+        Label start_loop = new Label();
+        Label next_loop = new Label();
+        Label down_count_label = new Label();
+        Label loop_end = new Label();
+        setline(forNode);
+        // reuse the start_value variable as the iter_tmp variable
+        int iter_tmp = start_value;
+        int expr_tmp = code.getLocal(p(PyObject.class));
+        // set up the loop iterator
+        code.iload(start_value);
+        code.istore(iter_tmp);
+        // do check at end of loop. Saves one opcode ;-)
+        code.goto_(next_loop);
+        code.label(start_loop);
+        // set iter variable to current entry in list
+        set(forNode.getInternalTarget(), expr_tmp);
+        // evaluate for body
+        suite(forNode.getInternalBody());
+        code.label(continue_loop);
+        setline(forNode);
+        code.iload(iter_tmp);
+        code.iload(step_value);
+        code.iadd();
+        code.istore(iter_tmp);
+        code.label(next_loop);
+        // make the element available in python
+        code.iload(iter_tmp);
+        code.invokestatic(p(Py.class), "newInteger", sig(PyInteger.class, int.class));
+        code.astore(expr_tmp);
+        // down counting?
+        code.iload(down_count_flag);
+        code.iconst_1();
+        code.if_icmpeq(down_count_label);
+        // now check if we should go back into the loop
+        code.iload(iter_tmp);
+        code.iload(stop_value);
+        code.if_icmplt(start_loop);
+        code.goto_(loop_end);
+        // We are down counting
+        code.label(down_count_label);
+        code.iload(iter_tmp);
+        code.iload(stop_value);
+        code.if_icmpgt(start_loop);
+        code.label(loop_end);
+        finishLoop(savebcf);
+        if (forNode.getInternalOrelse() != null) {
+            // Do else clause if provided
+            suite(forNode.getInternalOrelse());
+        }
+        code.label(break_loop);
+        code.freeLocal(expr_tmp);
+    }
 }

File src/org/python/compiler/cfg/IRBuilder.java

 import org.python.antlr.ast.HandlerStart;
 import org.python.antlr.ast.If;
 import org.python.antlr.ast.Name;
+import org.python.antlr.ast.OptimizedFor;
 import org.python.antlr.ast.Pass;
 import org.python.antlr.ast.Return;
 import org.python.antlr.ast.Suite;
 
     private int uniqueWith = 0;
 
+    private boolean hasYield = false;
+
     private void beginBlock(HandlerStart handler) {
         data = new CFGBuilderData();
         if (handler != null) {
         expr target = (expr) visitNotNull(node.getInternalTarget());
 
         int savebcf = beginLoop();
+        this.hasYield = false;
         List<stmt> body = suite(node.getInternalBody());
         finishLoop(savebcf);
 
         List<stmt> orelse = suite(node.getInternalOrelse());
+        if (node.getInternalIter() instanceof Call) {
+            if (((Call)node.getInternalIter()).getInternalFunc() instanceof Name) {
+                Name name = (Name)((Call)node.getInternalIter()).getInternalFunc();
+                if (name.getInternalId().equals("range") || name.getInternalId().equals("xrange")) {
+                    if (!hasYield) {
+                        return new OptimizedFor(node.getToken(), target, iter, body, orelse);
+                    }
+                }
+            }
+        }
+
         return new For(node.getToken(), target, iter, body, orelse);
     }
 
 
     @Override
     public Object visitYield(Yield node) throws Exception {
+        this.hasYield = true;
         data.stmts.add(new YieldReturn(node));
         boolean endCurrentBlock = data.handler != null;
         endAllExceptions(endCurrentBlock, true);

File src/org/python/compiler/cfg/TransitiveVisitor.java

 import org.python.antlr.ast.Module;
 import org.python.antlr.ast.Name;
 import org.python.antlr.ast.Num;
+import org.python.antlr.ast.OptimizedFor;
 import org.python.antlr.ast.Pass;
 import org.python.antlr.ast.Print;
 import org.python.antlr.ast.Raise;
     }
 
     @Override
+    public Object visitOptimizedFor(OptimizedFor node) throws Exception {
+        List<stmt> body = suite(node.getInternalBody());
+        expr iter = (expr) visitNotNull(node.getInternalIter());
+        expr target = (expr) visitNotNull(node.getInternalTarget());
+        List<stmt> orelse = suite(node.getInternalOrelse());
+        OptimizedFor new_node = new OptimizedFor(node.getToken(), target, iter, body, orelse);
+
+        return new_node;
+    }
+
+    @Override
     public Object visitFunctionDef(FunctionDef node) throws Exception {
         arguments args = node.getInternalArgs();
         List<stmt> body = suite(node.getInternalBody());