Commits

Martin Alnæs committed cc7e914

Improve ordering of arguments in factorized graph and subsequent loop invariant code motion.

Comments (0)

Files changed (2)

site-packages/uflacs/analysis/dependency_handler.py

 
 from ufl.classes import (Terminal, Grad, Indexed, FixedIndex,
                          Restricted, PositiveRestricted, NegativeRestricted,
+                         FacetAvg, CellAvg,
                          Coefficient, Argument)
 from ufl.sorting import sorted_expr
 
 from uflacs.utils.log import uflacs_assert, warning, error
 
-# TODO: Add FacetAvg and CellAvg to modifiers
+# TODO: Add FacetAvg and CellAvg to modifiers everywhere relevant and handle in table extraction
 # TODO: Make this more robust by looping like analyse_modified_terminal, currently assumes that transformations have been applied.
 def is_modified_terminal(v):
-    return (isinstance(v, (Terminal, Grad, Restricted))
-            or (isinstance(v, Indexed) and isinstance(v.operands()[0], (Terminal, Grad, Restricted))))
+    _accepted_types = (Terminal, Grad, Restricted, FacetAvg, CellAvg)
+    return (isinstance(v, _accepted_types)
+            or (isinstance(v, Indexed) and isinstance(v.operands()[0], _accepted_types)))
 
-terminal_modifier_types = (Grad, Restricted, Indexed)
+terminal_modifier_types = (Grad, Restricted, Indexed, FacetAvg, CellAvg)
 def analyse_modified_terminal(o, form_argument_mapping={}):
     """Analyse a so-called 'modified terminal' expression and return its properties in more compact form.
 
     component = None
     derivatives = []
     r = None
+    a = None
     while not isinstance(t, Terminal):
         if not isinstance(t, terminal_modifier_types):
             error("Unexpected type %s object %s." % (type(t), repr(t)))
             r = t._side
             t, = t.operands()
 
+        elif isinstance(t, CellAvg):
+            uflacs_assert(a is None, "Got twice averaged terminal!")
+            a = "cell"
+            t, = t.operands()
+
+        elif isinstance(t, FacetAvg):
+            uflacs_assert(a is None, "Got twice averaged terminal!")
+            a = "facet"
+            t, = t.operands()
+
     t = form_argument_mapping.get(t,t)
     component = tuple(component) if component else ()
     derivatives = tuple(sorted(derivatives))
     uflacs_assert(all(c >= 0 and c < d for c,d in zip(component, t.shape())),
                   "Component indices %s are outside terminal shape %s" % (component, t.shape()))
 
-    return (t, component, derivatives, r)
+    # FIXME: Return average state, update all callers
+    return (t, component, derivatives, r) #, a)
 
 
 class DependencyHandler(object):

site-packages/uflacs/analysis/factorization.py

 
 from ufl import as_ufl
-from ufl.classes import Terminal, Indexed, Grad, Restricted, Argument, Product, Sum, Division
+from ufl.classes import Terminal, Indexed, Grad, Restricted, FacetAvg, CellAvg, Argument, Product, Sum, Division
 
 from uflacs.utils.log import uflacs_assert
 
 from uflacs.analysis.graph_ssa import compute_dependencies
+from uflacs.analysis.dependency_handler import analyse_modified_terminal
 
 def strip_modified_terminal(v):
     "Extract core Terminal from a modified terminal or return None."
     while not isinstance(v, Terminal):
-        if isinstance(v, (Indexed, Grad, Restricted)):
+        if isinstance(v, (Indexed, Grad, Restricted, FacetAvg, CellAvg)):
             v = v.operands()[0]
         else:
             return None
         arg_combos = [comb + (j,) for comb in arg_combos for j in js]
     return arg_indices, arg_combos
 
+def build_argument_indices_from_arg_sets(arg_sets):
+    "Build ordered list of indices to modified arguments."
+    arg_indices = set()
+    for js in arg_sets.values():
+        arg_indices.update(js)
+    return sorted(arg_indices)
+
+def build_argument_indices(SV):
+    "Build ordered list of indices to modified arguments."
+
+    arg_sets = {}
+    for i,v in enumerate(SV):
+        a = strip_modified_terminal(v)
+        if not isinstance(a, Argument):
+            continue
+        c = a.count()
+        s = arg_sets.get(c)
+        if s is None:
+            s = {}
+            arg_sets[c] = s
+        s[i] = v
+
+    arg_indices = set()
+    for js in arg_sets.values():
+        arg_indices.update(js)
+
+    def arg_ordering_key(i):
+        a = None # TODO: Include averaging state
+        (t, c, d, r) = analyse_modified_terminal(arg_ordering_key.SV[i])
+        #, form_argument_mapping={}) # FIXME: Need this? Already mapped?
+        return (t.count(), a, d, r, a)
+    arg_ordering_key.SV = SV
+    ordered_arg_indices = sorted(arg_indices, key=arg_ordering_key)
+
+    return ordered_arg_indices
+
 def build_argument_dependencies(dependencies, arg_indices):
     "Preliminary algorithm: build list of argument vertex indices each vertex (indirectly) depends on."
     n = len(dependencies)
     argkeys = sorted(IM.keys())
     fs = []
     for argkey in argkeys:
-        # Add each subproduct of this monomial
+        # Start with coefficients
         f = FV[IM[argkey]]
+        ###f = 1
+
+        # Add binary products with each argument in order
         for argindex in argkey:
             f = f*AV[argindex]
             add_vertex(f)
+
+        # Add product with coefficients last
+        ###f = f*FV[IM[argkey]]
+        ###add_vertex(f)
+
         # f is now the full monomial, store it as a term for sum below
         fs.append(f)
 
     assert list(target_variables) == [len(SV)-1]
 
     arg_sets = build_argument_component_sets(SV)
-
-    arg_indices, valid_arg_combos = build_valid_argument_combinations(arg_sets)
+    #arg_indices, valid_arg_combos = build_valid_argument_combinations(arg_sets)
+    #arg_indices = build_argument_indices(arg_sets)
+    arg_indices = build_argument_indices(SV)
 
     A = build_argument_dependencies(dependencies, arg_indices)
 
         print '\n'.join(map(str,arg_sets.items()))
         print 'arg_indices:'
         print arg_indices
-        print 'valid_arg_combos:'
-        print valid_arg_combos
+        #print 'valid_arg_combos:'
+        #print valid_arg_combos
         print 'A:'
         print A
         print 'END DEBUGGING compute_argument_factorization'
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.