Lenard Lindstrom avatar Lenard Lindstrom committed 91a0d35

Add program to list Pygame modules accessed by a unit test module

Comments (0)

Files changed (6)

+# Program check_test.py
+# Requires Python 2.4
+
+"""A program for listing the modules accessed by a Pygame unit test module
+
+Usage:
+
+python check_test.py <test module>
+
+e.g.
+
+python check_test.py surface_test.py
+
+The returned list will show which Pygame modules were imported and accessed.
+Each module name is followed by a list of attributes accessed.
+
+"""
+
+import sys
+import os
+import trackmod
+trackmod.begin(pattern=['pygame', 'pygame.*'],
+               continuous=True,
+               submodule_accesses=False)
+skip = set(['pygame.locals', 'pygame.constants',
+            'pygame.base', 'pygame.threads'])
+
+sys.path.append('.')
+os.chdir('test')
+test_file = sys.argv[1]
+del sys.argv[1]
+try:
+    execfile(test_file)
+finally:
+    trackmod.end()
+    print "=== Pygame package submodule accesses ==="
+    print
+    accesses = [(n, a) for n, a in trackmod.get_accesses().iteritems()
+                       if n not in skip]
+    accesses.sort(key=lambda t: t[0])
+    for name, attributes in accesses:
+        print "%s (%s)" % (name, ', '.join(attributes))

trackmod/__init__.py

+# package trackmod
+# For Python 2.4 and up.
+
+"""A package for tracking module use
+
+Exports:
+    begin(repfile=None) ==> None
+    end() ==> None
+    get_previous_imports() ==> List of names
+    get_my_imports() ==> List of names
+    get_imports() ==> List of names
+    get_unaccessed_modules() ==> List of names
+    get_accessed_modules() ==> List of names
+    get_accesses() ==> Dictionary of attribute names by module name
+    write_report(repfile) ==> None
+
+"""
+
+from trackmod import reporter  # Keep this first.
+import sys
+import atexit
+
+from trackmod import importer, module
+
+try:
+    installed
+except NameError:
+    installed = False
+else:
+    # reloaded; reload submodules.
+    reload(importer)  # implicit reporter and module reload.
+
+def print_(*args, **kwds):
+    stream = kwds.get('file', sys.stdout)
+    sep = kwds.get('sep', ' ')
+    end = kwds.get('end', '\n')
+
+    if args:
+        stream.write(sep.join([str(arg) for arg in args]))
+    if end:
+        stream.write(end)
+
+def _write_report(repfile):
+    def report(*args, **kwds):
+        print_(file=repfile, *args, **kwds)
+
+    report("=== module usage report ===")
+    report("\n-- modules already imported (ignored) --")
+    for name in get_previous_imports():
+        report(name)
+    report("\n-- modules added by", __name__.split('.')[0], "(ignored) --")
+    for name in get_my_imports():
+        report(name)
+    report("\n-- modules imported but not accessed --")
+    for name in get_unaccessed_modules():
+        report(name)
+    report("\n-- modules accessed --")
+    accesses = sorted(get_accesses().iteritems())
+    for name, attrs in accesses:
+        report(name, " (", ', '.join(attrs), ")", sep='')
+    report("\n=== end of report ===")
+
+def get_previous_imports():
+    """Return a new sorted name list of previously imported modules"""
+    return reporter.get_previous_imports()
+
+def get_my_imports():
+    """Return a new sorted name list of module imported by this package"""
+    return reporter.get_my_imports()
+
+def get_imports():
+    """Return a new sorted name list of imported modules"""
+    return reporter.get_imports()
+
+def get_unaccessed_modules():
+    """Return a new sorted name list of unaccessed imported modules"""
+    return reporter.get_unaccessed_modules()
+    
+def get_accessed_modules():
+    """Return a new sorted name list of accessed modules"""
+    return reporter.get_accessed_modules()
+
+def get_accesses():
+    """Return a new dictionary of lists of attributes by module name"""
+    return reporter.get_accesses()
+
+def write_report(repfile=None):
+    """Write a module import and access report to repfile
+
+    repfile may be an open file object of a file path. If not previded
+    then writes to standard output. Data collection is terminated if not
+    already stopped by an end() call. If no data is collected, begin() not
+    called, then a runtime error is raised.
+
+    """
+    try:
+        if collecting:
+            end()
+    except NameError:
+        raise RuntimeError("No import data was collected")
+    if repfile is None:
+        _write_report(sys.stdout)
+    else:
+        try:
+            repfile.write
+        except AttributeError:
+            rf = open(repfile, 'w')
+            try:
+                _write_report(rf)
+            finally:
+                rf.close()
+        else:
+            _write_report(repfile)
+
+def begin(repfile=None,
+          pattern=None,
+          continuous=False,
+          submodule_accesses=True):
+    """Start collecting import and module access information
+
+    repfile (default no file) is the destination for an
+    end-of-run module import and access report. It can be either a file
+    path or an open file object.
+
+    pattern (default ['*']) is a list of modules on which to collect data. It
+    is a list of one or more dotted full module names. An asterisk '*' is a
+    wild card an matches everything. Examples:
+      ['pygame']               Will on report on top level pygame package
+      ['pygame', 'numpy']      Only top level pygame and numpy modules
+      ['pygame', 'pygame.surface']
+                               pygame and pygame.surface
+      ['pygame', 'pygame.*']   pygame and all its submodules
+      ['*']                    everything
+
+    continous (default False) indicates whether per-module attribute access
+    recording should stop with the first access or be continuous. Set False
+    to stop after the first access, True for continuous recording.
+
+    submodule_accesses (default True) indicates whether submodules imports
+    are to be included as an access on the containing package.
+
+    """
+    global installed, collecting
+
+    if not installed:
+        sys.meta_path.insert(0, importer)
+        installed = True
+        if repfile is not None:
+            atexit.register(write_report, repfile)
+    try:
+        if collecting:
+            return
+    except NameError:
+        collecting = True
+    if continuous:
+        module.set_report_mode('continuous')
+    importer.begin(pattern, submodule_accesses)
+
+def end():
+    global collecting
+    collecting = False
+    reporter.end()
+    importer.end()
+    module.set_report_mode('quit')
+
+reporter.begin()  # Keep this last.
+
+
+

