Commits

Éric Lemoine  committed b1a258e Draft

breaking up PGDialect.get_columns, and add PostGIS column reflection tests

  • Participants
  • Parent commits 98d457c

Comments (0)

Files changed (2)

File lib/sqlalchemy/dialects/postgresql/base.py

         # format columns
         columns = []
         for name, format_type, default, notnull, attnum, table_oid in rows:
-            ## strip (5) from character varying(5), timestamp(5)
-            # with time zone, etc
-            attype = re.sub(r'\(.*\)', '', format_type)
+            column_info = self._get_column_info(name, format_type, default,
+                                                notnull, domains, enums, schema)
+            columns.append(column_info)
+        return columns
 
-            # strip '[]' from integer[], etc.
-            attype = re.sub(r'\[\]', '', attype)
+    def _get_column_info(self, name, format_type, default,
+                         notnull, domains, enums, schema):
+        ## strip (5) from character varying(5), timestamp(5)
+        # with time zone, etc
+        attype = re.sub(r'\(.*\)', '', format_type)
 
-            nullable = not notnull
-            is_array = format_type.endswith('[]')
-            charlen = re.search('\(([\d,]+)\)', format_type)
+        # strip '[]' from integer[], etc.
+        attype = re.sub(r'\[\]', '', attype)
+
+        nullable = not notnull
+        is_array = format_type.endswith('[]')
+        charlen = re.search('\(([\d,]+)\)', format_type)
+        if charlen:
+            charlen = charlen.group(1)
+        args = re.search('\((.*)\)', format_type)
+        if args:
+            args = tuple(args.group(1).split(','))
+        else:
+            args = ()
+        kwargs = {}
+
+        if attype == 'numeric':
             if charlen:
-                charlen = charlen.group(1)
-            args = re.search('\((.*)\)', format_type)
-            if args:
-                args = tuple(args.group(1).split(','))
+                prec, scale = charlen.split(',')
+                args = (int(prec), int(scale))
             else:
                 args = ()
-            kwargs = {}
+        elif attype == 'double precision':
+            args = (53, )
+        elif attype == 'integer':
+            args = ()
+        elif attype in ('timestamp with time zone',
+                        'time with time zone'):
+            kwargs['timezone'] = True
+            if charlen:
+                kwargs['precision'] = int(charlen)
+            args = ()
+        elif attype in ('timestamp without time zone',
+                        'time without time zone', 'time'):
+            kwargs['timezone'] = False
+            if charlen:
+                kwargs['precision'] = int(charlen)
+            args = ()
+        elif attype == 'bit varying':
+            kwargs['varying'] = True
+            if charlen:
+                args = (int(charlen),)
+            else:
+                args = ()
+        elif attype in ('interval','interval year to month',
+                            'interval day to second'):
+            if charlen:
+                kwargs['precision'] = int(charlen)
+            args = ()
+        elif charlen:
+            args = (int(charlen),)
 
-            if attype == 'numeric':
-                if charlen:
-                    prec, scale = charlen.split(',')
-                    args = (int(prec), int(scale))
+        while True:
+            if attype in self.ischema_names:
+                coltype = self.ischema_names[attype]
+                break
+            elif attype in enums:
+                enum = enums[attype]
+                coltype = ENUM
+                if "." in attype:
+                    kwargs['schema'], kwargs['name'] = attype.split('.')
                 else:
-                    args = ()
-            elif attype == 'double precision':
-                args = (53, )
-            elif attype == 'integer':
-                args = ()
-            elif attype in ('timestamp with time zone',
-                            'time with time zone'):
-                kwargs['timezone'] = True
-                if charlen:
-                    kwargs['precision'] = int(charlen)
-                args = ()
-            elif attype in ('timestamp without time zone',
-                            'time without time zone', 'time'):
-                kwargs['timezone'] = False
-                if charlen:
-                    kwargs['precision'] = int(charlen)
-                args = ()
-            elif attype == 'bit varying':
-                kwargs['varying'] = True
-                if charlen:
-                    args = (int(charlen),)
-                else:
-                    args = ()
-            elif attype in ('interval','interval year to month',
-                                'interval day to second'):
-                if charlen:
-                    kwargs['precision'] = int(charlen)
-                args = ()
-            elif charlen:
-                args = (int(charlen),)
+                    kwargs['name'] = attype
+                args = tuple(enum['labels'])
+                break
+            elif attype in domains:
+                domain = domains[attype]
+                attype = domain['attype']
+                # A table can't override whether the domain is nullable.
+                nullable = domain['nullable']
+                if domain['default'] and not default:
+                    # It can, however, override the default
+                    # value, but can't set it to null.
+                    default = domain['default']
+                continue
+            else:
+                coltype = None
+                break
 
