Kirill Simonov avatar Kirill Simonov committed 8d32caf Merge

Merged.

Comments (0)

Files changed (1)

src/htsql_pgsql/rulesparser.py

         query = rule_tree[0]
         assert query.__class__.__name__ == 'QUERY', query.__class__.__name__
         if query.commandType != '1' \
+                or query.rtable == '<>' \
                 or query.groupClause != '<>' \
                 or query.distinctClause != '<>' \
                 or query.setOperations != '<>':
         tableref = query.jointree.fromlist[0]
         if tableref.__class__.__name__ != 'RANGETBLREF':
             return None
+        if query.rtable == '<>':
+            return None
         rtindex = int(tableref.rtindex) - 1
         rtable = query.rtable[rtindex]
 
         query = rule_tree[0]
         assert query.__class__.__name__ == 'QUERY', query.__class__.__name__
         if query.commandType != '1' \
+                or query.rtable == '<>' \
                 or query.distinctClause != '<>' \
                 or query.setOperations != '<>':
             return False
 
         return [PrimaryKeyEntity(view.schema_name, view.name, v_colnames)]
 
+class FKColumnRef(object):
+
+    def __init__(self, schema_name, table_name, column_name, key):
+        self.schema_name = schema_name
+        self.table_name = table_name
+        self.column_name = column_name
+        self.key = key
+        self.alias = None
+
 class SelectFKScenario(Scenario):
 
     def accepts(self, rule_tree):
 
         query = rule_tree[0]
         assert query.__class__.__name__ == 'QUERY', query.__class__.__name__
-        if query.commandType != '1':
+        if query.commandType != '1' \
+                or query.rtable == '<>':
             return False
 
         return True
         Find all relations in the query and return in a map by oid.
         """
         result = {}
+        if query.rtable == '<>':
+            return result
         for rtable in query.rtable:
             if int(rtable.rtekind) == RTEKind.RTE_RELATION:
                 result[rtable.relid] = rtable
 
     def get_key_column(self, table_entity, column_name):
         """
-        Returns tuple of (schema-name, table-name, column-name) of the primary key column
+        Returns tuple of (schema-name, table-name, column-name, key) of the primary key column
         referenced by specified column.
         """
         if table_entity.primary_key is not None:
             if column_name in table_entity.primary_key.origin_column_names:
-                return table_entity.schema_name, \
-                       table_entity.name, \
-                       column_name, \
-                       table_entity.primary_key
+                return FKColumnRef(table_entity.schema_name, \
+                           table_entity.name, \
+                           column_name, \
+                           table_entity.primary_key)
         for fkey in table_entity.foreign_keys:
             if column_name in fkey.origin_column_names:
                 index = fkey.origin_column_names.index(column_name)
-                return fkey.target_schema_name, \
-                       fkey.target_name, \
-                       fkey.target_column_names[index], \
-                       fkey
+                return FKColumnRef(fkey.target_schema_name, \
+                           fkey.target_name, \
+                           fkey.target_column_names[index], \
+                           fkey)
         return None
 
     def find_target_keys(self, query, rtablemap, tablemap):
         result = []
         for target in query.targetList:
-            item = None
-            if target.expr.__class__.__name__ == 'VAR' \
-                    and target.resorigtbl != '0' \
-                    and target.resjunk != 'true':
-                colname = rtablemap[target.resorigtbl].eref.colnames[int(target.resorigcol) - 1].strip('"')
-                if int(target.resorigtbl) not in tablemap:
-                    # not introspected table
-                    continue
-                table_entity = tablemap[int(target.resorigtbl)]
-                ref_column = self.get_key_column(table_entity, colname)
-                item = ref_column
-
-            result.append(item)
+            if target.resjunk != 'true':
+                item = None
+                if target.expr.__class__.__name__ == 'VAR' \
+                        and target.resorigtbl != '0' \
+                        and int(target.resorigtbl) in tablemap :
+                    colname = rtablemap[target.resorigtbl].eref.colnames[int(target.resorigcol) - 1].strip('"')
+                    table_entity = tablemap[int(target.resorigtbl)]
+                    item = self.get_key_column(table_entity, colname)
+                    if item is not None:
+                        item.alias = target.resname
+                result.append(item)
         return result
 
     def find_setoparg_keys(self, query, setoparg, rtablemap, tablemap):
         for c1, c2 in zip(candidates1, candidates2):
             item = None
             if c1 is not None and c2 is not None:
-                (schema1, table1, column1, key1) = c1
-                (schema2, table2, column2, key2) = c2
-                if schema1 == schema2 and table1 == table2 and column1 == column2:
+                if c1.schema_name == c2.schema_name \
+                        and c1.table_name == c2.table_name \
+                        and c1.column_name == c2.column_name:
                     # which key - doesn't matter
                     item = c1
 
     def find_keys(self, rule_tree, view, tablemap):
         query = rule_tree[0]
         candidates = self.find_query_keys(query, self.find_rtables(query), tablemap)
-        assert len(candidates) == len(query.targetList)
         keys = []
         schemas = {}
-        for (c, view_column) in zip(candidates, query.targetList):
+        for c in candidates:
             if c is not None:
-                (schema, table, column, key) = c
-                if schema not in schemas:
-                    schemas[schema] = {}
-                tables = schemas[schema]
-                if table not in tables:
-                    tables[table] = {}
-                columns = tables[table]
-                if column not in columns:
-                    columns[column] = view_column.resname
-                if key not in keys:
-                    keys.append(key)
+                if c.schema_name not in schemas:
+                    schemas[c.schema_name] = {}
+                tables = schemas[c.schema_name]
+                if c.table_name not in tables:
+                    tables[c.table_name] = {}
+                columns = tables[c.table_name]
+                if c.column_name not in columns:
+                    columns[c.column_name] = c.alias
+                if c.key not in keys:
+                    keys.append(c.key)
 
         result = []
         for key in keys:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.