Commits

Anonymous committed e02dc1f

Add unittests for custom compare_type function

Comments (0)

Files changed (1)

tests/test_autogenerate.py

 import re
 import sys
 from unittest import TestCase
+from mock import Mock, patch
 
 from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
-    Numeric, CHAR, ForeignKey, DATETIME, \
+    Numeric, CHAR, ForeignKey, DATETIME, VARCHAR, \
     TypeDecorator, CheckConstraint, Unicode, Enum,\
     UniqueConstraint, Boolean, ForeignKeyConstraint,\
     PrimaryKeyConstraint
             [('remove_table', 'extra'), ('remove_table', 'user')]
         )
 
+    def test_uses_custom_compare_type_function(self):
+        my_compare_type = Mock()
+        my_compare_type.return_value = None
+
+        context = MigrationContext.configure(
+            connection=self.bind.connect(),
+            opts={
+                'compare_type': my_compare_type,
+                'target_metadata': self.m1,
+                'upgrade_token':"upgrades",
+                'downgrade_token':"downgrades",
+                'alembic_module_prefix':'op.',
+                'sqlalchemy_module_prefix':'sa.'
+            }
+        )
+        autogenerate._produce_migration_diffs(context, {}, set())
+
+        first_table = self.m1.tables['address']
+        first_column = first_table.columns['email_address']
+
+        # We'll just test the first call
+        _, args, _ = my_compare_type.mock_calls[0]
+        ctx, inspected_column, metadata_column, inspected_type, metadata_type = args
+        eq_(ctx, context)
+        eq_(metadata_column, first_column)
+        eq_(metadata_type, first_column.type)
+        eq_(inspected_column.name, first_column.name)
+        eq_(type(inspected_type), VARCHAR)
+
+    def test_fields_excluded_when_custom_compare_type_returns_False(self):
+        my_compare_type = Mock()
+        my_compare_type.return_value = False
+
+        context = MigrationContext.configure(
+            connection=self.bind.connect(),
+            opts={
+                'compare_type': my_compare_type,
+                'target_metadata': self.m1,
+                'upgrade_token':"upgrades",
+                'downgrade_token':"downgrades",
+                'alembic_module_prefix':'op.',
+                'sqlalchemy_module_prefix':'sa.'
+            }
+        )
+        template_args = {}
+        newtype = String(length=30)
+        with patch.object(self.m1.tables['address'].columns['email_address'], 'type', new=newtype):
+            autogenerate._produce_migration_diffs(context, template_args, set())
+
+        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+"""### commands auto generated by Alembic - please adjust! ###
+    pass
+    ### end Alembic commands ###""")
+
+    def test_fields_included_when_custom_compare_type_returns_True(self):
+        my_compare_type = Mock()
+        my_compare_type.return_value = True
+
+        context = MigrationContext.configure(
+            connection=self.bind.connect(),
+            opts={
+                'compare_type': my_compare_type,
+                'target_metadata': self.m1,
+                'upgrade_token':"upgrades",
+                'downgrade_token':"downgrades",
+                'alembic_module_prefix':'op.',
+                'sqlalchemy_module_prefix':'sa.'
+            }
+        )
+        template_args = {}
+        autogenerate._produce_migration_diffs(context, template_args, set())
+
+        eq_(re.sub(r"u'", "'", template_args['upgrades']),
+"""### commands auto generated by Alembic - please adjust! ###
+    op.alter_column('address', 'email_address',
+               existing_type=sa.VARCHAR(length=100),
+               type_=sa.String(length=100),
+               existing_nullable=False)
+    op.alter_column('address', 'id',
+               existing_type=sa.INTEGER(),
+               type_=sa.Integer(),
+               existing_nullable=False)
+    op.alter_column('extra', 'uid',
+               existing_type=sa.INTEGER(),
+               type_=sa.Integer(),
+               existing_nullable=True)
+    op.alter_column('extra', 'x',
+               existing_type=sa.CHAR(),
+               type_=sa.CHAR(),
+               existing_nullable=True)
+    op.alter_column('order', 'amount',
+               existing_type=sa.NUMERIC(precision=8, scale=2),
+               type_=sa.Numeric(precision=8, scale=2),
+               existing_nullable=False,
+               existing_server_default='0')
+    op.alter_column('order', 'order_id',
+               existing_type=sa.INTEGER(),
+               type_=sa.Integer(),
+               existing_nullable=False)
+    op.alter_column('user', 'a1',
+               existing_type=sa.TEXT(),
+               type_=sa.Text(),
+               existing_nullable=True)
+    op.alter_column('user', 'id',
+               existing_type=sa.INTEGER(),
+               type_=sa.Integer(),
+               existing_nullable=False)
+    op.alter_column('user', 'name',
+               existing_type=sa.VARCHAR(length=50),
+               type_=sa.String(length=50),
+               existing_nullable=True)
+    op.alter_column('user', 'pw',
+               existing_type=sa.VARCHAR(length=50),
+               type_=sa.String(length=50),
+               existing_nullable=True)
+    ### end Alembic commands ###""")
+
+
 class AutogenKeyTest(AutogenTest, TestCase):
     @classmethod
     def _get_db_schema(cls):