Commits

Anonymous committed 3bd3603

importutils: add_import handles modules, too

  • Participants
  • Parent commits 8408fba

Comments (0)

Files changed (2)

rope/refactor/importutils/__init__.py

     return module_imports.ModuleImports(pycore, pymodule)
 
 
-def add_import(pycore, pymodule, module_name, name):
+def add_import(pycore, pymodule, module_name, name=None):
     imports = get_module_imports(pycore, pymodule)
-    from_import = FromImport(module_name, 0, [(name, None)])
-    candidates = [from_import]
+    candidates = []
+    names = []
+    # from mod import name
+    if name is not None:
+        from_import = FromImport(module_name, 0, [(name, None)])
+        names.append(name)
+        candidates.append(from_import)
+    # from pkg import mod
     if '.' in module_name:
         pkg, mod = module_name.rsplit('.', 1)
         candidates.append(FromImport(pkg, 0, [(mod, None)]))
+        if name:
+            names.append(mod + '.' + name)
+        else:
+            names.append(mod)
+    # import mod
     normal_import = NormalImport([(module_name, None)])
+    if name:
+        names.append(module_name + '.' + name)
+    else:
+        names.append(module_name)
+
     candidates.append(normal_import)
 
     visitor = actions.AddingVisitor(pycore, candidates)
             selected_import = visitor.import_info
             break
     imports.add_import(selected_import)
-    if selected_import == normal_import:
-        imported_name = module_name + '.' + name
-    elif selected_import == from_import:
-        imported_name = name
-    else:
-        imported_name = mod + '.' + name
+    imported_name = names[candidates.index(selected_import)]
     return imports.get_changed_source(), imported_name

ropetest/refactor/importutilstest.py

         self.assertEquals('from pkg import var1, mod3\n', result)
         self.assertEquals('mod3.var2', name)
 
+    def test_adding_import_for_modules_instead_of_names(self):
+        self.pkg.get_child('__init__.py').write('var1 = None\n')
+        self.mod3.write('\n')
+        self.mod1.write('from pkg import var1\n')
+        pymod = self.pycore.get_module('mod1')
+        result, name = add_import(self.pycore, pymod, 'pkg.mod3', None)
+        self.assertEquals('from pkg import var1, mod3\n', result)
+        self.assertEquals('mod3', name)
+
+    def test_adding_import_for_modules_with_normal_duplicate_imports(self):
+        self.pkg.get_child('__init__.py').write('var1 = None\n')
+        self.mod3.write('\n')
+        self.mod1.write('import pkg.mod3\n')
+        pymod = self.pycore.get_module('mod1')
+        result, name = add_import(self.pycore, pymod, 'pkg.mod3', None)
+        self.assertEquals('import pkg.mod3\n', result)
+        self.assertEquals('pkg.mod3', name)
+
 
 def suite():
     result = unittest.TestSuite()