Commits

Ionel Cristian Mărieș committed c559a93

Reverted sync implementation.

Comments (0)

Files changed (1)

dcfields/inheritedfield.py

 import logging
-from types import TupleType
 logger = logging.getLogger(__name__)
 
 from django.core.exceptions import ObjectDoesNotExist
-from django.db.models import Model, Field, BooleanField, Manager, ManyToManyField
+from django.db.models import Model, Field, BooleanField, Manager, ManyToManyField, Q
 from django.db.models.query import QuerySet
 from django.db.models.fields.related import RelatedField, add_lazy_relation, \
-                        ReverseManyRelatedObjectsDescriptor, ManyToManyField, \
-                        ForeignKey
+                        ReverseManyRelatedObjectsDescriptor, ManyToManyField
 from django.db.models.base import ModelBase
 from django.db.models import signals
+from django.db.models.sql.constants import LOOKUP_SEP, QUERY_TERMS
 
 from django.utils.functional import curry
 
 import copy
 from collections import defaultdict
 
-INHERITED_FLAG_FIELD_NAME = "is_%s_inherited"
-INHERITED_VALUE_FIELD_NAME = "%s_inherited_value"
-LOCAL_VALUE_FIELD_NAME = "%s_local_value"
+INHERIT_FLAG_NAME = "is_%s_inherited"
+VALUE_FIELD_NAME = "%s_value"
 
 __all__ = (
-    'INHERITED_FLAG_FIELD_NAME', 'INHERITED_VALUE_FIELD_NAME',
-    'LOCAL_VALUE_FIELD_NAME', 'InheritedOnlyException',
+    'INHERIT_FLAG_NAME', 'VALUE_FIELD_NAME', 'InheritedOnlyException',
     'InheritedField', 'find_in_parent', 'find_on_model'
 )
 
-def log(class_name, msg, *args, **kwargs):
-    if class_name in ('TestModelChild', 'Program', 'Producer', 'TestModelParent'):
-        logger.debug(msg, *args, **kwargs)
-
 class InheritedOnlyException(Exception):
     pass
 
 class InheritedField(Field):
     """
     This field will:
-
+    
         - check if the model refered by ``parent_name`` contains ``field_name``
           and raise according exceptions if `validate=True` (default).
         - copy the field from the parent, or the parent's parent if the field in
           (default)
         - create a boolean flag ``is_<fieldname>_inherited`` in the current
           model.
-
+          
     This field implements a descriptor interface so it will work like a property
     with getters and setters.
 
         - set will save the value in `{fieldname}_value` and set the
           `is_{fieldname}_inherited` flag accordingly
     """
-    def __init__(self, parent_name, field_name=None, inherit_only=False, validate=True, sync=False):
+    def __init__(self, parent_name, field_name=None, inherit_only=False, validate=True):
         super(InheritedField, self).__init__()
-
+        
         self.parent_object_field_name = parent_name
         self.inherited_field_name_in_parent = field_name
         self.inherit_only = inherit_only
         self.validate = validate
-        self.sync = sync
 
     def get_field_display(self, instance, name):
-        if self.inherit_only or getattr(instance, self.inherited_flag_name):
-            if self.sync:
-                value = getattr(instance, INHERITED_VALUE_FIELD_NAME % name, None)
-            else:
-                rel = getattr(instance, self.parent_object_field_name)
-                value = getattr(rel, self.inherited_field_name_in_parent or name)
-            return u"%s *Inherited" % value
+        if self.inherit_only or getattr(instance, self.inherit_flag_name):
+            rel = getattr(instance, self.parent_object_field_name)
+            pname = self.inherited_field_name_in_parent or name
+            displayfname = "get_%s_display" % pname
+            return u"%s *Inherited" % (
+                getattr(rel, displayfname)() 
+                if hasattr(rel, displayfname) 
+                else getattr(rel, pname)
+            )
         else:
-            return getattr(instance, LOCAL_VALUE_FIELD_NAME % name)
+            return getattr(instance, VALUE_FIELD_NAME % name)
 
     def contribute_to_class(self, cls, name):
         self.name = self.attname = name
         cls._meta.add_virtual_field(self)
-
-        self.inherited_flag_name = INHERITED_FLAG_FIELD_NAME % name
-        self.inherited_value_field_name = INHERITED_VALUE_FIELD_NAME % name
-        self.local_value_field_name = LOCAL_VALUE_FIELD_NAME % name
+        
+        self.inherit_flag_name = INHERIT_FLAG_NAME % name
+        self.value_field_name = VALUE_FIELD_NAME % name
 
         if not self.inherit_only:
             flag_field = BooleanField(default=True)
             flag_field.creation_counter = self.creation_counter
