Commits

Anteru  committed 481580b

Add a new debug printer for types.

  • Participants
  • Parent commits 4dc37bb

Comments (0)

Files changed (5)

File nsl/Compiler.py

 from nsl.parser import NslParser
+from nsl.passes import ComputeTypes, ValidateSwizzle, ValidateFlowStatements, DebugAst, DebugTypes, PrettyPrint
 
 class Compiler:
     def __init__(self):
-        import nsl.passes.ComputeTypes, nsl.passes.ValidateFlowStatements, nsl.passes.DebugPrint, nsl.passes.ValidateSwizzle, nsl.passes.PrettyPrint
         self.parser = NslParser ()
 
-        self.passes = [nsl.passes.ComputeTypes.GetPass(),
-                  nsl.passes.ValidateFlowStatements.GetPass (),
-                  nsl.passes.ValidateSwizzle.GetPass (),
-                  nsl.passes.DebugPrint.GetPass (),
-                  nsl.passes.PrettyPrint.GetPass ()]
+        self.passes = [ComputeTypes.GetPass(),
+                  ValidateFlowStatements.GetPass (),
+                  ValidateSwizzle.GetPass (),
+                  DebugAst.GetPass (),
+                  DebugTypes.GetPass (),
+                  PrettyPrint.GetPass ()]
 
     def Compile (self, source, debugParsing = False):
         ast = self.parser.Parse (source, debug = debugParsing)

File nsl/passes/ComputeTypes.py

                                              expr.GetLeft ().type,
                                              expr.GetRight ().type)
 
+        return expr.type
+
     def v_IfStatement(self, stmt, ctx):
         self._ProcessExpression(stmt.GetCondition(), ctx[-1])
         self.v_Visit (stmt.GetTruePath(), ctx)
                                         scope)
 
     def v_ExpressionStatement(self, stmt, ctx):
-        self._ProcessExpression(stmt.GetExpression(), ctx[-1])
+        self.type = self._ProcessExpression(stmt.GetExpression(), ctx[-1])
 
     def v_ReturnStatement(self, stmt, ctx):
-        self._ProcessExpression(stmt.GetExpression(), ctx[-1])
+        self.type = self._ProcessExpression(stmt.GetExpression(), ctx[-1])
 
     def v_Function(self, func, ctx):
         '''Computes the function type and processes all statements.'''

File nsl/passes/DebugPrint.py

-from nsl import ast
-
-class DebugVisitor(ast.Visitor):
-    def GetContext (self):
-        return 0
-
-    def v_Generic(self, obj, ctx=None):
-        ast.Visitor.v_Generic (self, obj, ctx)
-        if hasattr (obj, 'Traverse'):
-            obj.Traverse (self, ctx + 1)
-
-    def v_Default(self, obj, ctx):
-        print (' '*ctx*2, obj.__class__.__name__)
-        print (' '*(ctx*2 + 4), str (obj))
-
-def GetPass():
-    import nsl.Pass
-    return nsl.Pass.MakePassFromVisitor (DebugVisitor (), 'debug-print')

File nsl/passes/DebugTypes.py

+from nsl import ast
+
+class DebugTypeVisitor(ast.DefaultVisitor):
+	def GetContext (self):
+		return 0
+
+	def _p(self, ctx, s, **args):
+		print (' ' * (ctx * 4), end = '')
+		print (s, **args)
+
+	def v_StructureDefinition(self, decl, ctx):
+		self._p (ctx, 'struct ' + decl.GetName ())
+		for t in decl.GetElements ():
+			# Resolve here allows for nested types
+			self._p (ctx + 1, t.GetName () + ':' + str(t.GetType ()))
+		print()
+
+	def _ProcessExpression(self, expr, ctx):
+		self._p (ctx, str(expr) + ':' + str(expr.type))
+		for e in expr:
+			self._ProcessExpression (e, ctx + 1)
+
+	def v_CompoundStatement(self, stmt, ctx):
+		for s in stmt:
+			self.v_Visit (s, ctx + 1)
+
+	def v_PrimaryExpression(self, expr, ctx):
+		self._p(ctx, str(expr) + ':' + str(expr.type))
+
+	def v_ExpressionStatement(self, stmt, ctx):
+		self._ProcessExpression(stmt.GetExpression(), ctx)
+
+	def v_Function(self, func, ctx):
+		'''Computes the function type and processes all statements.'''
+		self._p(ctx, str(func.GetType ()))
+		self._p (ctx, 'Arguments')
+		for (name, type) in func.GetType ().GetArguments().items ():
+			self._p (ctx + 1, name + ':' + str(type))
+
+		print ()
+		self.v_Visit (func.GetBody(), ctx)
+		print ()
+
+	def v_Shader(self, shd, ctx=None):
+		self.v_Function(shd, ctx)
+
+	def v_Program(self, prog, ctx):
+		# Must visit types first
+		for type in prog.GetTypes ():
+			self.v_Visit (type, ctx)
+		for decl in prog.GetDeclarations ():
+			self.v_Visit (decl, ctx)
+		for func in prog.GetFunctions ():
+			self.v_Visit (func, ctx)
+
+	def v_Generic(self, node, ctx):
+		ast.Visitor.v_Generic (self, node, ctx)
+
+def GetPass():
+	import nsl.Pass
+	return nsl.Pass.MakePassFromVisitor (DebugTypeVisitor (), 'debug-print-types')

File nsl/types.py

         self.name = name
         for (name, type) in declarations.items ():
             self.members.RegisterVariable(name, type)
+        self.declarations = declarations
 
     def __str__(self):
         return 'struct {}'.format(self.name)
         return '@{}->{}`{}'.format (self.name, str(self.returnType),
                                     ','.join ([str(arg) for arg in self.arguments.values ()]))
 
+    def __str__(self):
+        return 'function {0} ({1}) -> {2}'.format (self.name,
+            ', '.join(self.arguments.keys ()), self.returnType)
+
     def __repr__(self):
         return 'Function (\'{}\', {}, [{}])'.format (self.name, repr(self.returnType),
                                                  ', '.join ([repr(arg) for arg in self.arguments.keys ()]))