-            while True:
-                if attype in self.ischema_names:
-                    coltype = self.ischema_names[attype]
-                    break
-                elif attype in enums:
-                    enum = enums[attype]
-                    coltype = ENUM
-                    if "." in attype:
-                        kwargs['schema'], kwargs['name'] = attype.split('.')
-                    else:
-                        kwargs['name'] = attype
-                    args = tuple(enum['labels'])
-                    break
-                elif attype in domains:
-                    domain = domains[attype]
-                    attype = domain['attype']
-                    # A table can't override whether the domain is nullable.
-                    nullable = domain['nullable']
-                    if domain['default'] and not default:
-                        # It can, however, override the default
-                        # value, but can't set it to null.
-                        default = domain['default']
-                    continue
-                else:
-                    coltype = None
-                    break
+        if coltype:
+            coltype = coltype(*args, **kwargs)
+            if is_array:
+                coltype = ARRAY(coltype)
+        else:
+            util.warn("Did not recognize type '%s' of column '%s'" %
+                      (attype, name))
+            coltype = sqltypes.NULLTYPE
+        # adjust the default value
+        autoincrement = False
+        if default is not None:
+            match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+            if match is not None:
+                autoincrement = True
+                # the default is related to a Sequence
+                sch = schema
+                if '.' not in match.group(2) and sch is not None:
+                    # unconditionally quote the schema name.  this could
+                    # later be enhanced to obey quoting rules /
+                    # "quote schema"
+                    default = match.group(1) + \
+                                ('"%s"' % sch) + '.' + \
+                                match.group(2) + match.group(3)
+                    print default
 
-            if coltype:
-                coltype = coltype(*args, **kwargs)
-                if is_array:
-                    coltype = ARRAY(coltype)
-            else:
-                util.warn("Did not recognize type '%s' of column '%s'" %
-                          (attype, name))
-                coltype = sqltypes.NULLTYPE
-            # adjust the default value
-            autoincrement = False
-            if default is not None:
-                match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
-                if match is not None:
-                    autoincrement = True
-                    # the default is related to a Sequence
-                    sch = schema
-                    if '.' not in match.group(2) and sch is not None:
-                        # unconditionally quote the schema name.  this could
-                        # later be enhanced to obey quoting rules /
-                        # "quote schema"
-                        default = match.group(1) + \
-                                    ('"%s"' % sch) + '.' + \
-                                    match.group(2) + match.group(3)
-
-            column_info = dict(name=name, type=coltype, nullable=nullable,
-                               default=default, autoincrement=autoincrement)
-            columns.append(column_info)
-        return columns
+        column_info = dict(name=name, type=coltype, nullable=nullable,
+                           default=default, autoincrement=autoincrement)
+        return column_info
 
     @reflection.cache
     def get_pk_constraint(self, connection, table_name, schema=None, **kw):

File test/dialect/test_postgresql.py

         eq_(ind, [{'unique': False, 'column_names': [u'y'], 'name': u'idx1'}])
         conn.close()
 
+class PostGISColumnReflection(fixtures.TestBase):
+    __only_on__ = 'postgresql'
+
+    class Geometry(object):
+        def __init__(self, geometry_type=None, srid=None):
+            self.geometry_type = geometry_type
+            self.srid = srid
+
+    ischema_names = None
+
+    @classmethod
+    def setup_class(cls):
+        ischema_names = postgresql.PGDialect.ischema_names
+        postgresql.PGDialect.ischema_names = ischema_names.copy()
+        postgresql.PGDialect.ischema_names['geometry'] = cls.Geometry
+        cls.ischema_names = ischema_names
+
+    @classmethod
+    def teardown_class(cls):
+        postgresql.PGDialect.ischema_names = cls.ischema_names
+        cls.ischema_names = None
+
+    def test_geometry(self):
+        dialect = postgresql.PGDialect()
+        column_info = dialect._get_column_info(
+                'geom', 'geometry', None, False,
+                {}, {}, 'public')
+        assert isinstance(column_info['type'], self.Geometry)
+        assert column_info['type'].geometry_type is None
+        assert column_info['type'].srid is None
+
+    def test_geometry_with_type(self):
+        dialect = postgresql.PGDialect()
+        column_info = dialect._get_column_info(
+                'geom', 'geometry(POLYGON)', None, False,
+                {}, {}, 'public')
+        assert isinstance(column_info['type'], self.Geometry)
+        assert column_info['type'].geometry_type == 'POLYGON'
+        assert column_info['type'].srid is None
+
+    def test_geometry_with_type_and_srid(self):
+        dialect = postgresql.PGDialect()
+        column_info = dialect._get_column_info(
+                'geom', 'geometry(POLYGON,4326)', None, False,
+                {}, {}, 'public')
+        assert isinstance(column_info['type'], self.Geometry)
+        assert column_info['type'].geometry_type == 'POLYGON'
+        assert column_info['type'].srid == '4326'
+
 class MiscTest(fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL):
 
     __only_on__ = 'postgresql'