Commits

Mike Bayer committed 2692238

- Improvements to the system by which SQL types generate within
``__repr__()``, particularly with regards to the MySQL integer/numeric/
character types which feature a wide variety of keyword arguments.
The ``__repr__()`` is important for use with Alembic autogenerate
for when Python code is rendered in a migration script.
[ticket:2893]

  • Participants
  • Parent commits f701f87

Comments (0)

Files changed (7)

File doc/build/changelog/changelog_09.rst

     :version: 0.9.0b2
 
     .. change::
+        :tags: bug, mysql
+        :tickets: 2893
+
+        Improvements to the system by which SQL types generate within
+        ``__repr__()``, particularly with regards to the MySQL integer/numeric/
+        character types which feature a wide variety of keyword arguments.
+        The ``__repr__()`` is important for use with Alembic autogenerate
+        for when Python code is rendered in a migration script.
+
+    .. change::
         :tags: feature, postgresql
         :tickets: 2581
         :pullreq: github:50

File lib/sqlalchemy/dialects/mysql/base.py

 
 
 class _NumericType(object):
-    """Base for MySQL numeric types."""
+    """Base for MySQL numeric types.
+
+    This is the base both for NUMERIC as well as INTEGER, hence
+    it's a mixin.
+
+    """
 
     def __init__(self, unsigned=False, zerofill=False, **kw):
         self.unsigned = unsigned
         self.zerofill = zerofill
         super(_NumericType, self).__init__(**kw)
 
+    def __repr__(self):
+        return util.generic_repr(self,
+                to_inspect=[_NumericType, sqltypes.Numeric])
 
 class _FloatType(_NumericType, sqltypes.Float):
     def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
         super(_FloatType, self).__init__(precision=precision, asdecimal=asdecimal, **kw)
         self.scale = scale
 
+    def __repr__(self):
+        return util.generic_repr(self,
+                to_inspect=[_FloatType, _NumericType, sqltypes.Float])
 
 class _IntegerType(_NumericType, sqltypes.Integer):
     def __init__(self, display_width=None, **kw):
         self.display_width = display_width
         super(_IntegerType, self).__init__(**kw)
 
+    def __repr__(self):
+        return util.generic_repr(self,
+                to_inspect=[_IntegerType, _NumericType, sqltypes.Integer])
 
 class _StringType(sqltypes.String):
     """Base for MySQL string types."""
 
     def __init__(self, charset=None, collation=None,
-                 ascii=False, binary=False,
+                 ascii=False, binary=False, unicode=False,
                  national=False, **kw):
         self.charset = charset
 
         kw.setdefault('collation', kw.pop('collate', collation))
 
         self.ascii = ascii
-        # We have to munge the 'unicode' param strictly as a dict
-        # otherwise 2to3 will turn it into str.
-        self.__dict__['unicode'] = kw.get('unicode', False)
-        # sqltypes.String does not accept the 'unicode' arg at all.
-        if 'unicode' in kw:
-            del kw['unicode']
+        self.unicode = unicode
         self.binary = binary
         self.national = national
         super(_StringType, self).__init__(**kw)
 
+    def __repr__(self):
+        return util.generic_repr(self,
+                to_inspect=[_StringType, sqltypes.String])
 
 class NUMERIC(_NumericType, sqltypes.NUMERIC):
     """MySQL NUMERIC type."""
         _StringType.__init__(self, length=length, **kw)
         sqltypes.Enum.__init__(self, *values)
 
