Commits

David Stanek committed c14c184

added initial support for multibinders

Comments (0)

Files changed (4)

snakeguice/errors.py

 
 class AssistError(SnakeGuiceError):
     """Raised when an issue with assisted injection is found."""
+
+
+class MultiBindingError(SnakeGuiceError):
+    """Raised when a issue with multi-binding is found."""

snakeguice/injector.py

         binding = self.get_binding(key)
         if binding:
             provider = binding.scope.scope(key, binding.provider)
-            provider_inst = self.get_instance(provider)
-            return provider_inst.get()
+            if isinstance(provider, type):
+                provider = self.get_instance(provider)
+            return provider.get()
         else:
             return self.create_object(cls)
 

snakeguice/multibinder.py

+from snakeguice import providers
+from snakeguice.binder import Key
+from snakeguice.decorators import inject
+from snakeguice.interfaces import Injector
+from snakeguice.errors import MultiBindingError
+
+
+class _Hashable(object):
+
+    def __init__(self, interface):
+        self._interface = interface
+
+    def __hash__(self):
+        return id(self._interface)
+
+    def __eq__(self, other):
+        return hash(self) == hash(other)
+
+
+class _MultiBinder(object):
+
+    def __init__(self, binder, interface):
+        self._binder = binder
+        self._interface = interface
+        self._provider = self._get_or_create_provider()
+
+    def _get_or_create_provider(self):
+        key = Key(self.multibinding_type(self._interface))
+        binding = self._binder.get_binding(key)
+        if not binding:
+            self._binder.bind(self.multibinding_type(self._interface),
+                    to_provider=self._create_provider())
+            binding = self._binder.get_binding(key)
+        return binding.provider
+
+    def _dsl_to_provider(self, to, to_provider, to_instance):
+        if to:
+            #TODO: add some validation
+            return providers.create_simple_provider(to)
+        elif to_provider:
+            #TODO: add some validation
+            return to_provider
+        elif to_instance:
+            #TODO: add some validation
+            return providers.create_instance_provider(to_instance)
+        else:
+            raise MultiBindingError('incorrect arguments to %s.add_binding'
+                    % self.__class__.__name__)
+
+
+class List(_Hashable):
+    """Used for binding lists."""
+
+
+class ListBinder(_MultiBinder):
+
+    multibinding_type = List
+
+    def add_binding(self, to=None, to_provider=None, to_instance=None):
+        provider = self._dsl_to_provider(to, to_provider, to_instance)
+        self._provider.add_provider(provider)
+
+    def _create_provider(self):
+        class DynamicMultiBindingProvider(object):
+            providers = []
+
+            @inject(injector=Injector)
+            def __init__(self, injector):
+                self._injector = injector
+
+            @classmethod
+            def add_provider(cls, provider):
+                cls.providers.append(provider)
+
+            def get(self):
+                return [self._injector.get_instance(p).get()
+                        for p in self.providers]
+
+        return DynamicMultiBindingProvider
+
+
+class Dict(_Hashable):
+    """Used for binding lists."""
+
+
+class DictBinder(_MultiBinder):
+
+    multibinding_type = Dict
+
+    def add_binding(self, key, to=None, to_provider=None, to_instance=None):
+        provider = self._dsl_to_provider(to, to_provider, to_instance)
+        self._provider.add_provider(key, provider)
+
+    def _create_provider(self):
+        binder_self = self
+
+        class DynamicMultiBindingProvider(object):
+            providers = {}
+
+            @inject(injector=Injector)
+            def __init__(self, injector):
+                self._injector = injector
+
+            @classmethod
+            def add_provider(cls, key, provider):
+                if key in cls.providers:
+                    msg = ('duplicate binding for %r in Dict(%s) found'
+                            % (key, binder_self.interface.__class__.__name__))
+                    raise MultiBindingError(msg)
+                cls.providers[key] = provider
+
+            def get(self):
+                return dict([(k, self._injector.get_instance(p).get())
+                        for k, p in self.providers.items()])
+
+        return DynamicMultiBindingProvider
+
+

