Source

django-customfields / dcfields / inheritedfield.py

Full commit
import logging
from types import TupleType
logger = logging.getLogger(__name__)

from django.core.exceptions import ObjectDoesNotExist
from django.db.models import Model, Field, BooleanField, Manager, ManyToManyField
from django.db.models.query import QuerySet
from django.db.models.fields.related import RelatedField, add_lazy_relation, \
                        ReverseManyRelatedObjectsDescriptor, ManyToManyField, \
                        ForeignKey
from django.db.models.base import ModelBase
from django.db.models import signals

from django.utils.functional import curry

import copy
from collections import defaultdict

INHERITED_FLAG_FIELD_NAME = "is_%s_inherited"
INHERITED_VALUE_FIELD_NAME = "%s_inherited_value"
LOCAL_VALUE_FIELD_NAME = "%s_local_value"

__all__ = (
    'INHERITED_FLAG_FIELD_NAME', 'INHERITED_VALUE_FIELD_NAME',
    'LOCAL_VALUE_FIELD_NAME', 'InheritedOnlyException',
    'InheritedField', 'find_in_parent', 'find_on_model'
)

def log(class_name, msg, *args, **kwargs):
    if class_name in ('TestModelChild', 'Program', 'Producer', 'TestModelParent'):
        logger.debug(msg, *args, **kwargs)

class InheritedOnlyException(Exception):
    pass