trackmod/importer.py

+# module trackmod.importer
+
+"""A sys.meta_path importer for tracking module usage."""
+
+import sys
+from trackmod import module, namereg
+
+try:
+    collect_data
+except NameError:
+    pass
+else:
+    # reload: reload imported modules.
+    reload(module)  # implicit reload of reporter
+    reload(namereg)
+
+no_modules = []  # Contains nothing.
+modules_of_interest = no_modules
+add_submodule_accesses = True
+
+
+class Loader(object):
+    def __init__(self, fullname, module):
+        self.fullname = fullname
+        self.module = module
+
+    def load_module(self, fullname):
+        assert fullname == self.fullname, (
+            "loader called with wrong module %s: expecting %s" %
+              (fullname, self.fullname))
+        sys.modules[fullname] = self.module
+        return self.module
+
+def find_module(fullname, path=None):
+    if fullname in modules_of_interest and fullname not in sys.modules:
+        # reload doesn't "get" any tracked TrackerModule attributes.
+        m = module.TrackerModule(fullname)
+
+        # Add m to modules so reload works and to prevent infinite recursion.
+        sys.modules[fullname] = m
+        try:
+            try:
+                reload(m)
+            except ImportError, e:
+                return None;
+        finally:
+            del sys.modules[fullname]
+
+        # Add parent package access.
+        if add_submodule_accesses:
+            parts = fullname.rsplit('.', 1)
+            if len(parts) == 2:
+                try:
+                    pkg = sys.modules[parts[0]]
+                except KeyError:
+                    pass
+                else:
+                    try:
+                        getattr(pkg, parts[1])
+                    except AttributeError:
+                        pass
+
+        return Loader(fullname, m)
+    else:
+        return None
+
+def end():
+    global modules_of_interest
+    modules_of_interest = no_modules
+
+def begin(pattern=None, submodule_accesses=True):
+    global modules_of_interest, collect_data, add_submodule_accesses
+    if pattern is None:
+        pattern = ['*']
+    modules_of_interest = namereg.NameRegistry(pattern)
+    add_submodule_accesses = submodule_accesses

