dingus / dingus.py

# This library was written by Gary Bernhardt and is licensed under the BSD
# license. It grew out of an old version of the mock library by Michael Foord
# available at http://www.voidspace.org.uk/python/mock.html.


import sys
from functools import wraps

def DingusTestCase(object_under_test, exclude=None):
    exclude = [] if exclude is None else exclude

    def get_names_under_test():
        module = sys.modules[object_under_test.__module__]
        for name, value in module.__dict__.iteritems():
            if value is object_under_test or name in exclude:
                yield name

    class TestCase(object):
        def setup(self):
            module_name = object_under_test.__module__
            self._dingus_module = sys.modules[module_name]
            self._dingus_replace_module_globals(self._dingus_module)

        def teardown(self):
            self._dingus_restore_module(self._dingus_module)

        def _dingus_replace_module_globals(self, module):
            old_module_dict = module.__dict__.copy()
            module_keys = set(module.__dict__.iterkeys())

            dunders = set(k for k in module_keys
                           if k.startswith('__') and k.endswith('__'))
            replaced_keys = (module_keys - dunders - set(names_under_test))
            for key in replaced_keys:
                module.__dict__[key] = Dingus()
            module.__dict__['__dingused_dict__'] = old_module_dict

        def _dingus_restore_module(self, module):
            old_module_dict = module.__dict__['__dingused_dict__']
            module.__dict__.clear()
            module.__dict__.update(old_module_dict)

    names_under_test = list(get_names_under_test())
    TestCase.__name__ = '%s_DingusTestCase' % '_'.join(names_under_test)
    return TestCase


# These sentinels are used for argument defaults because the user might want
# to pass in None, which is different in some cases than passing nothing.
class NoReturnValue(object):
    pass
class NoArgument(object):
    pass


def patch(object_path, new_object=NoArgument):
    module_name, attribute_name = object_path.rsplit('.', 1)
    return _Patcher(module_name, attribute_name, new_object)


class _Patcher:
    def __init__(self, module_name, attribute_name, new_object):
        self.module_name = module_name
        self.attribute_name = attribute_name
        self.module = _importer(self.module_name)
        self.new_object = Dingus() if new_object is NoArgument else new_object

    def __call__(self, fn):
        @wraps(fn)
        def new_fn(*args, **kwargs):
            self.patch_object()
            try:
                return fn(*args, **kwargs)
            finally:
                self.restore_object()
        return new_fn

    def __enter__(self):
        self.patch_object()

    def __exit__(self, exc_type, exc_value, traceback):
        self.restore_object()

    def patch_object(self):
        self.original_object = getattr(self.module, self.attribute_name)
        setattr(self.module, self.attribute_name, self.new_object)

    def restore_object(self):
        setattr(self.module, self.attribute_name, self.original_object)


def _importer(target):
    components = target.split('.')
    import_path = components.pop(0)
    thing = __import__(import_path)

    for comp in components:
        import_path += ".%s" % comp
        thing = _dot_lookup(thing, comp, import_path)
    return thing


def _dot_lookup(thing, comp, import_path):
    try:
        return getattr(thing, comp)
    except AttributeError:
        __import__(import_path)
        return getattr(thing, comp)


class DontCare(object):
    pass


class Call(tuple):
    def __new__(cls, name, args, kwargs, return_value):
        return tuple.__new__(cls, (name, args, kwargs, return_value))

    def __init__(self, *args):
        self.name = self[0]
        self.args = self[1]
        self.kwargs = self[2]
        self.return_value = self[3]
        
    def __getnewargs__(self):
        return (self.name, self.args, self.kwargs, self.return_value)


class CallList(list):
    @staticmethod
    def _match_args(call, args):
        if not args:
            return True
        elif len(args) != len(call.args):
            return False
        else:
            return all(args[i] in (DontCare, call.args[i])
                       for i in range(len(call.args)))

    @staticmethod
    def _match_kwargs(call, kwargs):
        if not kwargs:
            return True
        elif len(kwargs) != len(call.kwargs):
            return False
        else:
            return all(name in kwargs and kwargs[name] in (DontCare, val)
                       for name, val in call.kwargs.iteritems())

    def one(self):
        if len(self) == 1:
            return self[0]
        else:
            return None

    def once(self):
        return self.one()

    def __call__(self, __name=NoArgument, *args, **kwargs):
        return CallList([call for call in self
                         if (__name is NoArgument or __name == call.name)
                         and self._match_args(call, args)
                         and self._match_kwargs(call, kwargs)])


def returner(return_value):
    return Dingus(return_value=return_value)


