Source

django-sniplates / sniplates / templatetags / form.py

from django import template
from django.template.loader import get_template
from django.template.loader_tags import BlockNode, ExtendsNode, BlockContext

register = template.Library()

def resolve_dict(d, context):
    return dict(
        (key, val.resolve(context))
        for key, val in d.items()
    )

def resolve_blocks(template, context, blocks =None):
    '''Get all the blocks from this template, accounting for 'extends' tags'''
    if blocks is None:
        blocks = BlockContext()

    if isinstance(template, basestring):
        template = get_template(template)

    # Add this templates blocks as the first
    local_blocks = dict(
        (block.name, block)
        for block in template.nodelist.get_nodes_by_type(BlockNode)
    )
    blocks.add_blocks(local_blocks)

    # Do we extend a parent template?
    extends = template.nodelist.get_nodes_by_type(ExtendsNode)
    if extends:
        # Can only have one extends in a template
        extends_node = extends[0]

        # Get the parent, and recurse
        parent_template = extends_node.get_parent(context)
        resolve_blocks(parent_template, context, blocks)

    return blocks



@register.filter
def append(value, extra):
    return u''.join([unicode(value), unicode(extra)])

@register.tag
def form(parser, token):
    '''Open a context providing blocks from the named sniplate'''
    bits = token.split_contents()
    if len(bits) < 3:
        raise template.TemplateSyntaxError("%r tag takes at least 2 arguments: form, and the widget library" % bits[0])

    options = {}
    bits.pop(0)
    form = parser.compile_filter(bits.pop(0))
    sniplate_lib = parser.compile_filter(bits.pop(0))

    for bit in bits:
        if '=' in bit:
            opt, val = bit.split('=', 1)
            if opt in options:
                raise template.TemplateSyntaxError(
                    "The %r option was specified more than once." % opt
                )
            options[opt] = parser.compile_filter(val)
        elif bit in []: # other flags
            pass
        else:
            raise template.TemplateSyntaxError("Unexpected option: %r" % bit)

    return FormNode(sniplate_lib, form, options)

class FormNode(template.Node):
    def __init__(self, lib, form, options):
        self.lib = lib
        self.form = form
        self.options = options

    def get_template_for_field(self, field):
        field_name = field.name
        field_type = field.field.__class__.__name__
        # Field Type and Field name
        template = self.templates.get_block('__'.join([field_type, field_name]))
        if template:
            return template
        # Field name
        template = self.templates.get_block(field_name)
        if template:
            return template
        # Field type
        template =  self.templates.get_block(field_type)
        if template:
            return template

        # Default fallback
        return self.templates.get_block('field')

    def render(self, context):
        lib = self.lib.resolve(context)
        options = resolve_dict(self.options, context)
        self.form = self.form.resolve(context)

        # Grab the template snippets
        self.templates = resolve_blocks(lib, context)

        # Find the parent snippet
        template_name = options.pop('form', 'form')
        template = self.templates.get_block(template_name)

        # Add ourself to the context
        options.update({
            'form_tag': self,
            'form': self.form,
            'used_fields': set(),
        })

        context.update(options)
        output = template.render(context)
        context.pop()

        return output

@register.tag
def use(parser, token):
    '''A simple helper for including blocks from the same template'''
    bits = token.split_contents()
    bits.pop(0)
    widget = parser.compile_filter(bits.pop(0))

    options = {}
    for bit in bits:
        if '=' in bit:
            key, val = bit.split('=', 1)
            options[key] = parser.compile_filter(val)
        else:
            raise template.TemplateSyntaxError(
                "Unexpected positional argument option: %r" % bit
            )

    return UseNode(widget, options)

class UseNode(template.Node):
    def __init__(self, widget):
        self.widget = widget
    def render(self, context):
        widget = self.widget.resolve(context)
        options = resolve_dict(self.options, context)
        context.update(options)
        output = context['form_tag'].templates.get_block(widget).render(context)
        context.pop()
        return output

@register.tag
def field(parser, token):
    '''{% field name [widget="..."] %}'''
    bits = token.split_contents()
    bits.pop(0)
    field = parser.compile_filter(bits.pop(0))

    options = {}
    for bit in bits:
        if '=' in bit:
            key, val = bit.split('=', 1)
            options[key] = parser.compile_filter(val)
        else:
            raise template.TemplateSyntaxError(
                "Unexpected positional argument option: %r" % bit
            )

    return FieldNode(field, **options)

class FieldNode(template.Node):
    def __init__(self, field, **options):
        self.field = field
        self.options = options

    def render(self, context):
        field = self.field.resolve(context)
        options = resolve_dict(self.options, context)
        form_tag = context['form_tag']

        if isinstance(field, basestring):
            field = context['form'][field]

        context['used_fields'].add(field.name)

        widget = options.pop('widget', None)
        if widget:
            tmpl = form_tag.templates.get_block(widget)
        else:
            tmpl = form_tag.get_template_for_field(field)

        # Update the context
        options['field'] = field
        context.update(options)

        # Render
        output = tmpl.render(context)
        context.pop()

        return output

class FieldsNode(template.Node):
    def __init__(self, fields, **options):
        self.fields = fields
        self.options = options

    def get_field(self, field):
        field = field.resolve(self.context)
        if isinstance(field, basestring):
            return self.context['form'][field]
        return field

    def render(self, context):
        self.context = context
        if not self.fields:
            fields = [
                field
                for field in context['form']
                if not field.name in context['used_fields']
            ]
        else:
            fields = [ self.get_field(field) for field in self.fields ]
        options = resolve_dict(self.options, context)