tests/system/test_multibinder.py

+from snakeguice import create_injector, inject
+from snakeguice.interfaces import Injector
+from snakeguice.multibinder import ListBinder, List, DictBinder, Dict
+from snakeguice import providers
+
+
+class ISnack(object):
+    """A snack interface."""
+
+
+class Twix(object):
+    """A concrete snack implementation."""
+
+
+class Snickers(object):
+    """A concrete snack implementation."""
+
+
+class Skittles(object):
+    """A concrete snack implementation."""
+
+
+class Lays(object):
+    """A concrete snack implementation."""
+
+
+class Tostitos(object):
+    """A concrete snack implementation."""
+
+
+class Ruffles(object):
+    """A concrete snack implementation."""
+
+
+class ListCandyModule(object):
+    """One to two modules adding to the multibinder."""
+
+    def configure(self, binder):
+        listbinder = ListBinder(binder, ISnack)
+        listbinder.add_binding(to=Twix)
+        provider = providers.create_simple_provider(Snickers)
+        listbinder.add_binding(to_provider=provider)
+        listbinder.add_binding(to_instance=Skittles())
+
+
+class ListChipsModule(object):
+    """One to two modules adding to the multibinder."""
+
+    def configure(self, binder):
+        listbinder = ListBinder(binder, ISnack)
+        listbinder.add_binding(to=Lays)
+        provider = providers.create_simple_provider(Tostitos)
+        listbinder.add_binding(to_provider=provider)
+        listbinder.add_binding(to_instance=Ruffles())
+
+
+class ListSnackMachine(object):
+
+    @inject(snacks=List(ISnack))
+    def __init__(self, snacks):
+        self.snacks = snacks
+
+
+class DictCandyModule(object):
+    """One to two modules adding to the multibinder."""
+
+    def configure(self, binder):
+        dictbinder = DictBinder(binder, ISnack)
+        dictbinder.add_binding('twix', to=Twix)
+        provider = providers.create_simple_provider(Snickers)
+        dictbinder.add_binding('snickers', to_provider=provider)
+        dictbinder.add_binding('skittles', to_instance=Skittles())
+
+
+class DictChipsModule(object):
+    """One to two modules adding to the multibinder."""
+
+    def configure(self, binder):
+        dictbinder = DictBinder(binder, ISnack)
+        dictbinder.add_binding('lays', to=Lays)
+        provider = providers.create_simple_provider(Tostitos)
+        dictbinder.add_binding('tostitos', to_provider=provider)
+        dictbinder.add_binding('ruffles', to_instance=Ruffles())
+
+
+class DictSnackMachine(object):
+
+    @inject(snacks=Dict(ISnack))
+    def __init__(self, snacks):
+        self.snacks = snacks
+
+
+SNACK_CLASSES = (Twix, Snickers, Skittles, Lays, Tostitos, Ruffles)
+
+
+class base_multibinder(object):
+
+    def test_that_the_injected_value_has_the_correct_number_of_elements(self):
+        assert len(self.snack_machine.snacks) == len(SNACK_CLASSES)
+
+
+class test_using_ListBinder(base_multibinder):
+
+    def setup(self):
+        injector = create_injector([ListCandyModule(), ListChipsModule()])
+        self.snack_machine = injector.get_instance(ListSnackMachine)
+
+    def test_that_the_elements_have_the_correct_type(self):
+        for n, snack in enumerate(self.snack_machine.snacks):
+            assert isinstance(snack, SNACK_CLASSES[n])
+
+
+class test_using_DictBinder(base_multibinder):
+
+    def setup(self):
+        injector = create_injector([DictCandyModule(), DictChipsModule()])
+        self.snack_machine = injector.get_instance(DictSnackMachine)
+
+    def test_that_the_elements_have_the_correct_type(self):
+        for k, v in self.snack_machine.snacks.items():
+            assert k == v.__class__.__name__.lower()
+            assert v.__class__ in SNACK_CLASSES