wtforms / wtforms / ext / sqlalchemy / orm.py

"""
Tools for generating forms based on SQLAlchemy models.
"""
import inspect

import sqlalchemy

from wtforms import fields as f
from wtforms import validators
from wtforms.form import Form
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.ext.sqlalchemy.fields import QuerySelectMultipleField
from wtforms.ext.sqlalchemy.validators import Unique

__all__ = (
    'model_fields', 'model_form',
)

def converts(*args):
    def _inner(func):
        func._converter_for = frozenset(args)
        return func
    return _inner

class ModelConverterBase(object):
    def __init__(self, converters, use_mro=True):
        self.use_mro = use_mro

        if not converters:
            converters = {}

        for name in dir(self):
            obj = getattr(self, name)
            if hasattr(obj, '_converter_for'):
                for classname in obj._converter_for:
                    converters[classname] = obj

        self.converters = converters

class ModelConverterBase(object):
    def __init__(self, converters, use_mro=True):
        self.use_mro = use_mro

        if not converters:
            converters = {}

        for name in dir(self):
            obj = getattr(self, name)
            if hasattr(obj, '_converter_for'):
                for classname in obj._converter_for:
                    converters[classname] = obj

        self.converters = converters

    def convert(self, model, mapper, prop, field_args, db_session=None):
        if not isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and \
                not isinstance(prop,
                sqlalchemy.orm.properties.RelationshipProperty):
            return
        elif isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and\
            len(prop.columns) != 1:
            raise TypeError('Do not know how to convert multiple-column '
                + 'properties currently')

        kwargs = {
            'validators': [],
            'filters': [],
            'default': None,
        }

        converter = None
        column = None

        if isinstance(prop, sqlalchemy.orm.properties.ColumnProperty):
            column = prop.columns[0]
            # Support sqlalchemy.schema.ColumnDefault, so users can benefit
            # from  setting defaults for fields, e.g.:
            #   field = Column(DateTimeField, default=datetime.utcnow)

            default = getattr(column, 'default', None)

            if default is not None:
                # Only actually change default if it has an attribute named
                # 'arg' that's callable.
                callable_default = getattr(default, 'arg', None)

                if callable_default and callable(callable_default):
                    default = callable_default(None)
            kwargs['default'] = default

            if column.nullable:
                kwargs['validators'].append(validators.Optional())
            else:
                kwargs['validators'].append(validators.Required())

            if db_session and column.unique:
                kwargs['validators'].append(Unique(lambda: db_session, model,
                    column))

            if self.use_mro:
                types = inspect.getmro(type(column.type))
            else:
                types = [type(column.type)]

            for col_type in types:
                type_string = '%s.%s' % (col_type.__module__,
                    col_type.__name__)
                if type_string.startswith('sqlalchemy'):
                    type_string = type_string[11:]

                if type_string in self.converters:
                    converter = self.converters[type_string]
                    break
            else:
                for col_type in types:
                    if col_type.__name__ in self.converters:
                        converter = self.converters[col_type.__name__]
                        break
                else:
                    return

        if db_session and isinstance(prop, sqlalchemy.orm.properties.RelationshipProperty):
            foreign_model = prop.mapper.class_

            nullable = True
            for pair in prop.local_remote_pairs:
                if not pair[0].nullable:
                    nullable = False

            kwargs.update({
                'allow_blank': nullable,
                'query_factory': lambda: db_session.query(foreign_model).all()
            })

            converter = self.converters[prop.direction.name]

        if field_args:
            kwargs.update(field_args)

        return converter(model=model, mapper=mapper, prop=prop, column=column,
            field_args=kwargs)