class InheritedField(Field):
    """
    This field will:

        - check if the model refered by ``parent_name`` contains ``field_name``
          and raise according exceptions if `validate=True` (default).
        - copy the field from the parent, or the parent's parent if the field in
          the parent has is an InheritedField too if `inherit_only=False`
          (default)
        - create a boolean flag ``is_<fieldname>_inherited`` in the current
          model.

    This field implements a descriptor interface so it will work like a property
    with getters and setters.

        - get will return the value from the parent if
          `is_{fieldname}_inherited` is `True`
        - set will raise `InheritedOnlyException` if the field is `inherit_only`
        - set will save the value in `{fieldname}_value` and set the
          `is_{fieldname}_inherited` flag accordingly
    """
    def __init__(self, parent_name, field_name=None, inherit_only=False, validate=True, sync=False):
        super(InheritedField, self).__init__()

        self.parent_object_field_name = parent_name
        self.inherited_field_name_in_parent = field_name
        self.inherit_only = inherit_only
        self.validate = validate
        self.sync = sync

    def get_field_display(self, instance, name):
        if self.inherit_only or getattr(instance, self.inherited_flag_name):
            if self.sync:
                value = getattr(instance, INHERITED_VALUE_FIELD_NAME % name, None)
            else:
                rel = getattr(instance, self.parent_object_field_name)
                value = getattr(rel, self.inherited_field_name_in_parent or name)
            return u"%s *Inherited" % value
        else:
            return getattr(instance, LOCAL_VALUE_FIELD_NAME % name)

    def contribute_to_class(self, cls, name):
        self.name = self.attname = name
        cls._meta.add_virtual_field(self)

        self.inherited_flag_name = INHERITED_FLAG_FIELD_NAME % name
        self.inherited_value_field_name = INHERITED_VALUE_FIELD_NAME % name
        self.local_value_field_name = LOCAL_VALUE_FIELD_NAME % name

        if not self.inherit_only:
            flag_field = BooleanField(default=True)
            flag_field.creation_counter = self.creation_counter

            # Adjust the creation_counter
            # cls.add_to_class(self.inherited_flag_name, flag_field)
            flag_field.contribute_to_class(cls, self.inherited_flag_name)

            signals.class_prepared.connect(
                curry(self.add_value_field, name=name),
                sender=cls,
                weak=False
            )

        setattr(cls, name, self)
        display_name = 'get_%s_display' % name
        setattr(cls, display_name, curry(self.get_field_display, name=name))
        getattr(cls, display_name).__dict__['short_description'] = name.replace('_', ' ')

        if not hasattr(cls, 'FIELD_INHERITANCE_MAP'):
            cls.FIELD_INHERITANCE_MAP = {}

        cls.FIELD_INHERITANCE_MAP[name] = (self.parent_object_field_name, self.inherited_field_name_in_parent or name, self.sync)
        if not self.sync:
            signals.class_prepared.connect(self.patch_manager, sender=cls)

    def patch_manager(self, sender, **kwargs):
        if not hasattr(sender.objects, 'original_get_query_set'):
            _get_query_set = sender.objects.get_query_set
            def get_query_set(qs):
                model = qs.model

                if hasattr(model, 'FIELD_INHERITANCE_REL'):
                    related = model.FIELD_INHERITANCE_REL
                else:
                    related = set()
                    for field, (parent, target_field, sync) in model.FIELD_INHERITANCE_MAP.iteritems():
                        if not sync:
                            chain = []
                            find_in_parent(None, model, parent, target_field, validate=False, chain=chain)
                            related.add('__'.join(chain))
                    model.FIELD_INHERITANCE_REL = related

                return _get_query_set().select_related(*related)

            sender.objects.original_get_query_set = _get_query_set
            sender.objects.get_query_set = get_query_set.__get__(sender.objects)
            logger.debug("Patching %s's get_query_set method to slap a select_related on the returned qs.", sender.__name__)

    def pre_save_hook(self, sender, **kwargs):
        logger.debug("Entering pre_save_hook")
        logger.debug(sender.REVERSE_FIELD_INHERITANCE_MAP)
        for name in sender.REVERSE_FIELD_INHERITANCE_MAP:
            logger.debug("Inheritance map: %s", name)
            # TODO
        logger.debug("Exitting pre_save_hook")

    def add_value_field(self, sender, name=None, robust=True, **kwargs):
        def contribute(cls, field):
            if not hasattr(cls, 'REVERSE_FIELD_INHERITANCE_MAP'): #TODO: test
                cls.REVERSE_FIELD_INHERITANCE_MAP = {}

            cls.REVERSE_FIELD_INHERITANCE_MAP[self.inherited_field_name_in_parent or name] = sender
            if self.sync:
                signals.pre_save.connect(self.pre_save_hook, sender=cls)

            if isinstance(field, ReverseManyRelatedObjectsDescriptor):
                field = field.field

            logger.debug(
                "Cloning field %s for class %s",
                field, sender.__name__
            )

            # local value field
            xfield = copy.deepcopy(field)
            xfield.blank = True
            if isinstance(xfield, (ManyToManyField, ForeignKey)):
                xfield.rel.through = None
                xfield.rel.related_name = '%s_inherited_set' % cls._meta.object_name

            xfield.creation_counter = self.creation_counter
            xfield.contribute_to_class(sender, INHERITED_VALUE_FIELD_NAME % name)

            # inherited value field
            yfield = copy.deepcopy(field)
            yfield.blank = True
            if isinstance(yfield, (ManyToManyField, ForeignKey)):
                yfield.rel.through = None
                yfield.rel.related_name = '%s_local_set' % cls._meta.object_name

            yfield.creation_counter = self.creation_counter
            yfield.contribute_to_class(sender, LOCAL_VALUE_FIELD_NAME % name)

        result = find_in_parent(
            None, sender,
            self.parent_object_field_name,
            self.inherited_field_name_in_parent or name,
            robust and self.validate,
            callback=contribute
        )
        if result:
            parent, value_field = result
            contribute(parent, value_field)

    def __get__(self, instance, instance_type=None):
        if self.inherit_only or getattr(instance, self.inherited_flag_name):
            if self.sync:
                return getattr(instance, self.inherited_value_field_name, None)
            else:
                rel = getattr(instance, self.parent_object_field_name)
                if rel:
                    return getattr(rel, self.inherited_field_name_in_parent or self.name)
        return getattr(instance, self.local_value_field_name, None)

    def __set__(self, instance, value):
        if self.inherit_only:
            raise InheritedOnlyException(
                "Can't set value for field %s on %s (field is inherit_only). Try to set it on %s.%s." %
                (self.name, instance, self.parent_object_field_name, self.inherited_field_name_in_parent or self.name))
        try:
            rel = getattr(instance, self.parent_object_field_name)
            if rel:
                parent_value = getattr(rel, self.inherited_field_name_in_parent or self.name)
                setattr(instance, self.inherited_flag_name, value == parent_value)
            else:
                setattr(instance, self.inherited_flag_name, False)
        except ObjectDoesNotExist:
            setattr(instance, self.inherited_flag_name, False)
        setattr(instance, self.local_value_field_name, value)