-
+            
             # Adjust the creation_counter
-            # cls.add_to_class(self.inherited_flag_name, flag_field)
-            flag_field.contribute_to_class(cls, self.inherited_flag_name)
+            # cls.add_to_class(self.inherit_flag_name, flag_field)
+            flag_field.contribute_to_class(cls, self.inherit_flag_name)
 
             signals.class_prepared.connect(
                 curry(self.add_value_field, name=name),
                 sender=cls,
                 weak=False
             )
-
+        
         setattr(cls, name, self)
         display_name = 'get_%s_display' % name
         setattr(cls, display_name, curry(self.get_field_display, name=name))
         getattr(cls, display_name).__dict__['short_description'] = name.replace('_', ' ')
 
-        if not hasattr(cls, 'FIELD_INHERITANCE_MAP'):
+        if not hasattr(cls, 'FIELD_INHERITANCE_MAP'): #TODO: test
             cls.FIELD_INHERITANCE_MAP = {}
 
-        cls.FIELD_INHERITANCE_MAP[name] = (self.parent_object_field_name, self.inherited_field_name_in_parent or name, self.sync)
-        if not self.sync:
-            signals.class_prepared.connect(self.patch_manager, sender=cls)
+        cls.FIELD_INHERITANCE_MAP[name] = (self.parent_object_field_name, self.inherited_field_name_in_parent or name)
+        signals.class_prepared.connect(self.patch_manager, sender=cls)
 
     def patch_manager(self, sender, **kwargs):
         if not hasattr(sender.objects, 'original_get_query_set'):
             _get_query_set = sender.objects.get_query_set
             def get_query_set(qs):
                 model = qs.model
-
+                
                 if hasattr(model, 'FIELD_INHERITANCE_REL'):
                     related = model.FIELD_INHERITANCE_REL
                 else:
                     related = set()
-                    for field, (parent, target_field, sync) in model.FIELD_INHERITANCE_MAP.iteritems():
-                        if not sync:
-                            chain = []
-                            find_in_parent(None, model, parent, target_field, validate=False, chain=chain)
-                            related.add('__'.join(chain))
+                    for field, (parent, target_field) in model.FIELD_INHERITANCE_MAP.iteritems():
+                        chain = []
+                        find_in_parent(model, parent, target_field, validate=False, chain=chain)
+                        related.add('__'.join(chain))
                     model.FIELD_INHERITANCE_REL = related
-
+                    
                 return _get_query_set().select_related(*related)
-
+            
             sender.objects.original_get_query_set = _get_query_set
             sender.objects.get_query_set = get_query_set.__get__(sender.objects)
             logger.debug("Patching %s's get_query_set method to slap a select_related on the returned qs.", sender.__name__)
 
-    def pre_save_hook(self, sender, **kwargs):
-        logger.debug("Entering pre_save_hook")
-        logger.debug(sender.REVERSE_FIELD_INHERITANCE_MAP)
-        for name in sender.REVERSE_FIELD_INHERITANCE_MAP:
-            logger.debug("Inheritance map: %s", name)
-        logger.debug("Exitting pre_save_hook")
-
     def add_value_field(self, sender, name=None, robust=True, **kwargs):
-        def contribute(cls, field):
-            if not hasattr(cls, 'REVERSE_FIELD_INHERITANCE_MAP'): #TODO: test
-                cls.REVERSE_FIELD_INHERITANCE_MAP = {}
-
-            cls.REVERSE_FIELD_INHERITANCE_MAP[self.inherited_field_name_in_parent or name] = sender
-            if self.sync:
-                signals.pre_save.connect(self.pre_save_hook, sender=cls)
+        def contribute(field):
+            if hasattr(field, '_choices') and field._choices:
+                self._choices = field._choices
 
             if isinstance(field, ReverseManyRelatedObjectsDescriptor):
                 field = field.field
-
-            logger.debug(
-                "Cloning field %s for class %s",
-                field, sender.__name__
-            )
-
-            # local value field
+            
             xfield = copy.deepcopy(field)
             xfield.blank = True
-            if isinstance(xfield, (ManyToManyField, ForeignKey)):
+            if isinstance(xfield, ManyToManyField):
                 xfield.rel.through = None
-                xfield.rel.related_name = '%s_inherited_set' % cls._meta.object_name
-
+                
             xfield.creation_counter = self.creation_counter
