Commits

Kristian Ølgaard  committed c10c577

Fix issue 8 regarding nondeterministic code generation for quadrature optimisations.

  • Participants
  • Parent commits 8de7612

Comments (0)

Files changed (3)

File ffc/quadrature/optimisedquadraturetransformer.py

             return {(): create_product([val]*expo.value())}
         elif isinstance(expo, FloatValue):
             exp = format["floating point"](expo.value())
-#            sym = create_symbol(format["std power"](str(val), exp), val.t)
-#            sym.base_expr = val
-#            sym.base_op = 1 # Add one operation for the pow() function.
-            sym = create_symbol(format["std power"], val.t, val, 1, exp)
+            sym = create_symbol(format["std power"](str(val), exp), val.t, val, 1)
             return {(): sym}
         elif isinstance(expo, (Coefficient, Operator)):
             exp = self.visit(expo)[()]
-#            sym = create_symbol(format["std power"](str(val), exp[()]), val.t)
-#            sym.base_expr = val
-#            sym.base_op = 1 # Add one operation for the pow() function.
-            sym = create_symbol(format["std power"], val.t, val, 1, exp)
+#            print "pow exp: ", exp
+#            print "pow val: ", val
+            sym = create_symbol(format["std power"](str(val), exp), val.t, val, 1)
             return {(): sym}
         else:
             error("power does not support this exponent: " + repr(expo))
 
         # Take absolute value of operand.
         val = operands[0][()]
-#        new_val = create_symbol(format["absolute value"](str(val)), val.t)
-#        new_val.base_expr = val
-#        new_val.base_op = 1 # Add one operation for taking the absolute value.
-        new_val = create_symbol(format["absolute value"], val.t, val, 1)
+        new_val = create_symbol(format["absolute value"](str(val)), val.t, val, 1)
         return {():new_val}
 
     # -------------------------------------------------------------------------
         c, = operands
         ffc_assert(len(c) == 1 and c.keys()[0] == (),\
             "Condition for NotCondition should only be one function: " + repr(c))
-        sym = create_symbol("", c[()].t, cond=(c[()], format["not"]))
+        sym = create_symbol(format["not"](str(c[()])), c[()].t, base_op=c[()].ops()+1)
         return {(): sym}
 
     def binary_condition(self, o, *operands):
 
         # Get the minimum type
         t = min(lhs[()].t, rhs[()].t)
-        sym = create_symbol("", t, cond=(lhs[()], format[name_map[o._name]], rhs[()]))
+        ops = lhs[()].ops() + rhs[()].ops() + 1
+        cond = str(lhs[()])+format[name_map[o._name]]+str(rhs[()])
+        sym = create_symbol(format["grouping"](cond), t, base_op=ops)
         return {(): sym}
 
     def conditional(self, o, *operands):
         # Use format function on value of operand.
         operand = operands[0]
         for key, val in operand.items():
-#            new_val = create_symbol(format_function(str(val)), val.t)
-#            new_val.base_expr = val
-#            new_val.base_op = 1 # Add one operation for the math function.
-            new_val = create_symbol(format_function, val.t, val, 1)
+            new_val = create_symbol(format_function(str(val)), val.t, val, 1)
             operand[key] = new_val
         return operand
 
         if x is None:
             x = format["floating point"](0.0)
 
-        sym = create_symbol(format_function, x.t, x, 1, nu)
+        sym = create_symbol(format_function(x,nu), x.t, x, 1)
         return {():sym}
 
     # -------------------------------------------------------------------------

File ffc/quadrature/symbol.py

 
 class Symbol(Expr):
     __slots__ = ("v", "base_expr", "base_op", "exp", "cond")
-    def __init__(self, variable, symbol_type, base_expr=None, base_op=0, expo=None, cond=()):
+    def __init__(self, variable, symbol_type, base_expr=None, base_op=0):
         """Initialise a Symbols object, it derives from Expr and contains
         the additional variables:
 
         # ops = base_expr.ops() + base_ops = 2 + 1 = 3
         self.base_expr = base_expr
         self.base_op = base_op
-        self.exp = expo
-        self.cond = cond
 
         # If type of the base_expr is lower than the given symbol_type change type.
         # TODO: Should we raise an error here? Or simply require that one
         # Compute the representation now, such that we can use it directly
         # in the __eq__ and __ne__ methods (improves performance a bit, but
         # only when objects are cached).
-        if self.base_expr and self.exp is None:
+        if self.base_expr:# and self.exp is None:
             self._repr = "Symbol('%s', %s, %s, %d)" % (self.v, type_to_string[self.t],\
                          self.base_expr._repr, self.base_op)
-        elif self.base_expr:
-            self._repr = "Symbol('%s', %s, %s, %d, %s)" % (self.v, type_to_string[self.t],\
-                         self.base_expr._repr, self.base_op, self.exp)
-        elif self.cond:
-            self._repr = "Symbol('%s', %s, %s, %d, %s, %s)" % (self.v, type_to_string[self.t],\
-                          self.base_expr, self.base_op, self.exp, self.cond)
         else:
             self._repr = "Symbol('%s', %s)" % (self.v, type_to_string[self.t])
 
     # Print functions.
     def __str__(self):
         "Simple string representation which will appear in the generated code."
-        if self.base_expr is None:
-            if self.cond == ():
-                return self.v
-            else:
-                if len(self.cond) == 2:
-                    return self.cond[1](str(self.cond[0]))
-                return format["grouping"]("".join([str(c) for c in self.cond]))
-        elif self.exp is None:
-            return self.v(str(self.base_expr))
-        return self.v(str(self.base_expr), self.exp)
+#        print "sym str: ", self.v
+        return self.v
 
     # Binary operators.
     def __add__(self, other):
         # for the base (sin(2*x + 1)) --> 2 + 1.
         if self.base_expr:
             return self.base_op + self.base_expr.ops()
-        elif self.cond:
-            if len(self.cond) == 2:
-                return self.base_op + self.cond[0].ops() + 1
-            return self.base_op + self.cond[0].ops() + self.cond[2].ops() + 1
         return self.base_op
 
 from floatvalue import FloatValue

File ffc/quadrature/symbolics.py

     return float_val
 
 _symbol_cache = {}
-def create_symbol(variable, symbol_type, base_expr=None, base_op=0, expo=None, cond=()):
-    key = (variable, symbol_type, base_expr, base_op, expo, cond)
+def create_symbol(variable, symbol_type, base_expr=None, base_op=0):
+    key = (variable, symbol_type, base_expr, base_op)
     if key in _symbol_cache:
 #        print "found %s in cache" %variable
         return _symbol_cache[key]
-    symbol = Symbol(variable, symbol_type, base_expr, base_op, expo, cond)
+    symbol = Symbol(variable, symbol_type, base_expr, base_op)
     _symbol_cache[key] = symbol
     return symbol