def find_on_model(origin, model, field_name, validate=True, callback=None, chain=None):
    target_fields = [
        target for target in model._meta.fields
            if target.name == field_name
    ]

    if target_fields:
        return target_fields[0]

    if hasattr(model, field_name):
        return getattr(model, field_name)

    if hasattr(model, 'FIELD_INHERITANCE_MAP'):
        MAP = getattr(model, 'FIELD_INHERITANCE_MAP')
        if field_name in MAP:
            rel, field, sync = MAP[field_name]
            result = find_in_parent(origin, model, rel, field, validate, callback, chain)
            if result:
                parent_model, field = result
                return field
        elif validate:
            raise TypeError("InheritedField: %s does not exist in %s." %
                                (field_name, model))


    if validate:
        raise TypeError("InheritedField: %s does not exist in %s." %
                                    (field_name, model))

def find_in_parent(origin, model_class, relation_name, field_name, validate=True, callback=None, chain=None):
    """
    This function will take a model class and search for the relation so that:

        - if `validate` is `True` will raise `TypeError`'s if the field isn't
          found in the parent.
        - will call `callback` if there is something found

    Note that this will not always return the field instance as the field may be
    on a uninstantiated model class. Use `callback` to do stuff with the field.
    """
    for ifield in model_class._meta.fields:
        if ifield.name == relation_name:
            if not chain is None:
                chain.append(relation_name)
            if isinstance(ifield, RelatedField):
                if isinstance(ifield.rel.to, basestring):
                    # the model class isn't instantiated yet so we need to add a
                    # hook
                    def resolve_related_class(xfield, xmodel, xcls):
                        xfield.rel.to = xmodel
                        field_instance = find_on_model(origin or xmodel, xmodel, field_name, validate, callback, chain)
                        callback(origin or xmodel, field_instance)

                    add_lazy_relation(model_class, ifield, ifield.rel.to,
                                        resolve_related_class)
                    return
                else:
                    return (ifield.rel.to, find_on_model(origin or ifield.rel.to, ifield.rel.to, field_name, validate, callback, chain))
            else:
                if validate:
                    raise TypeError(
                        "InheritedField: %s is a %s instead of a RelatedField."%
                        (relation_name, type(ifield)))
                else:
                    return
    if validate:
        raise TypeError("InheritedField: %s does not exist on %s." %
                        (relation_name, model_class))


class InheritedFieldQuerySet(QuerySet):
    def is_inherited(self, parts):
        _parts = parts[:]
        last_field_name, model = self.traverse_models(_parts, self.model)
        return hasattr(model, 'FIELD_INHERITANCE_MAP') and \
               last_field_name in model.FIELD_INHERITANCE_MAP

    def traverse_models(self, parts, model):
        next_part = parts.pop(0)
        if(parts):
            next_model = self.model._meta.get_field(next_part).rel.to
            return self.traverse_models(parts, next_model)
        return (next_part, model)

    def split_field(self, field):
        parts = field.split(LOOKUP_SEP)
        lookup = None
        if parts[-1] in QUERY_TERMS:
            lookup = parts.pop(-1)
        return (parts,lookup)
        
    def patch_child(self, parts, lookup, value):
        inherited_flag = ["is_%s_inherited" % parts[-1]]
        inherited_value = ["%s_value" % parts[-1]]
        lookup = [lookup] if lookup else []
        is_inherited = LOOKUP_SEP.join(parts[:-1] + inherited_flag)
        field = LOOKUP_SEP.join(parts[:-1] + inherited_value + lookup)
        print is_inherited
        print field
        return Q(**{is_inherited: False, field: value})

    def patch_parent(self, parts, lookup, value):
        _parts = parts[:]
        last_field_name, model = self.traverse_models(_parts, self.model)
        lookup = [lookup] if lookup else []
        parents_name, parents_field_name = model.FIELD_INHERITANCE_MAP[last_field_name]
        parent_lookup = LOOKUP_SEP.join(parts[:-1] + [parents_name] + [parents_field_name] + lookup)
        print parent_lookup
        return Q(**{parent_lookup: value})

    def patch(self, parts, lookup, value):
        child = self.patch_child(parts, lookup, value)
        parent = self.patch_parent(parts, lookup, value)
        return child | parent

    def filter(self, *args, **kwargs):
        # We don't support Q objects, bail out if any
        assert not args
        args = []
        old_kwargs = kwargs.copy()
        for field, value in kwargs.items():
            parts, lookup = self.split_field(field)
            if self.is_inherited(parts):
                old_kwargs.pop(field)
                args.append(self.patch(parts, lookup, value))
        return super(InheritedFieldQuerySet, self).filter(*args, **old_kwargs)

class InheritedFieldManager(Manager):
    def get_query_set(self):
        return InheritedFieldQuerySet(self.model, using=self._db)