Commits

yerokhin committed a39c5e6

view introspection new scenario: pk or fk columns in the select column list

Comments (0)

Files changed (1)

src/htsql_pgsql/rulesparser.py

             return self.find_rtable(rtable.subquery)
         return None
 
-#    def find_candidates(self, query):
-#        tableref = query.jointree.fromlist[0]
-#        rtindex = int(tableref.rtindex) - 1
-#        rtable = query.rtable[rtindex]
-#        view_columns = {}
-#        for target in query.targetList:
-#            if target.expr.__class__.__name__ == 'VAR':
-#                attindex = int(target.expr.varattno) - 1
-#                rcolname = rtable.eref.colnames[attindex].strip('"')
-#                vcolname = target.resname
-#                view_columns[rcolname] = vcolname
-#
-#        if rtable.rtekind == RTEKind.RTE_RELATION:
-#            return view_columns
-#        if rtable.rtekind == RTEKind.RTE_SUBQUERY:
-#            sub_columns = self.find_candidates(rtable.subquery)
-#            result = {}
-#            for rcolname in sub_columns:
-#                if sub_columns[rcolname] in view_columns:
-#                    result[rcolname] = view_columns[sub_columns[rcolname]]
-#            return result
-
     def find_keys(self, rule_tree, view, tablemap):
         query = rule_tree[0]
         rtable = self.find_rtable(query)
         view_columns = {}
         for target in query.targetList:
             if target.expr.__class__.__name__ == 'VAR' \
-                    and target.resorigtbl != '0':
+                    and target.resorigtbl != '0' \
+                    and target.resjunk != 'true':
                 assert rtable.relid == target.resorigtbl
                 attindex = int(target.resorigcol) - 1
                 rcolname = rtable.eref.colnames[attindex].strip('"')
 
         return [PrimaryKeyEntity(view.schema_name, view.name, v_colnames)]
 
+class SelectFKScenario(Scenario):
 
-scenario_list = [SingleTableIdScenario(), GroupByScenario()]
+    def accepts(self, rule_tree):
+        if len(rule_tree) > 1:
+            return False
+
+        query = rule_tree[0]
+        assert query.__class__.__name__ == 'QUERY', query.__class__.__name__
+        if query.commandType != '1':
+            return False
+
+        return True
+
+    def find_rtables(self, query):
+        """
+        Find all relations in the query and return in a map by oid.
+        """
+        result = {}
+        for rtable in query.rtable:
+            if int(rtable.rtekind) == RTEKind.RTE_RELATION:
+                result[rtable.relid] = rtable
+            elif int(rtable.rtekind) == RTEKind.RTE_SUBQUERY:
+                result.update(self.find_rtables(rtable.subquery))
+        return result
+
+    def get_key_column(self, table_entity, column_name):
+        """
+        Returns tuple of (schema-name, table-name, column-name) of the primary key column
+        referenced by specified column.
+        """
+        if table_entity.primary_key is not None:
+            if column_name in table_entity.primary_key.origin_column_names:
+                return table_entity.schema_name, \
+                       table_entity.name, \
+                       column_name, \
+                       table_entity.primary_key
+        for fkey in table_entity.foreign_keys:
+            if column_name in fkey.origin_column_names:
+                index = fkey.origin_column_names.index(column_name)
+                return fkey.target_schema_name, \
+                       fkey.target_name, \
+                       fkey.target_column_names[index], \
+                       fkey
+        return None
+
+    def find_target_keys(self, query, rtablemap, tablemap):
+        result = []
+        for target in query.targetList:
+            item = None
+            if target.expr.__class__.__name__ == 'VAR' \
+                    and target.resorigtbl != '0' \
+                    and target.resjunk != 'true':
+                colname = rtablemap[target.resorigtbl].eref.colnames[int(target.resorigcol) - 1].strip('"')
+                table_entity = tablemap[int(target.resorigtbl)]
+                ref_column = self.get_key_column(table_entity, colname)
+                item = ref_column
+
+            result.append(item)
+        return result
+
+    def find_setoparg_keys(self, query, setoparg, rtablemap, tablemap):
+        if setoparg.__class__.__name__ == 'RANGETBLREF':
+            rel1 = query.rtable[int(setoparg.rtindex) - 1]
+            assert int(rel1.rtekind) == RTEKind.RTE_SUBQUERY
+            return self.find_query_keys(rel1.subquery, rtablemap, tablemap)
+        elif setoparg.__class__.__name__ == 'SETOPERATIONSTMT':
+            return self.find_setop_keys(query, setoparg, rtablemap, tablemap)
+
+    def find_setop_keys(self, query, setop, rtablemap, tablemap):
+        assert query.setOperations.__class__.__name__ == 'SETOPERATIONSTMT'
+        candidates1 = self.find_setoparg_keys(query, setop.larg, rtablemap, tablemap)
+        candidates2 = self.find_setoparg_keys(query, setop.rarg, rtablemap, tablemap)
+        assert len(candidates1) == len(candidates2)
+        result = []
+        for c1, c2 in zip(candidates1, candidates2):
+            item = None
+            if c1 is not None and c2 is not None:
+                (schema1, table1, column1, key1) = c1
+                (schema2, table2, column2, key2) = c2
+                if schema1 == schema2 and table1 == table2 and column1 == column2:
+                    # which key - doesn't matter
+                    item = c1
+
+            result.append(item)
+        return result
+
+    def find_query_keys(self, query, rtablemap, tablemap):
+        if query.setOperations == '<>':
+            return self.find_target_keys(query, rtablemap, tablemap)
+        else:
+            assert query.setOperations.__class__.__name__ == 'SETOPERATIONSTMT'
+            return self.find_setop_keys(query, query.setOperations, rtablemap, tablemap)
+
+    def find_keys(self, rule_tree, view, tablemap):
+        query = rule_tree[0]
+        candidates = self.find_query_keys(query, self.find_rtables(query), tablemap)
+        assert len(candidates) == len(query.targetList)
+        keys = []
+        schemas = {}
+        for (c, view_column) in zip(candidates, query.targetList):
+            if c is not None:
+                (schema, table, column, key) = c
+                if schema not in schemas:
+                    schemas[schema] = {}
+                tables = schemas[schema]
+                if table not in tables:
+                    tables[table] = {}
+                columns = tables[table]
+                if column not in columns:
+                    columns[column] = view_column.resname
+                if key not in keys:
+                    keys.append(key)
+
+        result = []
+        for key in keys:
+            if isinstance(key, PrimaryKeyEntity):
+                ref_schema = key.origin_schema_name
+                ref_table = key.origin_name
+                ref_columns = key.origin_column_names
+            elif isinstance(key, ForeignKeyEntity):
+                ref_schema = key.target_schema_name
+                ref_table = key.target_name
+                ref_columns = key.target_column_names
+            else:
+                continue
+            fkey_columns = []
+            selected_columns = schemas[ref_schema][ref_table]
+            for pkey_column in ref_columns:
+                if pkey_column in selected_columns:
+                    fkey_columns.append(selected_columns[pkey_column])
+            if len(ref_columns) == len(fkey_columns):
+                result.append(ForeignKeyEntity(view.schema_name, view.name, fkey_columns,
+                                               ref_schema, ref_table, ref_columns))
+        return result
+
+
+scenario_list = [SingleTableIdScenario(), GroupByScenario(), SelectFKScenario()]