Commits

James Crasta committed 8a0ae49

Improve SQLAlchemy test coverage, remove hacky _obj kluge.

Comments (0)

Files changed (2)

tests/ext_sqlalchemy.py

 
 from sqlalchemy import create_engine, ForeignKey
 from sqlalchemy.schema import MetaData, Table, Column, ColumnDefault
-from sqlalchemy.types import String, Integer, Date
+from sqlalchemy.types import String, Integer, Numeric, Date, Text, Enum
 from sqlalchemy.orm import sessionmaker, relationship, backref
 from sqlalchemy.ext.declarative import declarative_base
 
 
 from wtforms.compat import text_type, iteritems
 from wtforms.ext.sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField
-from wtforms.form import Form
+from wtforms import Form, fields
 from wtforms.fields import TextField
-from wtforms.ext.sqlalchemy.orm import model_form
+from wtforms.ext.sqlalchemy.orm import model_form, ModelConversionError
 from wtforms.validators import Optional, Required, Length
 from wtforms.ext.sqlalchemy.validators import Unique
 
         self._fill(sess)
 
         class F(Form):
-            a = QuerySelectField(get_label=(lambda model: model.name), query_factory=lambda:sess.query(self.Test), widget=LazySelect())
-            b = QuerySelectField(allow_blank=True, query_factory=lambda:sess.query(self.PKTest), widget=LazySelect())
+            a = QuerySelectField(get_label=(lambda model: model.name), query_factory=lambda: sess.query(self.Test), widget=LazySelect())
+            b = QuerySelectField(allow_blank=True, query_factory=lambda: sess.query(self.PKTest), widget=LazySelect())
 
         form = F()
         self.assertEqual(form.a.data, None)
             __tablename__ = "course"
             id = Column(Integer, primary_key=True)
             name = Column(String(255), nullable=False)
+            # These are for better model form testing
+            cost = Column(Numeric(5, 2), nullable=False)
+            description = Column(Text, nullable=False)
+            level = Column(Enum('Primary', 'Secondary'))
 
         class School(Model):
             __tablename__ = "school"
             id = Column(Integer, primary_key=True)
             full_name = Column(String(255), nullable=False, unique=True)
             dob = Column(Date(), nullable=True)
-            current_school_id = Column(Integer, ForeignKey(School.id),
-                nullable=False)
+            current_school_id = Column(Integer, ForeignKey(School.id), nullable=False)
 
             current_school = relationship(School, backref=backref('students'))
-            courses = relationship("Course", secondary=student_course,
-                backref=backref("students", lazy='dynamic'))
+            courses = relationship(
+                "Course",
+                secondary=student_course,
+                backref=backref("students", lazy='dynamic')
+            )
 
         self.School = School
         self.Student = Student
+        self.Course = Course
 
         engine = create_engine('sqlite:///:memory:', echo=False)
         Session = sessionmaker(bind=engine)
         self.assertTrue(issubclass(QuerySelectMultipleField,
             student_form._fields['courses'].__class__))
 
+    def test_convert_basic(self):
+        self.assertRaises(TypeError, model_form, None)
+        self.assertRaises(ModelConversionError, model_form, self.Course)
+        form_class = model_form(self.Course, exclude=['students'])
+        form = form_class()
+        self.assertEqual(len(list(form)), 4)
+        assert isinstance(form.cost, fields.DecimalField)
+
 
 class ModelFormColumnDefaultTest(TestCase):
-
     def setUp(self):
         Model = declarative_base()
 

wtforms/ext/sqlalchemy/orm.py

     return _inner
 
 
+class ModelConversionError(Exception):
+    def __init__(self, message):
+        Exception.__init__(self, message)
+
+
 class ModelConverterBase(object):
     def __init__(self, converters, use_mro=True):
         self.use_mro = use_mro
 
         converter = None
         column = None
+        types = None
 
         if not hasattr(prop, 'direction'):
             column = prop.columns[0]
                 types = [type(column.type)]
 
             for col_type in types:
-                type_string = '%s.%s' % (col_type.__module__,
-                    col_type.__name__)
+                type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
                 if type_string.startswith('sqlalchemy'):
                     type_string = type_string[11:]
 
                         converter = self.converters[col_type.__name__]
                         break
                 else:
-                    return
+                    raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0]))
+        else:
+            # We have a property with a direction.
+            if not db_session:
+                raise ModelConversionError("Cannot convert field %s, need DB session." % prop.key)
 
-        if db_session and hasattr(prop, 'direction'):
             foreign_model = prop.mapper.class_
 
             nullable = True
         self._string_common(field_args=field_args, **extra)
         return f.TextField(**field_args)
 
-    @converts('Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary')
+    @converts('types.Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary')
     def conv_Text(self, field_args, **extra):
         self._string_common(field_args=field_args, **extra)
         return f.TextAreaField(**field_args)
 
 
 def model_fields(model, db_session=None, only=None, exclude=None,
-    field_args=None, converter=None):
+        field_args=None, converter=None):
     """
     Generate a dictionary of fields for a given SQLAlchemy model.
 
     See `model_form` docstring for description of parameters.
     """
-    if not hasattr(model, '_sa_class_manager'):
-        raise TypeError('model must be a sqlalchemy mapped model')
-
     mapper = model._sa_class_manager.mapper
     converter = converter or ModelConverter()
     field_args = field_args or {}
     :param type_name:
         An optional string to set returned type name.
     """
-    class ModelForm(base_class):
-        """Sets object as form attribute."""
-        def __init__(self, *args, **kwargs):
-            if 'obj' in kwargs:
-                self._obj = kwargs['obj']
-            super(ModelForm, self).__init__(*args, **kwargs)
+    if not hasattr(model, '_sa_class_manager'):
+        raise TypeError('model must be a sqlalchemy mapped model')
 
     if not exclude:
         exclude = []
         if not hasattr(prop, 'direction') and prop.columns[0].primary_key:
             if exclude_pk:
                 exclude.append(prop.key)
-        if hasattr(prop, 'direction') and  exclude_fk and \
+        if hasattr(prop, 'direction') and exclude_fk and \
                 prop.direction.name != 'MANYTOMANY':
             for pair in prop.local_remote_pairs:
                 exclude.append(pair[0].key)
     type_name = type_name or str(model.__name__ + 'Form')
-    field_dict = model_fields(model, db_session, only, exclude, field_args,
-        converter)
-    return type(type_name, (ModelForm, ), field_dict)
+    field_dict = model_fields(model, db_session, only, exclude, field_args, converter)
+    return type(type_name, (base_class, ), field_dict)