-            xfield.contribute_to_class(sender, INHERITED_VALUE_FIELD_NAME % name)
-
-            # inherited value field
-            yfield = copy.deepcopy(field)
-            yfield.blank = True
-            if isinstance(yfield, (ManyToManyField, ForeignKey)):
-                yfield.rel.through = None
-                yfield.rel.related_name = '%s_local_set' % cls._meta.object_name
-
-            yfield.creation_counter = self.creation_counter
-            yfield.contribute_to_class(sender, LOCAL_VALUE_FIELD_NAME % name)
-
-        result = find_in_parent(
-            None, sender,
+            xfield.contribute_to_class(sender, VALUE_FIELD_NAME % name)
+            
+        value_field = find_in_parent(
+            sender,
             self.parent_object_field_name,
             self.inherited_field_name_in_parent or name,
             robust and self.validate,
-            callback=contribute
+            callback = contribute
         )
-        if result:
-            parent, value_field = result
-            contribute(parent, value_field)
-
     def __get__(self, instance, instance_type=None):
-        if self.inherit_only or getattr(instance, self.inherited_flag_name):
-            if self.sync:
-                return getattr(instance, self.inherited_value_field_name, None)
-            else:
-                rel = getattr(instance, self.parent_object_field_name)
-                if rel:
-                    return getattr(rel, self.inherited_field_name_in_parent or self.name)
-        return getattr(instance, self.local_value_field_name, None)
+        if self.inherit_only or getattr(instance, self.inherit_flag_name):
+            rel = getattr(instance, self.parent_object_field_name)
+            if rel:
+                return getattr(rel, self.inherited_field_name_in_parent or self.name)
+        return getattr(instance, self.value_field_name, None)
 
     def __set__(self, instance, value):
         if self.inherit_only:
             rel = getattr(instance, self.parent_object_field_name)
             if rel:
                 parent_value = getattr(rel, self.inherited_field_name_in_parent or self.name)
-                setattr(instance, self.inherited_flag_name, value == parent_value)
+                setattr(instance, self.inherit_flag_name, value == parent_value)
             else:
-                setattr(instance, self.inherited_flag_name, False)
+                setattr(instance, self.inherit_flag_name, False)
         except ObjectDoesNotExist:
-            setattr(instance, self.inherited_flag_name, False)
-        setattr(instance, self.local_value_field_name, value)
+            setattr(instance, self.inherit_flag_name, False)
+        setattr(instance, self.value_field_name, value)
 
-def find_on_model(origin, model, field_name, validate=True, callback=None, chain=None):
+def find_on_model(model, field_name, validate=True, callback=None, chain=None):
     target_fields = [
         target for target in model._meta.fields
             if target.name == field_name
     if hasattr(model, 'FIELD_INHERITANCE_MAP'):
         MAP = getattr(model, 'FIELD_INHERITANCE_MAP')
         if field_name in MAP:
-            rel, field, sync = MAP[field_name]
-            result = find_in_parent(origin, model, rel, field, validate, callback, chain)
-            if result:
-                parent_model, field = result
-                return field
+            rel, field = MAP[field_name]
+            return find_in_parent(model, rel, field, validate, callback, chain)
         elif validate:
             raise TypeError("InheritedField: %s does not exist in %s." %
                                 (field_name, model))
         raise TypeError("InheritedField: %s does not exist in %s." %
                                     (field_name, model))
 
-def find_in_parent(origin, model_class, relation_name, field_name, validate=True, callback=None, chain=None):
+def find_in_parent(model_class, relation_name, field_name, validate=True, callback=None, chain=None):
     """
     This function will take a model class and search for the relation so that:
 
                     # hook
                     def resolve_related_class(xfield, xmodel, xcls):
                         xfield.rel.to = xmodel
-                        field_instance = find_on_model(origin or xmodel, xmodel, field_name, validate, callback, chain)
-                        callback(origin or xmodel, field_instance)
+                        field_instance = find_on_model(xmodel, field_name, validate, callback, chain)
+                        if field_instance:
+                            callback(field_instance)
 
                     add_lazy_relation(model_class, ifield, ifield.rel.to,
                                         resolve_related_class)
                     return
                 else:
-                    return (ifield.rel.to, find_on_model(origin or ifield.rel.to, ifield.rel.to, field_name, validate, callback, chain))
+                    return find_on_model(ifield.rel.to, field_name, validate, callback, chain)
             else:
                 if validate:
                     raise TypeError(
         raise TypeError("InheritedField: %s does not exist on %s." %
                         (relation_name, model_class))
 
-
 class InheritedFieldQuerySet(QuerySet):
     def is_inherited(self, parts):
         _parts = parts[:]
 
 class InheritedFieldManager(Manager):
     def get_query_set(self):
-        return InheritedFieldQuerySet(self.model, using=self._db)
+        return InheritedFieldQuerySet(self.model, using=self._db)