+    def __repr__(self):
+        return util.generic_repr(self,
+                to_inspect=[ENUM, _StringType, sqltypes.Enum])
+
     def bind_processor(self, dialect):
         super_convert = super(ENUM, self).bind_processor(dialect)
 
 MSInteger = INTEGER
 
 colspecs = {
+    _IntegerType: _IntegerType,
+    _NumericType: _NumericType,
+    _FloatType: _FloatType,
     sqltypes.Numeric: NUMERIC,
     sqltypes.Float: FLOAT,
     sqltypes.Time: TIME,

File lib/sqlalchemy/sql/sqltypes.py

 
     """
 
-    def __init__(self, **kw):
-        name = kw.pop('name', None)
+    def __init__(self, name=None, schema=None, metadata=None,
+                inherit_schema=False, quote=None):
         if name is not None:
-            self.name = quoted_name(name, kw.pop('quote', None))
+            self.name = quoted_name(name, quote)
         else:
             self.name = None
-        self.schema = kw.pop('schema', None)
-        self.metadata = kw.pop('metadata', None)
-        self.inherit_schema = kw.pop('inherit_schema', False)
+        self.schema = schema
+        self.metadata = metadata
+        self.inherit_schema = inherit_schema
         if self.metadata:
             event.listen(
                 self.metadata,
         SchemaType.__init__(self, **kw)
 
     def __repr__(self):
-        return util.generic_repr(self, [
-                        ("native_enum", True),
-                        ("name", None)
-                    ])
+        return util.generic_repr(self,
+              to_inspect=[Enum, SchemaType],
+          )
 
     def _should_create_constraint(self, compiler):
         return not self.native_enum or \

File lib/sqlalchemy/util/langhelpers.py

 from .. import exc
 import hashlib
 from . import compat
+from . import _collections
 
 def md5_hex(x):
     if compat.py3k:
 
     """
     if to_inspect is None:
-        to_inspect = obj
+        to_inspect = [obj]
+    else:
+        to_inspect = _collections.to_list(to_inspect)
 
     missing = object()
 
-    def genargs():
+    pos_args = []
+    kw_args = _collections.OrderedDict()
+    vargs = None
+    for i, insp in enumerate(to_inspect):
         try:
-            (args, vargs, vkw, defaults) = \
-                inspect.getargspec(to_inspect.__init__)
+            (_args, _vargs, vkw, defaults) = \
+                inspect.getargspec(insp.__init__)
         except TypeError:
-            return
+            continue
+        else:
+            default_len = defaults and len(defaults) or 0
+            if i == 0:
+                if _vargs:
+                    vargs = _vargs
+                if default_len:
+                    pos_args.extend(_args[1:-default_len])
+                else:
+                    pos_args.extend(_args[1:])
+            else:
+                kw_args.update([
+                    (arg, missing) for arg in _args[1:-default_len]
+                ])
 
-        default_len = defaults and len(defaults) or 0
+            if default_len:
+                kw_args.update([
+                    (arg, default)
+                        for arg, default
+                        in zip(_args[-default_len:], defaults)
+                ])
+    output = []
 
-        if not default_len:
-            for arg in args[1:]:
-                yield repr(getattr(obj, arg, None))
-            if vargs is not None and hasattr(obj, vargs):
-                yield ', '.join(repr(val) for val in getattr(obj, vargs))
-        else:
-            for arg in args[1:-default_len]:
-                yield repr(getattr(obj, arg, None))
-            for (arg, defval) in zip(args[-default_len:], defaults):
-                try:
-                    val = getattr(obj, arg, missing)
-                    if val is not missing and val != defval:
-                        yield '%s=%r' % (arg, val)
-                except:
-                    pass
-        if additional_kw:
-            for arg, defval in additional_kw:
-                try:
-                    val = getattr(obj, arg, missing)
-                    if val is not missing and val != defval:
-                        yield '%s=%r' % (arg, val)
-                except:
-                    pass
-
-    return "%s(%s)" % (obj.__class__.__name__, ", ".join(genargs()))
+    output.extend(repr(getattr(obj, arg, None)) for arg in pos_args)
+
+    if vargs is not None and hasattr(obj, vargs):
+        output.extend([repr(val) for val in getattr(obj, vargs)])
+
+    for arg, defval in kw_args.items():
+        try:
+            val = getattr(obj, arg, missing)
+            if val is not missing and val != defval:
+                output.append('%s=%r' % (arg, val))
+        except:
+            pass
+
+    if additional_kw:
+        for arg, defval in additional_kw:
+            try:
+                val = getattr(obj, arg, missing)
+                if val is not missing and val != defval:
+                    output.append('%s=%r' % (arg, val))
+            except:
+                pass
+
+    return "%s(%s)" % (obj.__class__.__name__, ", ".join(output))
 
 
 class portable_instancemethod(object):

File test/base/test_utils.py

             "Foo(b=5, d=7)"
         )
 
