Anonymous avatar Anonymous committed a3ecbde

importutils: added add_import for better import additions

Comments (0)

Files changed (4)

rope/refactor/importutils/__init__.py

 
 
 def get_module_imports(pycore, pymodule):
-    """A shortcut creating a `module_imports.ModuleImports` object"""
+    """A shortcut for creating a `module_imports.ModuleImports` object"""
     return module_imports.ModuleImports(pycore, pymodule)
+
+
+def add_import(pycore, pymodule, module_name, name):
+    imports = get_module_imports(pycore, pymodule)
+    normal_import = NormalImport([(module_name, None)])
+    from_import = FromImport(module_name, 0, [(name, None)])
+    visitor = actions.AddingVisitor(pycore, [from_import, normal_import])
+    selected_import = normal_import
+    for import_statement in imports.get_import_statements():
+        if import_statement.accept(visitor):
+            selected_import = visitor.import_info
+            break
+    imports.add_import(selected_import)
+    if isinstance(selected_import, NormalImport):
+        imported_name = module_name + '.' + name
+    else:
+        imported_name = name
+    return imports.get_changed_source(), imported_name

rope/refactor/importutils/actions.py

 
 
 class AddingVisitor(ImportInfoVisitor):
+    """A class for adding imports
 
-    def __init__(self, pycore, import_info):
+    Given a list of `ImportInfo`\s, it tries to add each import to the
+    module and returns `True` and gives up when an import can be added
+    to older ones.
+
+    """
+
+    def __init__(self, pycore, import_list):
         self.pycore = pycore
-        self.import_info = import_info
+        self.import_list = import_list
+        self.import_info = None
+
+    def dispatch(self, import_):
+        for import_info in self.import_list:
+            self.import_info = import_info
+            if ImportInfoVisitor.dispatch(self, import_):
+                return True
 
     # TODO: Handle adding relative and absolute imports
     def visitNormalImport(self, import_stmt, import_info):

rope/refactor/importutils/module_imports.py

         return result
 
     def add_import(self, import_info):
-        visitor = actions.AddingVisitor(self.pycore, import_info)
+        visitor = actions.AddingVisitor(self.pycore, [import_info])
         for import_statement in self.get_import_statements():
             if import_statement.accept(visitor):
                 break
         added_imports = []
         for import_stmt in imports:
             visitor = actions.AddingVisitor(self.pycore,
-                                            import_stmt.import_info)
+                                            [import_stmt.import_info])
             for added_import in added_imports:
                 if added_import.accept(visitor):
                     import_stmt.empty_import()

ropetest/refactor/importutilstest.py

 import unittest
 
-from rope.refactor.importutils import ImportTools, importinfo
+from rope.refactor.importutils import ImportTools, importinfo, add_import
 from ropetest import testutils
 
 
         self.assertEquals(1, len(imports))
 
 
+class AddImportTest(unittest.TestCase):
+
+    def setUp(self):
+        super(AddImportTest, self).setUp()
+        self.project = testutils.sample_project()
+        self.pycore = self.project.get_pycore()
+
+        self.mod1 = testutils.create_module(self.project, 'mod1')
+        self.mod2 = testutils.create_module(self.project, 'mod2')
+        self.pkg = testutils.create_package(self.project, 'pkg')
+        self.mod3 = testutils.create_module(self.project, 'mod3', self.pkg)
+
+    def tearDown(self):
+        testutils.remove_project(self.project)
+        super(AddImportTest, self).tearDown()
+
+    def test_normal_imports(self):
+        self.mod2.write('myvar = None\n')
+        self.mod1.write('\n')
+        pymod = self.pycore.get_module('mod1')
+        result, name = add_import(self.pycore, pymod, 'mod2', 'myvar')
+        self.assertEquals('import mod2\n', result)
+        self.assertEquals('mod2.myvar', name)
+
+    def test_not_reimporting_a_name(self):
+        self.mod2.write('myvar = None\n')
+        self.mod1.write('from mod2 import myvar\n')
+        pymod = self.pycore.get_module('mod1')
+        result, name = add_import(self.pycore, pymod, 'mod2', 'myvar')
+        self.assertEquals('from mod2 import myvar\n', result)
+        self.assertEquals('myvar', name)
+
+
+def suite():
+    result = unittest.TestSuite()
+    result.addTests(unittest.makeSuite(ImportUtilsTest))
+    result.addTests(unittest.makeSuite(AddImportTest))
+    return result
+
 if __name__ == '__main__':
     unittest.main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.