Commits

Amaury Forgeot d'Arc  committed e3f11f2
  • Participants
  • Parent commits 7cfdee1

Comments (0)

Files changed (2)

File pypy/module/imp/importing.py

 def check_sys_modules_w(space, modulename):
     return space.finditem_str(space.sys.get('modules'), modulename)
 
+def _get_relative_name(space, modulename, level, w_globals):
+    w = space.wrap
+    ctxt_w_package = space.finditem(w_globals, w('__package__'))
+
+    ctxt_package = None
+    if ctxt_w_package is not None and ctxt_w_package is not space.w_None:
+        try:
+            ctxt_package = space.str_w(ctxt_w_package)
+        except OperationError, e:
+            if not e.match(space, space.w_TypeError):
+                raise
+            raise OperationError(space.w_ValueError, space.wrap(
+                "__package__ set to non-string"))
+
+    if ctxt_package is not None:
+        # __package__ is set, so use it
+        package_parts = ctxt_package.split('.')
+        while level > 1 and package_parts:
+            level -= 1
+            if package_parts:
+                package_parts.pop()
+        if not package_parts:
+            if len(ctxt_package) == 0:
+                msg = "Attempted relative import in non-package"
+            else:
+                msg = "Attempted relative import beyond toplevel package"
+            raise OperationError(space.w_ValueError, w(msg))
+
+        # Try to import parent package
+        try:
+            w_parent = absolute_import(space, ctxt_package, 0,
+                                       None, tentative=False)
+        except OperationError, e:
+            if not e.match(space, space.w_ImportError):
+                raise
+            if level > 0:
+                raise OperationError(space.w_SystemError, space.wrap(
+                    "Parent module '%s' not loaded, "
+                    "cannot perform relative import" % ctxt_package))
+            else:
+                space.warn("Parent module '%s' not found "
+                           "while handling absolute import" % ctxt_package,
+                           space.w_RuntimeWarning)
+
+        if modulename:
+            package_parts.append(modulename)
+        rel_level = len(package_parts)
+        rel_modulename = '.'.join(package_parts)
+    else:
+        # __package__ not set, so figure it out and set it
+        ctxt_w_name = space.finditem(w_globals, w('__name__'))
+        ctxt_w_path = space.finditem(w_globals, w('__path__'))
+
+        ctxt_name = None
+        if ctxt_w_name is not None:
+            try:
+                ctxt_name = space.str_w(ctxt_w_name)
+            except OperationError, e:
+                if not e.match(space, space.w_TypeError):
+                    raise
+
+        if not ctxt_name:
+            return None, 0
+
+        ctxt_name_prefix_parts = ctxt_name.split('.')
+        if level > 0:
+            n = len(ctxt_name_prefix_parts)-level+1
+            assert n>=0
+            ctxt_name_prefix_parts = ctxt_name_prefix_parts[:n]
+        if ctxt_name_prefix_parts and ctxt_w_path is None: # plain module
+            ctxt_name_prefix_parts.pop()
+
+        if level > 0 and not ctxt_name_prefix_parts:
+            msg = "Attempted relative import in non-package"
+            raise OperationError(space.w_ValueError, w(msg))
+
+        rel_modulename = '.'.join(ctxt_name_prefix_parts)
+
+        if ctxt_w_path is not None:
+            # __path__ is set, so __name__ is already the package name
+            space.setitem(w_globals, w("__package__"), ctxt_w_name)
+        else:
+            # Normal module, so work out the package name if any
+            if '.' not in ctxt_name:
+                space.setitem(w_globals, w("__package__"), space.w_None)
+            elif rel_modulename:
+                space.setitem(w_globals, w("__package__"), w(rel_modulename))
+
+        if modulename:
+            if rel_modulename:
+                rel_modulename += '.' + modulename
+            else:
+                rel_modulename = modulename
+
+        rel_level = len(ctxt_name_prefix_parts)
+
+    return rel_modulename, rel_level
+
+
 @unwrap_spec(name=str, level=int)
 def importhook(space, name, w_globals=None,
                w_locals=None, w_fromlist=None, level=-1):
         w_globals is not None and
         space.isinstance_w(w_globals, space.w_dict)):
 
