Commits

Kirill Simonov  committed 723069e

Fixed handling id() values in ETL commands.

  • Participants
  • Parent commits af051b5

Comments (0)

Files changed (4)

File src/htsql/tweak/etl/cmd/delete.py

             extract_node = BuildExtractNode.__invoke__(product.meta,
                     with_id=True, with_fields=False)
             resolve_key = BuildResolveKey.__invoke__(
-                    extract_node.node)
+                    extract_node.node, extract_node.arcs)
             execute_delete = BuildExecuteDelete.__invoke__(
                     extract_node.node.table)
             meta = decorate(VoidBinding())

File src/htsql/tweak/etl/cmd/insert.py

             return (lambda v: v)
         if self.origin_domain.width != self.domain.width:
             return None
+        group = list(enumerate(self.origin_domain.labels))
+        return self.align(group, self.domain)
+
+    def align(self, group, domain):
         converts = []
-        for origin_field, field in zip(self.origin_domain.labels,
-                                       self.domain.labels):
-            convert = Clarify.__invoke__(origin_field, field)
-            if convert is None:
-                return None
-            converts.append(convert)
-        id_class = ID.make(self.domain.dump)
+        for label in domain.labels:
+            if isinstance(label, IdentityDomain):
+                subgroup = []
+                subwidth = 0
+                while subwidth < label.width:
+                    idx, entry = group.pop(0)
+                    subgroup.append((idx, entry))
+                    if isinstance(entry, IdentityDomain):
+                        subwidth += entry.width
+                    else:
+                        subwidth += 1
+                if subwidth > label.width:
+                    return None
+                if (len(subgroup) == 1 and
+                        isinstance(subgroup[0][1], IdentityDomain)):
+                    idx, entry = subgroup[0]
+                    subgroup = list(enumerate(entry.labels))
+                    convert = self.align(subgroup, label)
+                    if convert is None:
+                        return None
+                    converts.append(lambda v, i=idx, c=convert: c(v[i]))
+                else:
+                    convert = self.align(subgroup, label)
+                    if convert is None:
+                        return None
+                    converts.append(convert)
+            else:
+                idx, entry = group.pop(0)
+                convert = Clarify.__invoke__(entry, label)
+                if convert is None:
+                    return None
+                converts.append(lambda v, i=idx, c=convert: c(v[i]))
+        id_class = ID.make(domain.dump)
         return (lambda v, id_class=id_class, cs=converts:
-                        id_class(c(i) for i, c in zip(v, cs))
-                                      if v is not None else None)
+                        id_class(c(v) for c in cs) if v is not None else None)
 
 
 class ExtractValuePipe(object):

File src/htsql/tweak/etl/cmd/merge.py

 from ....core.context import context
 from ....core.error import Error, PermissionError
 from ....core.entity import TableEntity, ColumnEntity
-from ....core.model import TableArc
+from ....core.model import TableArc, ColumnArc, ChainArc
 from ....core.classify import localize, relabel
 from ....core.connect import transaction, scramble, unscramble
 from ....core.domain import IdentityDomain, RecordDomain, ListDomain, Product
 from ....core.tr.signature import IsEqualSig, AndSig, PlaceholderSig
 from ....core.tr.decorate import decorate
 from ....core.tr.coerce import coerce
-from ....core.tr.lookup import identify
+from ....core.tr.lookup import prescribe
 from .command import MergeCmd
 from .insert import (BuildExtractNode, BuildExtractTable, BuildExecuteInsert,
         BuildResolveIdentity, BuildResolveChain)
         identity_arcs = localize(self.node)
         if identity_arcs is None:
             raise Error("Expected a table with identity")
-        index_by_arc = dict((arc, index) for index, arc in enumerate(self.arcs))
+        index_by_arc = {}
+        for index, arc in enumerate(self.arcs):
+            index_by_arc[arc] = index
+            if isinstance(arc, ColumnArc) and arc.link is not None:
+                index_by_arc[arc.link] = index
         id_indices = []
         for arc in identity_arcs:
             if arc not in index_by_arc:
         other_indices = []
         arcs = []
         for idx, arc in enumerate(self.arcs):