class Dingus(object):
    def __init__(self, dingus_name=None, full_name=None, **kwargs):
        self._parent = None
        self.reset()
        name = 'dingus_%i' % id(self) if dingus_name is None else dingus_name
        full_name = name if full_name is None else full_name
        self._short_name = name
        self._full_name = full_name
        self.__name__ = name
        self._full_name = full_name

        for attr_name, attr_value in kwargs.iteritems():
            if attr_name.endswith('__returns'):
                attr_name = attr_name.replace('__returns', '')
                returner = self._create_child(attr_name)
                returner.return_value = attr_value
                setattr(self, attr_name, returner)
            else:
                setattr(self, attr_name, attr_value)

        self._replace_init_method()

    @classmethod
    def many(cls, count):
        return tuple(cls() for _ in range(count))

    def _fake_init(self, *args, **kwargs):
        return self.__getattr__('__init__')(*args, **kwargs)

    def _replace_init_method(self):
        self.__init__ = self._fake_init

    def _create_child(self, name):
        separator = ('' if (name.startswith('()') or name.startswith('['))
                     else '.')
        full_name = self._full_name + separator + name
        child = self.__class__(name, full_name)
        child._parent = self
        return child

    def reset(self):
        self._return_value = NoReturnValue
        self.calls = CallList()
        self._children = {}

    def _get_return_value(self):
        if self._return_value is NoReturnValue:
            self._return_value = self._create_child('()')
        return self._return_value

    def _set_return_value(self, value):
        self._return_value = value

    return_value = property(_get_return_value, _set_return_value)

    def __call__(self, *args, **kwargs):
        self._log_call('()', args, kwargs, self.return_value)
        if self._parent:
            self._parent._log_call(self._short_name,
                                   args,
                                   kwargs,
                                   self.return_value)

        return self.return_value

    def _log_call(self, name, args, kwargs, return_value):
        self.calls.append(Call(name, args, kwargs, return_value))

    def _should_ignore_attribute(self, name):
        return name in ['__pyobjc_object__', '__getnewargs__']
    
    def __getstate__(self):
        # Python cannot pickle a instancemethod
        # http://bugs.python.org/issue558238
        return [ (attr, value) for attr, value in self.__dict__.items() if attr != "__init__"]
    
    def __setstate__(self, state):
        self.__dict__.update(state)
        self._replace_init_method()

    def _existing_or_new_child(self, child_name, default_value=NoArgument):
        if child_name not in self._children:
            value = (self._create_child(child_name)
                     if default_value is NoArgument
                     else default_value)
            self._children[child_name] = value

        return self._children[child_name]

    def _remove_child_if_exists(self, child_name):
        if child_name in self._children:
            del self._children[child_name]

    def __getattr__(self, name):
        if self._should_ignore_attribute(name):
            raise AttributeError(name)
        return self._existing_or_new_child(name)

    def __delattr__(self, name):
        self._log_call('__delattr__', (name,), {}, None)

    def __getitem__(self, index):
        child_name = '[%s]' % (index,)
        return_value = self._existing_or_new_child(child_name)
        self._log_call('__getitem__', (index,), {}, return_value)
        return return_value

    def __setitem__(self, index, value):
        child_name = '[%s]' % (index,)
        self._log_call('__setitem__', (index, value), {}, None)
        self._remove_child_if_exists(child_name)
        self._existing_or_new_child(child_name, value)

    def _create_infix_operator(name):
        def operator_fn(self, other):
            return_value = self._existing_or_new_child(name)
            self._log_call(name, (other,), {}, return_value)
            return return_value
        operator_fn.__name__ = name
        return operator_fn

    _BASE_OPERATOR_NAMES = ['add', 'and', 'div', 'lshift', 'mod', 'mul', 'or',
                            'pow', 'rshift', 'sub', 'xor']

    def _infix_operator_names(base_operator_names):
        # This function has to have base_operator_names passed in because
        # Python's scoping rules prevent it from seeing the class-level
        # _BASE_OPERATOR_NAMES.

        reverse_operator_names = ['r%s' % name for name in base_operator_names]
        for operator_name in base_operator_names + reverse_operator_names:
            operator_fn_name = '__%s__' % operator_name
            yield operator_fn_name

    # Define each infix operator
    for operator_fn_name in _infix_operator_names(_BASE_OPERATOR_NAMES):
        exec('%s = _create_infix_operator("%s")' % (operator_fn_name,
                                              operator_fn_name))

    def _augmented_operator_names(base_operator_names):
        # Augmented operators are things like +=. They behavior differently
        # than normal infix operators because they return self instead of a
        # new object.

        return ['__i%s__' % operator_name
                for operator_name in base_operator_names]

    def _create_augmented_operator(name):
        def operator_fn(self, other):
            return_value = self
            self._log_call(name, (other,), {}, return_value)
            return return_value
        operator_fn.__name__ = name
        return operator_fn

    # Define each augmenting operator
    for operator_fn_name in _augmented_operator_names(_BASE_OPERATOR_NAMES):
        exec('%s = _create_augmented_operator("%s")' % (operator_fn_name,
                                                        operator_fn_name))

    def __str__(self):
        return '<Dingus %s>' % self._full_name
    __repr__ = __str__

    def __len__(self):
        return 1

    def __iter__(self):
        return iter([self._existing_or_new_child('__iter__')])


def exception_raiser(exception):
    def raise_exception(*args, **kwargs):
        raise exception
    return raise_exception
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.