-        ctxt_w_name = space.finditem(w_globals, w('__name__'))
-        ctxt_w_path = space.finditem(w_globals, w('__path__'))
+        rel_modulename, rel_level = _get_relative_name(space, modulename, level, w_globals)
 
-        ctxt_name = None
-        if ctxt_w_name is not None:
-            try:
-                ctxt_name = space.str_w(ctxt_w_name)
-            except OperationError, e:
-                if not e.match(space, space.w_TypeError):
-                    raise
+        if rel_modulename:
+            # if no level was set, ignore import errors, and
+            # fall back to absolute import at the end of the
+            # function.
+            if level == -1:
+                tentative = True
+            else:
+                tentative = False
 
-        if ctxt_name is not None:
-            ctxt_name_prefix_parts = ctxt_name.split('.')
-            if level > 0:
-                n = len(ctxt_name_prefix_parts)-level+1
-                assert n>=0
-                ctxt_name_prefix_parts = ctxt_name_prefix_parts[:n]
-            if ctxt_name_prefix_parts and ctxt_w_path is None: # plain module
-                ctxt_name_prefix_parts.pop()
-            if ctxt_name_prefix_parts:
-                rel_modulename = '.'.join(ctxt_name_prefix_parts)
-                if modulename:
-                    rel_modulename += '.' + modulename
-            baselevel = len(ctxt_name_prefix_parts)
-            if rel_modulename is not None:
-                # XXX What is this check about? There is no test for it
-                w_mod = check_sys_modules(space, w(rel_modulename))
+            w_mod = absolute_import(space, rel_modulename, rel_level,
+                                    fromlist_w, tentative=tentative)
+            if w_mod is not None:
+                space.timer.stop_name("importhook", modulename)
+                return w_mod
 
-                if (w_mod is None or
-                    not space.is_w(w_mod, space.w_None) or
-                    level > 0):
+        ## if level > 0:
+        ##     msg = "Attempted relative import in non-package"
+        ##     raise OperationError(space.w_ValueError, w(msg))
 
-                    # if no level was set, ignore import errors, and
-                    # fall back to absolute import at the end of the
-                    # function.
-                    if level == -1:
-                        tentative = True
-                    else:
-                        tentative = False
+        ## if not modulename:
+        ##     return None
 
-                    w_mod = absolute_import(space, rel_modulename,
-                                            baselevel, fromlist_w,
-                                            tentative=tentative)
-                    if w_mod is not None:
-                        space.timer.stop_name("importhook", modulename)
-                        return w_mod
-                else:
-                    rel_modulename = None
-
-        if level > 0:
-            msg = "Attempted relative import in non-package"
-            raise OperationError(space.w_ValueError, w(msg))
     w_mod = absolute_import_try(space, modulename, 0, fromlist_w)
     if w_mod is None or space.is_w(w_mod, space.w_None):
         w_mod = absolute_import(space, modulename, 0, fromlist_w, tentative=0)

File pypy/module/imp/test/test_import.py

         res = __import__('', mydict, None, ['bar'], 2)
         assert res is pkg
 
+    def test__package__(self):
+        # Regression test for http://bugs.python.org/issue3221.
+        def check_absolute():
+            exec "from os import path" in ns
+        def check_relative():
+            exec "from . import a" in ns
+
+        # Check both OK with __package__ and __name__ correct
+        ns = dict(__package__='pkg', __name__='pkg.notarealmodule')
+        check_absolute()
+        check_relative()
+
+        # Check both OK with only __name__ wrong
+        ns = dict(__package__='pkg', __name__='notarealpkg.notarealmodule')
+        check_absolute()
+        check_relative()
+
+        # Check relative fails with only __package__ wrong
+        ns = dict(__package__='foo', __name__='pkg.notarealmodule')
+        check_absolute() # XXX check warnings
+        raises(SystemError, check_relative)
+
+        # Check relative fails with __package__ and __name__ wrong
+        ns = dict(__package__='foo', __name__='notarealpkg.notarealmodule')
+        check_absolute() # XXX check warnings
+        raises(SystemError, check_relative)
+
+        # Check both fail with package set to a non-string
+        ns = dict(__package__=object())
+        raises(ValueError, check_absolute)
+        raises(ValueError, check_relative)
+
     def test_universal_newlines(self):
         import pkg_univnewlines
         assert pkg_univnewlines.a == 5