Anonymous avatar Anonymous committed 72fbd38

Added cgen classes for Class and Methods

Comments (0)

Files changed (2)

 class CodeGenIndentException(CodeGenException):
     pass
 
+class NotImplementedException(Exception):
+    pass
+
 class Module(object):
     '''Outermost Code block representing a Python Module. May include a main block'''
     def __init__(self, main = False):
         self.expression = expression
 
 class Statement(object):
-    pass
+    """Generic statement. To be overridden."""
+
+    def get(self):
+        raise NotImplementedException
+
+    def fix(self):
+        raise NotImplementedException
+
+class Class(object):
+    def __init__(self, name, super = None, content = None):
+        self.name = name
+        if super:
+            self.super = super
+        else:
+            self.super = ['object']
+
+        if content:
+            self.content = content
+        else:
+            self.content = []
+
+class Method(object):
+    def __init__(self, name, args = None, content = None):
+        self.name = name
+
+        self.args = []
+        if args != None:
+            self.args = args
+
+        if content != None:
+            self.content = content
+        else:
+            self.content = []
+
 
 
 class FixGenerator(object):
 
         return func
 
+    @visit.when(Class)
+    def visit(self, node):
+        c = Class(node.name)
+        c.super = self.visit_args(node.super)
+        c.content = self.visit_block(node.content)
+
+        return c
+
+    @visit.when(Method)
+    def visit(self, node):
+        m = Method(node.name)
+        m.args = self.visit_args(node.args)
+        m.content = self.visit_block(node.content)
+        return m
+
     @visit.when(IfStatement)
     def visit(self, node):
         stmt = IfStatement(node.clause)
         content += self.visit_block(depth, node.content)
         return content
 
+    @visit.when(Class)
+    def visit(self, depth, node):
+        superclasses = ", ".join(node.super)
+        c = "".join(['class ', node.name, '(', superclasses, '):'])
+        content = [self.code(depth, c)]
+        content += self.visit_block(depth, node.content)
+
+        return content
+
+    @visit.when(Method)
+    def visit(self, depth, node):
+        if node.args:
+            args = ", ".join(["self"] + node.args)
+        else:
+            args = "self"
+
+        fun = ''.join(['def ', node.name, '(', args, '):'])
+        content = [self.code(depth, fun)]
+        content += self.visit_block(depth, node.content)
+
+        return content
+
     @visit.when(CallStatement)
     def visit(self, depth, node):
         args = ", ".join(self.visit_args(node.args))
         if(len(node.expression) > 0):
             code += map(lambda x: x.strip(), self.visit_block(0, node.expression))
         return [self.code(depth, " ".join(code))]
+
         self.assert_("StringStatement" in code)
         self.assert_("CallableStatement" in code)
 
+
+class TestClass(unittest.TestCase):
+    def setUp(self):
+        self.gen = CodeGenerator()
+        self.fix = FixGenerator()
+
+    def testEmptyClass(self):
+        n = Class('Test')
+        n = self.fix.generate(n)
+
+        code = self.gen.generate(n)
+        self.assert_("class Test" in code)
+        self.assert_("(object)" in code)
+        self.assert_("pass" in code)
+
+    def testClass(self):
+        n = Class('Test')
+        n.content.append('x = 5')
+        n = self.fix.generate(n)
+
+        code = self.gen.generate(n)
+        self.assert_("class Test" in code)
+        self.assert_("(object)" in code)
+        self.assert_("pass" not in code)
+
+    def testEmptyMethod(self):
+        n = Method('test')
+        n = self.fix.generate(n)
+
+        code = self.gen.generate(n)
+        self.assert_("def test" in code)
+        self.assert_("(self)" in code)
+        self.assert_("pass" in code)
+
+    def testEmptyMethod(self):
+        n = Method('test')
+        n.content.append('x = 5')
+        n = self.fix.generate(n)
+
+        code = self.gen.generate(n)
+        self.assert_("def test" in code)
+        self.assert_("(self)" in code)
+        self.assert_("pass" not in code)
+
+
 class TestCallStatement(unittest.TestCase):
     def setUp(self):
         self.gen = CodeGenerator()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.