Commits

jason kirtland  committed 4b93c55

Added 'unformat_identifiers', produces a list of unquoted identifiers from an identifier or a fully qualified identifier string.

  • Participants
  • Parent commits 923fbf5

Comments (0)

Files changed (3)

File lib/sqlalchemy/ansisql.py

 
         return value.replace('"', '""')
 
+    def _unescape_identifier(self, value):
+        """Canonicalize an escaped identifier.
+
+        Subclasses should override this to provide database-dependent
+        unescaping behavior that reverses _escape_identifier.
+        """
+
+        return value.replace('""', '"')
+
     def quote_identifier(self, value):
         """Quote an identifier.
 
         else:
             return (self.format_table(table, use_schema=False), )
 
+    def unformat_identifiers(self, identifiers):
+        """Unpack 'schema.table.column'-like strings into components."""
+
+        try:
+            r = self._r_identifiers
+        except AttributeError:
+            initial, final, escaped_final = \
+                     [re.escape(s) for s in
+                      (self.initial_quote, self.final_quote,
+                       self._escape_identifier(self.final_quote))]
+            r = re.compile(
+                r'(?:'
+                r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
+                r'|([^\.]+))(?=\.|$))+' %
+                { 'initial': initial,
+                  'final': final,
+                  'escaped': escaped_final })
+            self._r_identifiers = r
+        
+        return [self._unescape_identifier(i)
+                for i in [a or b for a, b in r.findall(identifiers)]]
+
+
 dialect = ANSIDialect

File lib/sqlalchemy/databases/mysql.py

     def _escape_identifier(self, value):
         return value.replace('`', '``')
 
+    def _unescape_identifier(self, value):
+        return value.replace('``', '`')
+
     def _fold_identifier_case(self, value):
         # TODO: determine MySQL's case folding rules
         #

File test/sql/quote.py

             x = lc_table1.select(distinct=True).alias("lala").select().scalar()
         finally:
             meta.drop_all()
+
+class PreparerTest(PersistTest):
+    """Test the db-agnostic quoting services of ANSIIdentifierPreparer."""
+
+    def test_unformat(self):
+        prep = ansisql.ANSIIdentifierPreparer(None)
+        unformat = prep.unformat_identifiers
+
+        def a_eq(have, want):
+            if have != want:
+                print "Wanted %s" % want
+                print "Received %s" % have
+            self.assert_(have == want)
+
+        a_eq(unformat('foo'), ['foo'])
+        a_eq(unformat('"foo"'), ['foo'])
+        a_eq(unformat("'foo'"), ["'foo'"])
+        a_eq(unformat('foo.bar'), ['foo', 'bar'])
+        a_eq(unformat('"foo"."bar"'), ['foo', 'bar'])
+        a_eq(unformat('foo."bar"'), ['foo', 'bar'])
+        a_eq(unformat('"foo".bar'), ['foo', 'bar'])
+        a_eq(unformat('"foo"."b""a""r"."baz"'), ['foo', 'b"a"r', 'baz'])
+
+    def test_unformat_custom(self):
+        class Custom(ansisql.ANSIIdentifierPreparer):
+            def __init__(self, dialect):
+                super(Custom, self).__init__(dialect, initial_quote='`',
+                                             final_quote='`')
+            def _escape_identifier(self, value):
+                return value.replace('`', '``')
+            def _unescape_identifier(self, value):
+                return value.replace('``', '`')
+
+        prep = Custom(None)
+        unformat = prep.unformat_identifiers
+
+        def a_eq(have, want):
+            if have != want:
+                print "Wanted %s" % want
+                print "Received %s" % have
+            self.assert_(have == want)
+
+        a_eq(unformat('foo'), ['foo'])
+        a_eq(unformat('`foo`'), ['foo'])
+        a_eq(unformat(`'foo'`), ["'foo'"])
+        a_eq(unformat('foo.bar'), ['foo', 'bar'])
+        a_eq(unformat('`foo`.`bar`'), ['foo', 'bar'])
+        a_eq(unformat('foo.`bar`'), ['foo', 'bar'])
+        a_eq(unformat('`foo`.bar'), ['foo', 'bar'])
+        a_eq(unformat('`foo`.`b``a``r`.`baz`'), ['foo', 'b`a`r', 'baz'])
         
 if __name__ == "__main__":
     testbase.main()