Commits

jean-philippe serafin  committed c201301

Improved ext.sqlalchemy that now deals with unique field & relationship

  • Participants
  • Parent commits e32e77e

Comments (0)

Files changed (4)

 ^docs/html
 ^dist/
 ^MANIFEST$
+^env/

File tests/ext_sqlalchemy.py

 #!/usr/bin/env python
 
-from sqlalchemy import create_engine
+from sqlalchemy import create_engine, ForeignKey
 from sqlalchemy.schema import MetaData, Table, Column
-from sqlalchemy.types import String, Integer
-from sqlalchemy.orm import scoped_session, sessionmaker
+from sqlalchemy.types import String, Integer, Date
+from sqlalchemy.orm import sessionmaker, relationship, backref
+from sqlalchemy.ext.declarative import declarative_base
 
 from unittest import TestCase
 
 from wtforms.ext.sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField
 from wtforms.form import Form
+from wtforms.fields import TextField
+from wtforms.ext.sqlalchemy.orm import model_form
+from wtforms.validators import Optional, Required, Length
+from wtforms.ext.sqlalchemy.validators import Unique
 
 
 class LazySelect(object):
     def _do_tables(self, mapper, engine):
         metadata = MetaData()
 
-        test_table = Table('test', metadata, 
+        test_table = Table('test', metadata,
             Column('id', Integer, primary_key=True, nullable=False),
             Column('name', String, nullable=False),
         )
 
-        pk_test_table = Table('pk_test', metadata, 
+        pk_test_table = Table('pk_test', metadata,
             Column('foobar', String, primary_key=True, nullable=False),
             Column('baz', String, nullable=False),
         )
         self.assert_(form.validate())
 
 
+class ModelFormTest(TestCase):
+    def setUp(self):
+        Model = declarative_base()
+
+        student_course = Table(
+            'student_course', Model.metadata,
+            Column('student_id', Integer, ForeignKey('student.id')),
+            Column('course_id', Integer, ForeignKey('course.id'))
+        )
+
+        class Course(Model):
+            __tablename__ = "course"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(255), nullable=False)
+
+        class School(Model):
+            __tablename__ = "school"
+            id = Column(Integer, primary_key=True)
+            name = Column(String(255), nullable=False)
+
+        class Student(Model):
+            __tablename__ = "student"
+            id = Column(Integer, primary_key=True)
+            full_name = Column(String(255), nullable=False, unique=True)
+            dob = Column(Date(), nullable=True)
+            current_school_id = Column(Integer, ForeignKey(School.id),
+                nullable=False)
+
+            current_school = relationship(School, backref=backref('students'))
+            courses = relationship("Course", secondary=student_course,
+                backref=backref("students", lazy='dynamic'))
+
+        self.School = School
+        self.Student = Student
+
+        engine = create_engine('sqlite:///:memory:', echo=False)
+        Session = sessionmaker(bind=engine)
+        self.metadata = Model.metadata
+        self.metadata.create_all(bind=engine)
+        self.sess = Session()
+
+    def test_nullable_field(self):
+        student_form = model_form(self.Student, self.sess)()
+        self.assertTrue(issubclass(Optional,
+            student_form._fields['dob'].validators[0].__class__))
+
+    def test_required_field(self):
+        student_form = model_form(self.Student, self.sess)()
+        self.assertTrue(issubclass(Required,
+            student_form._fields['full_name'].validators[0].__class__))
+
+    def test_unique_field(self):
+        student_form = model_form(self.Student, self.sess)()
+        self.assertTrue(issubclass(Unique,
+            student_form._fields['full_name'].validators[1].__class__))
+
+    def test_include_pk(self):
+        form_class = model_form(self.Student, self.sess, exclude_pk=False)
+        student_form = form_class()
+        self.assertIn('id', student_form._fields)
+
+    def test_exclude_pk(self):
+        form_class = model_form(self.Student, self.sess, exclude_pk=True)
+        student_form = form_class()
+        self.assertNotIn('id', student_form._fields)
+
+    def test_exclude_fk(self):
+        student_form = model_form(self.Student, self.sess)()
+        self.assertNotIn('current_school_id', student_form._fields)
+
+    def test_include_fk(self):
+        student_form = model_form(self.Student, self.sess, exclude_fk=False)()
+        self.assertIn('current_school_id', student_form._fields)
+
+    def test_convert_many_to_one(self):
+        student_form = model_form(self.Student, self.sess)()
+        self.assertTrue(issubclass(QuerySelectField,
+            student_form._fields['current_school'].__class__))
+
+    def test_convert_one_to_many(self):
+        school_form = model_form(self.School, self.sess)()
+        self.assertTrue(issubclass(QuerySelectMultipleField,
+            school_form._fields['students'].__class__))
+
+    def test_convert_many_to_many(self):
+        student_form = model_form(self.Student, self.sess)()
+        self.assertTrue(issubclass(QuerySelectMultipleField,
+            student_form._fields['courses'].__class__))
+
+
+class UniqueValidatorTest(TestCase):
+    def setUp(self):
+        Model = declarative_base()
+
+        class User(Model):
+            __tablename__ = "user"
+            id = Column(Integer, primary_key=True)
+            username = Column(String(255), nullable=False, unique=True)
+
+        engine = create_engine('sqlite:///:memory:', echo=False)
+        Session = sessionmaker(bind=engine)
+        self.metadata = Model.metadata
+        self.metadata.create_all(bind=engine)
+        self.sess = Session()
+
+        self.sess.add(User(username='batman'))
+        self.sess.commit()
+
+        class UserForm(Form):
+            username = TextField('Username', [
+                Length(min=4, max=25),
+                Unique(lambda: self.sess, User, User.username)
+            ])
+
+        self.UserForm = UserForm
+
+    def test_validate(self):
+        from werkzeug.datastructures import MultiDict
+        user_form = self.UserForm(formdata=MultiDict([('username',
+            'spiderman')]))
+        self.assertTrue(user_form.validate())
+
+    def test_wrong(self):
+        from werkzeug.datastructures import MultiDict
+        user_form = self.UserForm(formdata=MultiDict([('username',
+            'batman')]))
+        self.assertFalse(user_form.validate())
+
+
 if __name__ == '__main__':
     from unittest import main
     main()

File wtforms/ext/sqlalchemy/orm.py

 """
 import inspect
 
+import sqlalchemy
+
 from wtforms import fields as f
 from wtforms import validators
 from wtforms.form import Form
-
+from wtforms.ext.sqlalchemy.fields import QuerySelectField
+from wtforms.ext.sqlalchemy.fields import QuerySelectMultipleField
+from wtforms.ext.sqlalchemy.validators import Unique
 
 __all__ = (
     'model_fields', 'model_form',
 
         self.converters = converters
 
-    def convert(self, model, mapper, prop, field_args):
-        if not hasattr(prop, 'columns'):
-            # XXX We don't support anything but ColumnProperty at the moment.
+class ModelConverterBase(object):
+    def __init__(self, converters, use_mro=True):
+        self.use_mro = use_mro
+
+        if not converters:
+            converters = {}
+
+        for name in dir(self):
+            obj = getattr(self, name)
+            if hasattr(obj, '_converter_for'):
+                for classname in obj._converter_for:
+                    converters[classname] = obj
+
+        self.converters = converters
+
+    def convert(self, model, db_session, mapper, prop, field_args):
+        if not isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and \
+                not isinstance(prop,
+                sqlalchemy.orm.properties.RelationshipProperty):
             return
-        elif len(prop.columns) != 1:
-            raise TypeError('Do not know how to convert multiple-column properties currently')
-
-        column = prop.columns[0]
-
-        # Support sqlalchemy.schema.ColumnDefault, so users can benefit from
-        # setting defaults for fields, e.g.:
-        #   field = Column(DateTimeField, default=datetime.utcnow)
-
-        default = getattr(column, 'default', None)
-
-        if default is not None:
-            # Only actually change default if it has an attribute named
-            # 'arg' that's callable.
-            callable_default = getattr(default, 'arg', None)
-
-            if callable_default and callable(callable_default):
-                default = callable_default(None)
+        elif isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and\
+            len(prop.columns) != 1:
+            raise TypeError('Do not know how to convert multiple-column '
+                + 'properties currently')
 
         kwargs = {
             'validators': [],
             'filters': [],
-            'default': default,
+            'default': None,
         }
 
+        converter = None
+        column = None
+
+        if isinstance(prop, sqlalchemy.orm.properties.ColumnProperty):
+            column = prop.columns[0]
+            # Support sqlalchemy.schema.ColumnDefault, so users can benefit
+            # from  setting defaults for fields, e.g.:
+            #   field = Column(DateTimeField, default=datetime.utcnow)
+
+            default = getattr(column, 'default', None)
+
+            if default is not None:
+                # Only actually change default if it has an attribute named
+                # 'arg' that's callable.
+                callable_default = getattr(default, 'arg', None)
+
+                if callable_default and callable(callable_default):
+                    default = callable_default(None)
+            kwargs['default'] = default
+
+            if column.nullable:
+                kwargs['validators'].append(validators.Optional())
+            else:
+                kwargs['validators'].append(validators.Required())
+
+            if column.unique:
+                kwargs['validators'].append(Unique(lambda: db_session, model,
+                    column))
+
+            if self.use_mro:
+                types = inspect.getmro(type(column.type))
+            else:
+                types = [type(column.type)]
+
+            for col_type in types:
+                type_string = '%s.%s' % (col_type.__module__,
+                    col_type.__name__)
+                if type_string.startswith('sqlalchemy'):
+                    type_string = type_string[11:]
+
+                if type_string in self.converters:
+                    converter = self.converters[type_string]
+                    break
+            else:
+                for col_type in types:
+                    if col_type.__name__ in self.converters:
+                        converter = self.converters[col_type.__name__]
+                        break
+                else:
+                    return
+
+        if isinstance(prop, sqlalchemy.orm.properties.RelationshipProperty):
+            foreign_model = prop.mapper.class_
+
+            nullable = True
+            for pair in prop.local_remote_pairs:
+                if not pair[0].nullable:
+                    nullable = False
+
+            kwargs.update({
+                'allow_blank': nullable,
+                'query_factory': lambda: db_session.query(foreign_model).all()
+            })
+
+            converter = self.converters[prop.direction.name]
+
         if field_args:
             kwargs.update(field_args)
 
-        if column.nullable:
-            kwargs['validators'].append(validators.Optional())
-
-        if self.use_mro:
-            types = inspect.getmro(type(column.type))
-        else:
-            types = [type(column.type)]
-
-        converter = None
-        for col_type in types:
-            type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
-            if type_string.startswith('sqlalchemy'):
-                type_string = type_string[11:]
-
-            if type_string in self.converters:
-                converter = self.converters[type_string]
-                break
-        else:
-            for col_type in types:
-                if col_type.__name__ in self.converters:
-                    converter = self.converters[col_type.__name__]
-                    break
-            else:
-                return
-        return converter(model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs)
+        return converter(model=model, mapper=mapper, prop=prop, column=column,
+            field_args=kwargs)
 
 
 class ModelConverter(ModelConverterBase):
         field_args['validators'].append(validators.IPAddress())
         return f.TextField(**field_args)
 
+    @converts('MANYTOMANY', 'ONETOMANY')
+    def conv_ManyToMany(self, field_args, **extra):
+        return QuerySelectMultipleField(**field_args)
 
-def model_fields(model, only=None, exclude=None, field_args=None, converter=None):
+    @converts('MANYTOONE')
+    def conv_ManyToOne(self, field_args, **extra):
+        return QuerySelectField(**field_args)
+
+
+def model_fields(model, db_session, only=None, exclude=None, field_args=None,
+    converter=None):
     """
     Generate a dictionary of fields for a given SQLAlchemy model.
 
 
     field_dict = {}
     for name, prop in properties:
-        field = converter.convert(model, mapper, prop, field_args.get(name))
+        field = converter.convert(model, db_session, mapper, prop,
+            field_args.get(name))
         if field is not None:
             field_dict[name] = field
 
     return field_dict
 
 
-def model_form(model, base_class=Form, only=None, exclude=None, field_args=None, converter=None):
+def model_form(model, db_session, base_class=Form, only=None, exclude=None,
+    field_args=None, converter=None, exclude_pk=True, exclude_fk=True,
+    type_name=None):
     """
     Create a wtforms Form for a given SQLAlchemy model class::
 
-        from wtforms.ext.sqlalchemy.orm import model_form
+        from wtalchemy.orm import model_form
         from myapp.models import User
         UserForm = model_form(User)
 
     :param model:
         A SQLAlchemy mapped model class.
+    :param db_session:
+        A SQLAlchemy Session.
     :param base_class:
         Base form class to extend from. Must be a ``wtforms.Form`` subclass.
     :param only:
     :param converter:
         A converter to generate the fields based on the model properties. If
         not set, ``ModelConverter`` is used.
+    :param exclude_pk:
+        An optional boolean to force primary key exclusion.
+    :param exclude_fk:
+        An optional boolean to force foreign keys exclusion.
+    :param type_name:
+        An optional string to set returned type name.
     """
-    field_dict = model_fields(model, only, exclude, field_args, converter)
-    return type(model.__name__ + 'Form', (base_class, ), field_dict)
+    class ModelForm(base_class):
+        """Sets object as form attribute."""
+        def __init__(self, *args, **kwargs):
+            if 'obj' in kwargs:
+                self._obj = kwargs['obj']
+            super(ModelForm, self).__init__(*args, **kwargs)
+
+    if not exclude:
+        exclude = []
+    model_mapper = model.__mapper__
+    for prop in model_mapper.iterate_properties:
+        if isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and \
+               prop.columns[0].primary_key:
+            if exclude_pk:
+                exclude.append(prop.key)
+        if isinstance(prop, sqlalchemy.orm.properties.RelationshipProperty) \
+            and  exclude_fk and prop.direction.name != 'MANYTOMANY':
+                for pair in prop.local_remote_pairs:
+                    exclude.append(pair[0].key)
+    type_name = type_name or model.__name__ + 'Form'
+    field_dict = model_fields(model, db_session, only, exclude, field_args,
+        converter)
+    return type(type_name, (ModelForm, ), field_dict)

File wtforms/ext/sqlalchemy/validators.py

+from wtforms import ValidationError
+from sqlalchemy.orm.exc import NoResultFound
+
+
+class Unique(object):
+    """Checks field value unicity against specified table field.
+
+    :param get_session:
+        A function that return a SQAlchemy Session.
+    :param model:
+        The model to check unicity against.
+    :param column:
+        The unique column.
+    :param message:
+        The error message.
+    """
+    field_flags = ('unique', )
+
+    def __init__(self, get_session, model, column, message=None):
+        self.get_session = get_session
+        self.model = model
+        self.column = column
+        self.message = message
+
+    def __call__(self, form, field):
+        try:
+            obj = self.get_session().query(self.model)\
+                .filter(self.column == field.data).one()
+            if not hasattr(form, '_obj') or not form._obj == obj:
+                if self.message is None:
+                    self.message = field.gettext(u'Already exists.')
+                raise ValidationError(self.message)
+        except NoResultFound:
+            pass