class ModelConverter(ModelConverterBase):
    def __init__(self, extra_converters=None):
        super(ModelConverter, self).__init__(extra_converters)

    @classmethod
    def _string_common(cls, column, field_args, **extra):
        if column.type.length:
            field_args['validators'].append(validators.Length(max=column.type.length))

    @converts('String', 'Unicode')
    def conv_String(self, field_args, **extra):
        self._string_common(field_args=field_args, **extra)
        return f.TextField(**field_args)

    @converts('Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary')
    def conv_Text(self, field_args, **extra):
        self._string_common(field_args=field_args, **extra)
        return f.TextAreaField(**field_args)

    @converts('Boolean')
    def conv_Boolean(self, field_args, **extra):
        return f.BooleanField(**field_args)

    @converts('Date')
    def conv_Date(self, field_args, **extra):
        return f.DateField(**field_args)

    @converts('DateTime')
    def conv_DateTime(self, field_args, **extra):
        return f.DateTimeField(**field_args)

    @converts('Integer', 'SmallInteger')
    def handle_integer_types(self, column, field_args, **extra):
        unsigned = getattr(column.type, 'unsigned', False)
        if unsigned:
            field_args['validators'].append(validators.NumberRange(min=0))
        return f.IntegerField(**field_args)

    @converts('Numeric', 'Float')
    def handle_decimal_types(self, column, field_args, **extra):
        places = getattr(column.type, 'scale', 2)
        if places is not None:
            field_args['places'] = places
        return f.DecimalField(**field_args)

    @converts('databases.mysql.MSYear')
    def conv_MSYear(self, field_args, **extra):
        field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
        return f.TextField(**field_args)

    @converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET')
    def conv_PGInet(self, field_args, **extra):
        field_args.setdefault('label', u'IP Address')
        field_args['validators'].append(validators.IPAddress())
        return f.TextField(**field_args)

    @converts('dialects.postgresql.base.MACADDR')
    def conv_PGMacaddr(self, field_args, **extra):
        field_args.setdefault('label', u'MAC Address')
        field_args['validators'].append(validators.MacAddress())
        return f.TextField(**field_args)

    @converts('dialects.postgresql.base.UUID')
    def conv_PGUuid(self, field_args, **extra):
        field_args.setdefault('label', u'UUID')
        field_args['validators'].append(validators.UUID())
        return f.TextField(**field_args)

    @converts('MANYTOONE')
    def conv_ManyToOne(self, field_args, **extra):
        return QuerySelectField(**field_args)

    @converts('MANYTOMANY', 'ONETOMANY')
    def conv_ManyToMany(self, field_args, **extra):
        return QuerySelectMultipleField(**field_args)


def model_fields(model, db_session=None, only=None, exclude=None,
    field_args=None, converter=None):
    """
    Generate a dictionary of fields for a given SQLAlchemy model.

    See `model_form` docstring for description of parameters.
    """
    if not hasattr(model, '_sa_class_manager'):
        raise TypeError('model must be a sqlalchemy mapped model')

    mapper = model._sa_class_manager.mapper
    converter = converter or ModelConverter()
    field_args = field_args or {}

    properties = ((p.key, p) for p in mapper.iterate_properties)
    if only:
        properties = (x for x in properties if x[0] in only)
    elif exclude:
        properties = (x for x in properties if x[0] not in exclude)

    field_dict = {}
    for name, prop in properties:
        field = converter.convert(model, db_session, mapper, prop,
            field_args.get(name))
        if field is not None:
            field_dict[name] = field

    return field_dict


def model_form(model, db_session=None, base_class=Form, only=None,
    exclude=None, field_args=None, converter=None, exclude_pk=True,
    exclude_fk=True, type_name=None):
    """
    Create a wtforms Form for a given SQLAlchemy model class::

        from wtalchemy.orm import model_form
        from myapp.models import User
        UserForm = model_form(User)

    :param model:
        A SQLAlchemy mapped model class.
    :param db_session:
        An optional SQLAlchemy Session.
    :param base_class:
        Base form class to extend from. Must be a ``wtforms.Form`` subclass.
    :param only:
        An optional iterable with the property names that should be included in
        the form. Only these properties will have fields.
    :param exclude:
        An optional iterable with the property names that should be excluded
        from the form. All other properties will have fields.
    :param field_args:
        An optional dictionary of field names mapping to keyword arguments used
        to construct each field object.
    :param converter:
        A converter to generate the fields based on the model properties. If
        not set, ``ModelConverter`` is used.
    :param exclude_pk:
        An optional boolean to force primary key exclusion.
    :param exclude_fk:
        An optional boolean to force foreign keys exclusion.
    :param type_name:
        An optional string to set returned type name.
    """
    class ModelForm(base_class):
        """Sets object as form attribute."""
        def __init__(self, *args, **kwargs):
            if 'obj' in kwargs:
                self._obj = kwargs['obj']
            super(ModelForm, self).__init__(*args, **kwargs)

    if not exclude:
        exclude = []
    model_mapper = model.__mapper__
    for prop in model_mapper.iterate_properties:
        if isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and \
               prop.columns[0].primary_key:
            if exclude_pk:
                exclude.append(prop.key)
        if isinstance(prop, sqlalchemy.orm.properties.RelationshipProperty) \
            and  exclude_fk and prop.direction.name != 'MANYTOMANY':
                for pair in prop.local_remote_pairs:
                    exclude.append(pair[0].key)
    type_name = type_name or model.__name__ + 'Form'
    field_dict = model_fields(model, db_session, only, exclude, field_args,
        converter)
    return type(type_name, (ModelForm, ), field_dict)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.