Commits

Kirill Simonov committed 0348b04

Updated generation of table identity.

Comments (0)

Files changed (3)

src/htsql/core/classify.py

     adapt(TableNode)
 
     def __call__(self):
-        label_by_column = {}
-        label_by_join = {}
+        arcs = set()
         for label in classify(self.node):
-            if (isinstance(label.arc, ColumnArc) and
-                    label.arc.column not in label_by_column):
-                label_by_column[label.arc.column] = label
-            if (isinstance(label.arc, ChainArc) and
-                    len(label.arc.joins) == 1 and
-                    label.arc.joins[0] not in label_by_join):
-                label_by_join[label.arc.joins[0]] = label
+            arc = label.arc
+            if isinstance(arc, ColumnArc):
+                arcs.add(arc)
+                if arc.link is not None:
+                    if isinstance(arc.link, ChainArc):
+                        arcs.add(arc.link)
+                    arc = arc.clone(link=None)
+                    arcs.add(arc)
+            elif isinstance(arc, ChainArc):
+                arcs.add(arc)
         table = self.node.table
         for key in [table.primary_key]+table.unique_keys:
             if key.is_partial:
                     width = len(foreign_key.origin_columns)
                     if foreign_key.origin_columns == columns[:width]:
                         join = DirectJoin(foreign_key)
-                        if join not in label_by_join:
+                        arc = ChainArc(table, [join])
+                        if arc not in arcs:
                             continue
-                        label = label_by_join[join]
-                        if localize(label.target) is None:
+                        if localize(arc.target) is None:
                             continue
-                        identity.append(label)
+                        identity.append(arc)
                         columns = columns[width:]
                         break
                 else:
                     column = columns[0]
-                    if column not in label_by_column:
+                    arc = ColumnArc(table, column)
+                    if arc not in arcs:
                         break
+                    identity.append(arc)
                     columns.pop(0)
-                    identity.append(label_by_column[column])
             if not columns:
                 return identity
 

src/htsql/core/tr/lookup.py

 
     def __call__(self):
         def chain(node):
-            labels = localize(node)
-            if labels is None:
+            arcs = localize(node)
+            if arcs is None:
                 return None
             recipes = []
-            for label in labels:
-                recipe = prescribe(label.arc, self.binding)
-                target_chain = chain(label.target)
+            for arc in arcs:
+                recipe = prescribe(arc, self.binding)
+                target_chain = chain(arc.target)
                 if target_chain is not None:
                     recipe = ChainRecipe([recipe, target_chain])
                 recipes.append(recipe)

src/htsql/core/tr/stitch.py

             # to the flow tree.
             def chain(flow):
                 node = TableNode(flow.family.table)
-                labels = localize(node)
-                if labels is None:
+                arcs = localize(node)
+                if arcs is None:
                     return None
                 units = []
-                for label in labels:
-                    if isinstance(label.arc, ColumnArc):
-                        identifier = IdentifierSyntax(label.name, flow.mark)
-                        binding = self.flow.binding.clone(syntax=identifier)
-                        code = ColumnUnit(label.arc.column, flow, binding)
+                for arc in arcs:
+                    if isinstance(arc, ColumnArc):
+                        code = ColumnUnit(arc.column, flow, flow.binding)
                         units.append(code)
-                    elif isinstance(label.arc, ChainArc):
-                        identifier = IdentifierSyntax(label.name, flow.mark)
-                        binding = self.flow.binding.clone(syntax=identifier)
+                    elif isinstance(arc, ChainArc):
                         subflow = flow
-                        for join in label.arc.joins:
+                        for join in arc.joins:
                             subflow = FiberTableFlow(subflow, join,
-                                                     binding)
+                                                     flow.binding)
                         subunits = chain(subflow)
                         assert subunits is not None
                         units.extend(subunits)
                     else:
-                        assert False, label.arc
+                        assert False, arc
                 return units
             if not self.flow.is_contracting:
                 flow = self.flow.inflate()