+    def test_multi_kw(self):
+        class Foo(object):
+            def __init__(self, a, b, c=3, d=4):
+                self.a = a
+                self.b = b
+                self.c = c
+                self.d = d
+        class Bar(Foo):
+            def __init__(self, e, f, g=5, **kw):
+                self.e = e
+                self.f = f
+                self.g = g
+                super(Bar, self).__init__(**kw)
+
+        eq_(
+            util.generic_repr(
+                Bar('e', 'f', g=7, a=6, b=5, d=9),
+                to_inspect=[Bar, Foo]
+            ),
+            "Bar('e', 'f', g=7, a=6, b=5, d=9)"
+        )
+
+        eq_(
+            util.generic_repr(
+                Bar('e', 'f', a=6, b=5),
+                to_inspect=[Bar, Foo]
+            ),
+            "Bar('e', 'f', a=6, b=5)"
+        )
+
+    def test_multi_kw_repeated(self):
+        class Foo(object):
+            def __init__(self, a=1, b=2):
+                self.a = a
+                self.b = b
+        class Bar(Foo):
+            def __init__(self, b=3, c=4, **kw):
+                self.c = c
+                super(Bar, self).__init__(b=b, **kw)
+
+        eq_(
+            util.generic_repr(
+                Bar(a='a', b='b', c='c'),
+                to_inspect=[Bar, Foo]
+            ),
+            "Bar(b='b', c='c', a='a')"
+        )
+
+
     def test_discard_vargs(self):
         class Foo(object):
             def __init__(self, a, b, *args):

File test/dialect/mysql/test_types.py

            ]
 
         for type_, args, kw, res in columns:
+            type_inst = type_(*args, **kw)
             self.assert_compile(
-                type_(*args, **kw),
+                type_inst,
+                res
+            )
+            # test that repr() copies out all arguments
+            print "mysql.%r" % type_inst
+            self.assert_compile(
+                eval("mysql.%r" % type_inst),
                 res
             )
 
             (mysql.ENUM, ["foo", "bar"], {'unicode':True},
              '''ENUM('foo','bar') UNICODE'''),
 
-            (String, [20], {"collation":"utf8"}, 'VARCHAR(20) COLLATE utf8')
+            (String, [20], {"collation": "utf8"}, 'VARCHAR(20) COLLATE utf8')
 
 
            ]
 
         for type_, args, kw, res in columns:
+            type_inst = type_(*args, **kw)
+            self.assert_compile(
+                type_inst,
+                res
+            )
+            # test that repr() copies out all arguments
             self.assert_compile(
-                type_(*args, **kw),
+                eval("mysql.%r" % type_inst)
+                    if type_ is not String
+                    else eval("%r" % type_inst),
                 res
             )
 

File test/sql/test_types.py

         # depending on backend.
         assert "('x'," in e.print_sql()
 
+    def test_repr(self):
+        e = Enum("x", "y", name="somename", convert_unicode=True,
+                        quote=True, inherit_schema=True)
+        eq_(
+            repr(e),
+            "Enum('x', 'y', name='somename', inherit_schema=True)"
+        )
+
 class BinaryTest(fixtures.TestBase, AssertsExecutionResults):
     __excluded_on__ = (
         ('mysql', '<', (4, 1, 1)),  # screwy varbinary types