Martin Alnæs committed cc7e914

Improve ordering of arguments in factorized graph and subsequent loop invariant code motion.

Comments (0)

Files changed (2)


 from ufl.classes import (Terminal, Grad, Indexed, FixedIndex,
                          Restricted, PositiveRestricted, NegativeRestricted,
+                         FacetAvg, CellAvg,
                          Coefficient, Argument)
 from ufl.sorting import sorted_expr
 from uflacs.utils.log import uflacs_assert, warning, error
-# TODO: Add FacetAvg and CellAvg to modifiers
+# TODO: Add FacetAvg and CellAvg to modifiers everywhere relevant and handle in table extraction
 # TODO: Make this more robust by looping like analyse_modified_terminal, currently assumes that transformations have been applied.
 def is_modified_terminal(v):
-    return (isinstance(v, (Terminal, Grad, Restricted))
-            or (isinstance(v, Indexed) and isinstance(v.operands()[0], (Terminal, Grad, Restricted))))
+    _accepted_types = (Terminal, Grad, Restricted, FacetAvg, CellAvg)
+    return (isinstance(v, _accepted_types)
+            or (isinstance(v, Indexed) and isinstance(v.operands()[0], _accepted_types)))
-terminal_modifier_types = (Grad, Restricted, Indexed)
+terminal_modifier_types = (Grad, Restricted, Indexed, FacetAvg, CellAvg)
 def analyse_modified_terminal(o, form_argument_mapping={}):
     """Analyse a so-called 'modified terminal' expression and return its properties in more compact form.
     component = None
     derivatives = []
     r = None
+    a = None
     while not isinstance(t, Terminal):
         if not isinstance(t, terminal_modifier_types):
             error("Unexpected type %s object %s." % (type(t), repr(t)))
             r = t._side
             t, = t.operands()
+        elif isinstance(t, CellAvg):
+            uflacs_assert(a is None, "Got twice averaged terminal!")
+            a = "cell"
+            t, = t.operands()
+        elif isinstance(t, FacetAvg):
+            uflacs_assert(a is None, "Got twice averaged terminal!")
+            a = "facet"
+            t, = t.operands()
     t = form_argument_mapping.get(t,t)
     component = tuple(component) if component else ()
     derivatives = tuple(sorted(derivatives))
     uflacs_assert(all(c >= 0 and c < d for c,d in zip(component, t.shape())),
                   "Component indices %s are outside terminal shape %s" % (component, t.shape()))
-    return (t, component, derivatives, r)
+    # FIXME: Return average state, update all callers
+    return (t, component, derivatives, r) #, a)
 class DependencyHandler(object):


 from ufl import as_ufl
-from ufl.classes import Terminal, Indexed, Grad, Restricted, Argument, Product, Sum, Division
+from ufl.classes import Terminal, Indexed, Grad, Restricted, FacetAvg, CellAvg, Argument, Product, Sum, Division
 from uflacs.utils.log import uflacs_assert
 from uflacs.analysis.graph_ssa import compute_dependencies
+from uflacs.analysis.dependency_handler import analyse_modified_terminal
 def strip_modified_terminal(v):
     "Extract core Terminal from a modified terminal or return None."
     while not isinstance(v, Terminal):
-        if isinstance(v, (Indexed, Grad, Restricted)):
+        if isinstance(v, (Indexed, Grad, Restricted, FacetAvg, CellAvg)):
             v = v.operands()[0]
             return None
         arg_combos = [comb + (j,) for comb in arg_combos for j in js]
     return arg_indices, arg_combos
+def build_argument_indices_from_arg_sets(arg_sets):
+    "Build ordered list of indices to modified arguments."
+    arg_indices = set()
+    for js in arg_sets.values():
+        arg_indices.update(js)
+    return sorted(arg_indices)
+def build_argument_indices(SV):
+    "Build ordered list of indices to modified arguments."
+    arg_sets = {}
+    for i,v in enumerate(SV):
+        a = strip_modified_terminal(v)
+        if not isinstance(a, Argument):
+            continue
+        c = a.count()
+        s = arg_sets.get(c)
+        if s is None:
+            s = {}
+            arg_sets[c] = s
+        s[i] = v
+    arg_indices = set()
+    for js in arg_sets.values():
+        arg_indices.update(js)
+    def arg_ordering_key(i):
+        a = None # TODO: Include averaging state
+        (t, c, d, r) = analyse_modified_terminal(arg_ordering_key.SV[i])
+        #, form_argument_mapping={}) # FIXME: Need this? Already mapped?
+        return (t.count(), a, d, r, a)
+    arg_ordering_key.SV = SV
+    ordered_arg_indices = sorted(arg_indices, key=arg_ordering_key)
+    return ordered_arg_indices
 def build_argument_dependencies(dependencies, arg_indices):
     "Preliminary algorithm: build list of argument vertex indices each vertex (indirectly) depends on."
     n = len(dependencies)
     argkeys = sorted(IM.keys())
     fs = []
     for argkey in argkeys:
-        # Add each subproduct of this monomial
+        # Start with coefficients
         f = FV[IM[argkey]]
+        ###f = 1
+        # Add binary products with each argument in order
         for argindex in argkey:
             f = f*AV[argindex]
+        # Add product with coefficients last
+        ###f = f*FV[IM[argkey]]
+        ###add_vertex(f)
         # f is now the full monomial, store it as a term for sum below
     assert list(target_variables) == [len(SV)-1]
     arg_sets = build_argument_component_sets(SV)
-    arg_indices, valid_arg_combos = build_valid_argument_combinations(arg_sets)
+    #arg_indices, valid_arg_combos = build_valid_argument_combinations(arg_sets)
+    #arg_indices = build_argument_indices(arg_sets)
+    arg_indices = build_argument_indices(SV)
     A = build_argument_dependencies(dependencies, arg_indices)
         print '\n'.join(map(str,arg_sets.items()))
         print 'arg_indices:'
         print arg_indices
-        print 'valid_arg_combos:'
-        print valid_arg_combos
+        #print 'valid_arg_combos:'
+        #print valid_arg_combos
         print 'A:'
         print A
         print 'END DEBUGGING compute_argument_factorization'
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.