Commits

Anonymous committed 24453f5

importutils: add_imports handles from pkg import mod for pkg.mod.var

Comments (0)

Files changed (2)

rope/refactor/importutils/__init__.py

 
 def add_import(pycore, pymodule, module_name, name):
     imports = get_module_imports(pycore, pymodule)
+    from_import = FromImport(module_name, 0, [(name, None)])
+    candidates = [from_import]
+    if '.' in module_name:
+        pkg, mod = module_name.rsplit('.')
+        candidates.append(FromImport(pkg, 0, [(mod, None)]))
     normal_import = NormalImport([(module_name, None)])
-    from_import = FromImport(module_name, 0, [(name, None)])
-    visitor = actions.AddingVisitor(pycore, [from_import, normal_import])
+    candidates.append(normal_import)
+
+    visitor = actions.AddingVisitor(pycore, candidates)
     selected_import = normal_import
     for import_statement in imports.get_import_statements():
         if import_statement.accept(visitor):
             selected_import = visitor.import_info
             break
     imports.add_import(selected_import)
-    if isinstance(selected_import, NormalImport):
+    if selected_import == normal_import:
         imported_name = module_name + '.' + name
+    elif selected_import == from_import:
+        imported_name = name
     else:
-        imported_name = name
+        imported_name = mod + '.' + name
     return imports.get_changed_source(), imported_name

ropetest/refactor/importutilstest.py

         self.assertEquals('from mod2 import myvar\n', result)
         self.assertEquals('myvar', name)
 
+    def test_adding_import_when_siblings_are_imported(self):
+        self.mod2.write('var1 = None\nvar2 = None\n')
+        self.mod1.write('from mod2 import var1\n')
+        pymod = self.pycore.get_module('mod1')
+        result, name = add_import(self.pycore, pymod, 'mod2', 'var2')
+        self.assertEquals('from mod2 import var1, var2\n', result)
+        self.assertEquals('var2', name)
+
+    def test_adding_import_when_the_package_is_imported(self):
+        self.pkg.get_child('__init__.py').write('var1 = None\n')
+        self.mod3.write('var2 = None\n')
+        self.mod1.write('from pkg import var1\n')
+        pymod = self.pycore.get_module('mod1')
+        result, name = add_import(self.pycore, pymod, 'pkg.mod3', 'var2')
+        self.assertEquals('from pkg import var1, mod3\n', result)
+        self.assertEquals('mod3.var2', name)
+
 
 def suite():
     result = unittest.TestSuite()