trackmod/module.py

+# module trackmod.module
+
+"""Implements a module usage tracker module type"""
+
+import threading
+
+
+ModuleType = type(threading)
+getattribute = ModuleType.__getattribute__
+accesses = set()
+accesses_lock = threading.RLock()
+
+
+class Module(ModuleType):
+    # A heap subtype of the module type.
+    #
+    # Allows __class__ to be changed. Otherwise it is just the same.
+    # To preserve the module docs this description is a comment.
+    
+    pass
+
+
+class TrackerModule(ModuleType):
+    # A heap subtype of the module type that tracks attribute gets.
+    #
+    # Allows __class__ to be changed. Otherwise it is just the same.
+    # To preserve the module docs this description is a comment.
+
+    # Attributes to ignore in reporting. The module name is the one
+    # attribute guarenteed to not be recorded. The class is used by
+    # the reporter. The path is just noise.
+    ignored_attributes = set(['__name__', '__class__', '__path__'])
+    
+    def __getattribute__(self, attr):
+        if attr in TrackerModule.ignored_attributes:
+            return getattribute(self, attr)
+        report(self, attr)
+        return getattribute(self, attr)
+
+
+def report_continuous(module, attr):
+    accesses_lock.acquire()
+    try:
+        # Safe: no recursive call on __name__ attribute.
+        accesses.add((module.__name__, attr))
+    finally:
+        accesses_lock.release()
+
+def report_quit(module, attr):
+    module.__class__ = Module
+
+def report_oneshot(module, attr):
+    report_continuous(module, attr)
+    report_quit(module, attr)
+
+report = report_oneshot
+
+
+def set_report_mode(mode=None):
+    """Set whether access checking is oneshot or continuous
+
+    if mode (default 'oneshot') is 'oneshot' or None then a TrackerModule
+    module will stop recording attribute accesses after the first non-trivial
+    access. If 'continuous' then all attribute accesses are recorded. If
+    'quit' then access recording stops and further calls to this function
+    have no effect.
+
+    """
+    global report
+    
+    if report is report_quit:
+        return
+    if mode is None:
+        mode = 'oneshot'
+    if mode == 'oneshot':
+        report = report_oneshot
+    elif mode == 'continuous':
+        report = report_continuous
+    elif mode == 'quit':
+        report = report_quit
+    else:
+        raise ValueError("Unknown mode %s" % mode)
+
+
+def get_accesses():
+    accesses_lock.acquire()
+    try:
+        return sorted(accesses)
+    finally:
+        accesses_lock.release()
+

trackmod/namereg.py

+# module trackmod.namereg
+
+class NameRegistry(object):
+    
+    class AllRegistered(object):
+        terminal = True
+        def register(self, names):
+            return
+        def __contains__(self, name):
+            return True
+    all_registered = AllRegistered()
+
+    class AllFound(object):
+        def __init__(self, value):
+            self.value = value
+        def __getitem__(self, key):
+            return self.value
+    all_found = AllFound(all_registered)
+
+    def __init__(self, names=None):
+        self.names = {}
+        if names is not None:
+            self.add(names)
+        self.terminal = False
+
+    def add(self, names):
+        if names is None:
+            self.terminal = True
+            return
+        for name in names:
+            parts = name.split('.', 1)
+            first = parts[0]
+            if first == '*':
+                self.names = self.all_found
+                return
+            else:
+                try:
+                    sub_registry = self.names[first]
+                except KeyError:
+                    sub_registry = NameRegistry()
+                    self.names[first] = sub_registry
+                if len(parts) == 2:
+                    sub_registry.add(parts[1:])
+                else:
+                    sub_registry.terminal = True
+
+    def __contains__(self, name):
+        parts = name.split('.', 1)
+        try:
+            sub_registry = self.names[parts[0]]
+        except KeyError:
+            return False
+        # This uses a conditional or.
+        if len(parts) == 1:
+            return sub_registry.terminal
+        return parts[1] in sub_registry
+

