Source

pytest-patches / asser-reloader-package

# HG changeset patch
# Parent 6796a80b42499aecc47e952e9f195c3274199997
diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py
--- a/_pytest/assertion/rewrite.py
+++ b/_pytest/assertion/rewrite.py
@@ -153,12 +153,30 @@ class AssertionRewritingHook(object):
             mod.__file__ = co.co_filename
             # Normally, this attribute is 3.2+.
             mod.__cached__ = pyc
+            mod.__loader__ = self
             py.builtin.exec_(co, mod.__dict__)
         except:
             del sys.modules[name]
             raise
         return sys.modules[name]
 
+    def is_package(self, name):
+        if name in sys.modules:
+            return getattr(sys.modules[name], '__path__', None) is not None
+        elif name in self.modules:
+            fname = self.modules[name][0].co_filname #XXX: py3?
+            #we only work with files anyway
+            return os.path.basename(fname).startswith('__init__.')
+        else:
+            #XXX: damn
+            try:
+                fd, fn, desc = imp.find_module(name)
+            except ImportError:
+                return None
+            if fd is not None:
+                fd.close()
+            return os.path.isdir(fn)
+
 def _write_pyc(co, source_path, pyc):
     # Technically, we don't have to have the same pyc format as (C)Python, since
     # these "pycs" should never be seen by builtin import. However, there's
diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py
--- a/testing/test_assertrewrite.py
+++ b/testing/test_assertrewrite.py
@@ -374,3 +374,33 @@ def test_rewritten():
         b = content.encode("utf-8")
         testdir.tmpdir.join("test_newlines.py").write(b, "wb")
         assert testdir.runpytest().ret == 0
+
+
+class TestAssertionRewriteHookDetails(object):
+    def test_loader_is_package_false_for_module(self, testdir):
+        testdir.makepyfile(test_fun="""
+            def test_loader():
+                assert not __loader__.is_package(__name__)
+            """)
+        result = testdir.runpytest()
+        result.stdout.fnmatch_lines([
+            "* 1 passed*",
+        ])
+
+    def test_loader_is_package_true_for_package(self, testdir):
+        testdir.makepyfile(test_fun="""
+            def test_loader():
+                assert not __loader__.is_package(__name__)
+
+            def test_fun():
+                assert __loader__.is_package('fun')
+
+            def test_missing():
+                assert not __loader__.is_package('pytest_not_there')
+            """)
+        pkg = testdir.mkpydir('fun')
+        result = testdir.runpytest()
+        result.stdout.fnmatch_lines([
+            '* 3 passed*',
+        ])
+