-            if arc in identity_arcs:
+            if idx in id_indices:
                 continue
             other_indices.append(idx)
             arcs.append(arc)
 
 class BuildResolveKey(Utility):
 
-    def __init__(self, node, with_error=True):
+    def __init__(self, node, arcs, with_error=True):
         self.node = node
+        self.arcs = arcs
         self.table = node.table
         self.with_error = with_error
 
         scope = RootBinding(syntax)
         state = BindingState(scope)
         seed = state.use(FreeTableRecipe(self.table), syntax)
-        recipe = identify(seed)
-        if recipe is None:
-            raise Error("Cannot determine identity of a link")
-        identity = state.use(recipe, syntax, scope=seed)
+        column_by_link = {}
+        for arc in self.arcs:
+            if isinstance(arc, ColumnArc) and arc.link is not None:
+                column_by_link[arc.link] = arc
+        identity_arcs = localize(self.node)
+        if identity_arcs is None:
+            raise Error("Expected a table with identity")
         count = itertools.count()
-        def make_images(identity):
+        def chain_arc(arc, scope):
             images = []
-            for field in identity.elements:
-                if isinstance(field.domain, IdentityDomain):
-                    images.extend(make_images(field))
-                else:
-                    item = FormulaBinding(scope,
-                                          PlaceholderSig(next(count)),
-                                          field.domain,
-                                          syntax)
-                    images.append((item, field))
-            return images
-        images = make_images(identity)
+            recipe = prescribe(arc, scope)
+            binding = state.use(recipe, syntax, scope=scope)
+            identity_arcs = localize(arc.target)
+            if identity_arcs:
+                fields = []
+                for identity_arc in identity_arcs:
+                    arc_images, arc_field = chain_arc(identity_arc, binding)
+                    images.extend(arc_images)
+                    fields.append(arc_field)
+                field = IdentityDomain(fields)
+            else:
+                item = FormulaBinding(scope,
+                                      PlaceholderSig(next(count)),
+                                      binding.domain,
+                                      syntax)
+                images.append((item, binding))
+                field = binding.domain
+            return images, field
+        images = []
+        fields = []
+        for arc in identity_arcs:
+            if arc in column_by_link:
+                arc = column_by_link[arc]
+            arc_images, arc_field = chain_arc(arc, seed)
+            images.extend(arc_images)
+            fields.append(arc_field)
+        identity_domain = IdentityDomain(fields)
         scope = LocateBinding(scope, seed, images, None, syntax)
         state.push_scope(scope)
         columns = []
         domain = ListDomain(binding.domain)
         binding = CollectBinding(state.root, binding, domain, syntax)
         pipe =  translate(binding)
-        domain = identity.domain
-        return ResolveKeyPipe(name, columns, domain, pipe, self.with_error)
+        return ResolveKeyPipe(name, columns, identity_domain, pipe,
+                              self.with_error)
 
 
 class ExecuteUpdatePipe(object):
             extract_identity = BuildExtractIdentity.__invoke__(
                     extract_node.node, extract_node.arcs)
             resolve_key = BuildResolveKey.__invoke__(
-                    extract_node.node, False)
+                    extract_node.node, extract_node.arcs, False)
             extract_table_for_update = BuildExtractTable.__invoke__(
                     extract_identity.node, extract_identity.arcs)
             execute_insert = BuildExecuteInsert.__invoke__(

File src/htsql/tweak/etl/cmd/update.py

             extract_node = BuildExtractNode.__invoke__(product.meta,
                     with_id=True, with_fields=True)
             resolve_key = BuildResolveKey.__invoke__(
-                    extract_node.node)
+                    extract_node.node, extract_node.arcs)
             extract_table = BuildExtractTable.__invoke__(
                     extract_node.node, extract_node.arcs)
             execute_update = BuildExecuteUpdate.__invoke__(