Commits

yerokhin committed 7245207

View key detection changed.

  • Participants
  • Parent commits be94ccc

Comments (0)

Files changed (2)

src/htsql_pgsql/introspect.py

                     or rule.ev_attr >= 0 \
                     or not rule.is_instead \
                     or rule.ev_qual != '<>':
+                # not a view
                 continue
 
             if not rule.ev_class in self.views_by_oid:
+                # not introspected view
                 continue
 
             view = self.views_by_oid[rule.ev_class]
             ruletree = rulesparser.RuleTreeParser().parse(rule.ev_action)
             for scenario in rulesparser.scenario_list:
                 if scenario.accepts(ruletree):
-                    for key in scenario.find_keys(ruletree, view, self.table_by_oid):
+                    keyset = scenario.find_keys(ruletree, view, self.table_by_oid)
+                    for key in keyset:
                         if isinstance(key, PrimaryKeyEntity):
                             view.unique_keys.append(key)
                             view.primary_key = key
                         if isinstance(key, ForeignKeyEntity):
                             view.foreign_keys.append(key)
-                    break
+                    if len(keyset) > 0:
+                        break
 
 
     def introspect_tables(self, schema_oid):

src/htsql_pgsql/rulesparser.py

 import re
 from htsql.entity import (PrimaryKeyEntity, ForeignKeyEntity)
 
+class RTEKind(object):
+    """
+    Range table enumeration
+    """
+    RTE_RELATION = 0			# ordinary relation reference
+    RTE_SUBQUERY = 1			# subquery in FROM
+    RTE_JOIN = 2				# join
+    RTE_SPECIAL = 3				# special rule relation (NEW or OLD)
+    RTE_FUNCTION = 4			# function in FROM
+    RTE_VALUES = 5				# VALUES (<exprlist>), (<exprlist>), ...
+    RTE_CTE = 6					# common table expr (WITH list element)
+
+
 class RuleTreeParser(object):
 
     def __init__(self):
         return m.group(1), instr[m.end(1):]
 
     def parse_token(self, instr):
-        m = re.match('[^ )}]+', instr)
+        m = re.match('(\\\\.|[^ )}])+', instr)
         return m.group(0), instr[m.end(0):]
 
     def parse_value(self, instr):
         pass
 
 
-class SingleTableScenario(Scenario):
+class SingleTableIdScenario(Scenario):
 
     def accepts(self, rule_tree):
         if len(rule_tree) > 1:
             return False
 
         query = rule_tree[0]
-        assert query.__class__.__name__ == 'QUERY'
+        assert query.__class__.__name__ == 'QUERY', query.__class__.__name__
         if query.commandType != '1' \
                 or query.groupClause != '<>' \
                 or query.distinctClause != '<>' \
                 or query.setOperations != '<>':
             return False
 
+        return self.find_rtable(query) is not None
+
+    def find_rtable(self, query):
         if len(query.jointree.fromlist) != 1:
-            return False
+            return None
+        tableref = query.jointree.fromlist[0]
+        if tableref.__class__.__name__ != 'RANGETBLREF':
+            return None
+        rtindex = int(tableref.rtindex) - 1
+        rtable = query.rtable[rtindex]
 
-        return True
+        if int(rtable.rtekind) == RTEKind.RTE_RELATION:
+            return rtable
+        if int(rtable.rtekind) == RTEKind.RTE_SUBQUERY:
+            return self.find_rtable(rtable.subquery)
+        return None
+
+#    def find_candidates(self, query):
+#        tableref = query.jointree.fromlist[0]
+#        rtindex = int(tableref.rtindex) - 1
+#        rtable = query.rtable[rtindex]
+#        view_columns = {}
+#        for target in query.targetList:
+#            if target.expr.__class__.__name__ == 'VAR':
+#                attindex = int(target.expr.varattno) - 1
+#                rcolname = rtable.eref.colnames[attindex].strip('"')
+#                vcolname = target.resname
+#                view_columns[rcolname] = vcolname
+#
+#        if rtable.rtekind == RTEKind.RTE_RELATION:
+#            return view_columns
+#        if rtable.rtekind == RTEKind.RTE_SUBQUERY:
+#            sub_columns = self.find_candidates(rtable.subquery)
+#            result = {}
+#            for rcolname in sub_columns:
+#                if sub_columns[rcolname] in view_columns:
+#                    result[rcolname] = view_columns[sub_columns[rcolname]]
+#            return result
 
     def find_keys(self, rule_tree, view, tablemap):
         query = rule_tree[0]
-        rtindex = int(query.jointree.fromlist[0].rtindex) - 1
-        rtable = query.rtable[rtindex]
+        rtable = self.find_rtable(query)
         o_table = tablemap[int(rtable.relid)]
         o_pkey = None
         for ukey in o_table.unique_keys:
 
         view_columns = {}
         for target in query.targetList:
-            if target.expr.__class__.__name__ == 'VAR':
-                attindex = int(target.expr.varattno) - 1
+            if target.expr.__class__.__name__ == 'VAR' \
+                    and target.resorigtbl != '0':
+                assert rtable.relid == target.resorigtbl
+                attindex = int(target.resorigcol) - 1
                 rcolname = rtable.eref.colnames[attindex].strip('"')
                 vcolname = target.resname
                 view_columns[rcolname] = vcolname
                                   o_table.schema_name, o_table.name, o_pkey.origin_column_names)
         return [v_pkey, v_fkey]
 
+class GroupByScenario(Scenario):
 
-scenario_list = [SingleTableScenario()]
+    def accepts(self, rule_tree):
+        if len(rule_tree) > 1:
+            return False
+
+        query = rule_tree[0]
+        assert query.__class__.__name__ == 'QUERY', query.__class__.__name__
+        if query.commandType != '1' \
+                or query.distinctClause != '<>' \
+                or query.setOperations != '<>':
+            return False
+
+        return query.hasAggs == 'true'
+
+    def find_keys(self, rule_tree, view, tablemap):
+        query = rule_tree[0]
+        v_colnames = []
+        for target in query.targetList:
+            if target.expr.__class__.__name__ == 'VAR' \
+                    and target.ressortgroupref != '0' \
+                    and target.resjunk != 'true':
+                v_colnames.append(target.resname)
+
+        if len(v_colnames) != len(query.groupClause):
+            return []
+
+        return [PrimaryKeyEntity(view.schema_name, view.name, v_colnames)]
+
+
+scenario_list = [SingleTableIdScenario(), GroupByScenario()]