trackmod/reporter.py

+# module trackmod.reporter
+
+# Keep this first.
+def listmods():
+    return [n for n, m in sys.modules.iteritems() if m is not None]
+
+import sys
+previous_imports = listmods()  #  Keep this after sys but before other imports.
+import threading
+
+import module
+
+
+# This module is does not need explicit thread protection since all calls
+# to the data entry methods are made while the import lock is acquired.
+collect_data = True
+my_imports = None
+accesses = None
+failed_imports = None
+
+try:
+    next
+except NameError:
+    def next(iterator):
+        return iterator.next()
+
+class Largest(object):
+    """This object is always greater than any other non Largest object"""
+    def __lt__(self, other):
+        return False
+    def __le__(self, other):
+        return self == other
+    def __eq__(self, other):
+        return isinstance(other, Largest)
+    def __ne__(self, other):
+        not self == other
+    def __gt__(self, other):
+        return True
+    def __ge__(self, other):
+        return True
+
+def process_accessed():
+    acc_names = dict(accessed)
+    for name, attr in accessed:
+        parts = name.split('.')
+        for i in range(1, len(parts)):
+            subname = '.'.join(parts[0:i])
+            if subname not in acc_names:
+                acc_names[subname] = parts[i]
+    return set(acc_names.iteritems())
+
+def begin():
+    global previous_imports, my_imports, accesses, failed_imports
+    my_imports = list(set(listmods()) - set(previous_imports))
+    accesses = {}
+    failed_imports = set()
+
+def end():
+    global collect_data
+    collect_data = False
+
+def add_import(name):
+    """Add a module to the import list
+
+    Expects to be called in the order in which modules are created:
+    package, submodule, etc.
+
+    """
+    if collect_data:
+        accesses[name] = set()
+ 
+def remove_import(name):
+    del accesses[name]
+    failed_imports.add(name)
+
+def add_access(name, attr):
+    if collect_data:
+        accesses[name].add(attr)
+
+def get_previous_imports():
+    """Return a new sorted name list of previously imported modules"""
+    return sorted(previous_imports)
+
+def get_my_imports():
+    """Return a new sorted name list of module imported by this package"""
+    return sorted(my_imports)
+
+def get_imports():
+    """Return a new sorted name list of imported modules"""
+    tracked_types = (module.Module, module.TrackerModule)
+    return sorted(n for n, m in list(sys.modules.iteritems())
+                    if isinstance(m, tracked_types))
+
+def get_unaccessed_modules():
+    """Return a new sorted name list of unaccessed imported modules"""
+    unaccessed = []
+    iaccessed = iter(get_accessed_modules())
+    accessed_name = ''
+    for imports_name in get_imports():
+        while accessed_name < imports_name:
+            try:
+                accessed_name = next(iaccessed)
+            except StopIteration:
+                accessed_name = Largest()
+        if imports_name < accessed_name:
+            unaccessed.append(imports_name)
+    return unaccessed
+
+def get_accessed_modules():
+    """Return a new sorted name list of accessed modules"""
+    accessed = []
+    previous_name = ''
+    for name, ignored in module.get_accesses():
+        if name != previous_name:
+            accessed.append(name)
+            previous_name = name
+    return accessed
+
+def get_accesses():
+    """Return a new dictionary of sorted lists of attributes by module name"""
+    accesses = {}
+    previous_name = ''
+    for name, attribute in module.get_accesses():
+        if name != previous_name:
+            attributes = []
+            accesses[name] = attributes
+            previous_name = name
+        attributes.append(attribute)
+    return accesses
+
+
+
+
+
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.