1. Gringo Suave
  2. dumbo

Commits

Mike Miller  committed 4b340ee Draft

move to package

  • Participants
  • Parent commits 668bc0f
  • Branches default

Comments (0)

Files changed (4)

File dumbo

-#!/usr/bin/env python
-'''
-    Dumbo, a pgadmin-style database viewer for the console.
-    (c) 2011, Mike Miller, ...
-    Released under the GPL, v2.
-        http://www.gnu.org/licenses/gpl-2.0.html
-
-    Usage:
-        %prog [options] [inifile]
-'''
-if True:  # foldable init
-    import sys, os, logging
-    from optparse import OptionParser
-    from  ConfigParser import RawConfigParser as rcp
-    import urwid as u
-    from utk import *
-
-    appname     = 'dumbo'
-    __version__ = '0.61a'
-    title       = ' %s %s ' % (appname.title(), __version__)
-    inipath     = ('dumbo.ini', '~/.config/dumbo.ini', '/etc/dumbo.ini')
-    debug       = False
-
-    # build up some defaults in case .ini file not found.
-    defdb = dict(type='postgres')
-    defconfig = dict(dumbo=dict(theme='dark', db=defdb, use256=True))
-    defpalette = [
-        # name          fg       bg       mono fg256   bg256
-        ('header',      'white', 'black', '', 'g90', 'g20'),
-        ('left',        '', '', '', 'g70', 'g10'),
-        ('right',       '', '', '', 'g80', ''),
-        ('footer',      'light gray', 'black', '', 'g70', 'g20'),
-        ('logo',        'dark gray', '', '', 'g30', ''),
-
-        ('body',        '', '', 'standout', 'g70', 'g10'),
-        ('selected',    'white', 'dark red', ('bold','underline'), 'g90', '#800'),
-        ('focus',       'white', 'dark blue', 'standout', 'g90', '#008'),
-        ('selected focus',
-                        'white', 'dark magenta', ('bold','standout','underline'), '#fff', '#707'),
-        ('dirmark',     'white', '', 'standout'),
-        # logging
-        ('debug',       'light blue', 'black', 'underline', '#69f', 'g20'),
-        ('info',        'light green', 'black', 'underline', '#0b0', 'g20'),
-        ('warning',     'yellow', 'black', 'underline', '#cc0', 'g20'),
-        ('error',       'light red', 'black', 'underline', '#c00', 'g20'),
-        ('critical',    'white', 'dark red', 'underline', '#fff', '#800'),
-    ]
-    formatter = logging.Formatter(
-        '%(asctime)s %(levelname)-7.7s %(funcName)s: %(message)s',
-        '%Y-%m-%d %H:%M:%S')
-
-    # Utility classes
-    # Derived classes from the TreeListBox in the utk module.
-    class SideTree(TreeListBox):
-        '''Display widget for leaf nodes.'''
-        def __init__(self, body, statxt):
-            super(SideTree, self).__init__(body)
-            self.statxt = statxt
-
-        def keypress(self, size, key):
-            return super(SideTree, self).keypress(size, key)
-
-    class LeafWidget(TreeWidget):
-        '''Display widget for leaf nodes.'''
-        def get_display_text(self):
-            return self.get_node().get_value()['name']
-
-    class BranchWidget(ParentWidget):
-        '''Display widget for interior/parent nodes.'''
-        def get_display_text(self):
-            return self.get_node().get_value()['name']
-
-    class LeafNode(TreeNode):
-        '''Data storage object for leaf nodes.'''
-        def load_widget(self):
-            return LeafWidget(self)
-
-    class BranchNode(ParentNode):
-        '''Data storage object for interior/parent nodes.'''
-        def load_widget(self):
-            return BranchWidget( self,
-                expanded=(False if self._parent else True) )
-
-        def load_child_keys(self):
-            data = self.get_value()
-            return range(len(data['children']))
-
-        def load_child_node(self, key):
-            '''Return either an LeafNode or BranchNode.'''
-            childdata = self.get_value()['children'][key]
-            childdepth = self.get_depth() + 1
-            if 'children' in childdata:
-                childclass = BranchNode
-            else:
-                childclass = LeafNode
-            return childclass(childdata, parent=self, key=key, depth=childdepth)
-
-    class mutable_list(list): pass  # in order to attach attribs
-
-    class _StatusBarHandler(logging.Handler):
-        'Log to Statusbar.'
-        def __init__(self, statusbar, level=logging.NOTSET):
-            logging.Handler.__init__(self, level=level)
-            self.statusbar = statusbar
-
-        def emit(self, record):
-            try:
-                msg = self.format(record)
-                if record.levelname in msg:
-                    part1, levelname, part2 = msg.partition(record.levelname)
-                    msg = [part1, (levelname.lower(), levelname), part2]
-                self.statusbar.set_text(msg)
-                u.AttrMap(self.statusbar, 'footer')
-            except:
-                self.handleError(record)
-
-    class DBField(object):
-        'Field that writes its value to Database on Enter key.'
-        def unhandled_input(self, size, input, get_value_method):
-            if input == 'enter':                    # submit data
-                dbinf = self._db_inf
-                gparent = dbinf['gparent']          # dumbo self.
-                if not gparent.db:  return input    # using dummy data
-
-                # get inner value and create sql assignment string
-                myvalue = get_value_method()
-                if isinstance(myvalue, basestring): # quote strings
-                    myvalue = "'%s'" % myvalue.replace("'", "''")
-                dbinf['pair'] = '%s = %s' % (dbinf['fdname'], myvalue)
-
-                # create query string and execute
-                q = gparent.config['dbtype']['query_update'] % dbinf
-                gparent.db.query(dbinf['dbname'], q, fetchresults=False)
-                self.set_edit_text(get_value_method())
-                self._orig = get_value_method() # in case blur before commit
-            elif input in ('tab','left','right','up','down'):  # restore
-                if hasattr(self, '_orig'): self.set_edit_text(self._orig)
-                return input
-            else:
-                return input
-
-    class EditDBStr(u.Edit, DBField):
-        'An Edit widget for strings that may write to a Database.'
-        def __init__(self, *args, **kwargs):
-            self.maxlen = kwargs.pop('maxlen', None)
-            self.__super.__init__(*args, **kwargs)
-            self._orig = self.get_edit_text() # keep for later in case no commit
-        def keypress(self, size, key):
-            key = self.__super.keypress(size, key)
-            myvalue = self.get_edit_text()
-            if self.maxlen and len(myvalue) >= self.maxlen:
-                self.set_edit_text(myvalue[:self.maxlen])
-            return self.unhandled_input(size, key, self.get_edit_text)
-
-    class EditDBInt(u.IntEdit, DBField):
-        'An Edit widget for integers that may write to a Database.'
-        def __init__(self, *args, **kwargs):
-            self.maxint = kwargs.pop('maxint', None)
-            self.__super.__init__(*args, **kwargs)
-            self._orig = self.value()  # keep for later in case no commit
-        def valid_char(self, ch):
-            return len(ch)==1 and ch in '-0123456789'
-        def keypress(self, size, key):
-            if self.edit_text == '-0' and key == 'backspace':  # before handler
-                self.edit_text = ''
-            key = self.__super.keypress(size, key)
-            if '-' in self.edit_text:
-                self.edit_text = '-' + self.edit_text.replace('-', '') # > 1
-                if self.edit_text == '-':   # long() doesn't like '-'
-                    self.edit_text = '-0'
-                    self.set_edit_pos(2)
-                elif self.edit_text == '-0':
-                    self.edit_text = ''
-            if self.maxint:
-                val = self.value()
-                if val >= self.maxint:
-                    self.set_edit_text(str(self.maxint - 1))
-                elif val <= (-self.maxint) - 1: #
-                    self.set_edit_text(str(-self.maxint))
-
-            return self.unhandled_input(size, key, self.value)
-
-
-class DbConnector(object):
-    'Keep track of DB connections.'
-    def __init__(self, module):
-        self.conns = {}
-        try:
-            self.module = __import__(module)
-        except ImportError, e:
-            raise RuntimeError, 'Database API module not found.'
-
-    def close(self, conn=None):
-        'Close one or all open connections.'
-        if conn:
-            self.cons[conn].close()
-            self.cons[conn] = None
-        else:
-            for c in self.conns:
-                log.debug('dbconn:' + repr(self.conns[c]))
-                if self.conns.get(c): self.conns[c].close()
-
-    def connect(self, name, **kwargs):
-        'Connecten-Sie to die Database.'
-        dsn = []
-        for item in kwargs.items():
-            if item[1]:  # not blank
-                dsn.append("%s='%s'" % item)
-        dsn = ' '.join(dsn)
-        log.debug(self.module.__name__ + ':' + `kwargs`)
-
-        conn = self.module.connect(dsn);
-        self.conns[name] = conn
-
-    def query(self, dbname, querystr, fetchresults=True):
-        'Query database.'
-        try:
-            import psycopg2
-        except ImportError:
-            class dummy(str): pass
-            psycopg2 = dummy()  # :/
-            psycopg2.ProgrammingError, psycopg2.InternalError = None, None
-        retval = True # success
-
-        cur = self.conns[dbname].cursor()
-        try:
-            log.info('"%s" against %s' % (querystr, dbname))
-            cur.execute(querystr)
-            if fetchresults:    results = cur.fetchall()
-            else:               results = None
-            log.debug('results: %.512r...' % results)
-            retval = results
-        except (psycopg2.ProgrammingError, psycopg2.InternalError), e:  # fix trans
-            cur.execute('rollback;')    # cannot recover without this
-            msg = '%s: %s' % (e.__class__.__name__, e)
-            log.error(msg.rstrip())
-            retval = False  # known err
-        except Exception, e:
-            msg = '%s: %s' % (e.__class__.__name__, e)
-            log.error(msg.rstrip())
-            retval = None   # generic err
-
-        return retval
-
-
-class DumboFrame(u.Frame):
-    '''The main window of the application.'''
-    def __init__(self, args):
-        self.palette = defpalette       # defaults to fall back on
-        self.config = defconfig
-        self.load_config(args)          # read in ini file
-        self.greeting = ' Ready!'
-        self.legendcache = {}           # keeps track of db/tbl/col names/types
-
-        # set up header and footer
-        menubar = u.AttrMap(u.Text(''), 'header')  # todo
-        exitbut = ( u'\u2715' if unichar_avail else 'x')
-        exitbut = Button(exitbut, align='right', on_press=self.on_exit_but)
-        exitbut = u.AttrMap(exitbut, 'header')
-        header = u.Columns( [menubar, ('fixed', 3, exitbut)], 0 )
-        # footer
-        self.statxt = u.Text(self.greeting)
-        footttl = u.AttrMap(u.Text(title, align='right'), 'footer')
-        footer = u.Columns([
-            ('weight', 1, self.statxt),
-            ('fixed', len(title), footttl)
-            ], 1)
-        footer = u.AttrMap(footer, 'footer')
-
-        # log to status bar
-        hdlr = _StatusBarHandler(self.statxt)
-        formatter = logging.Formatter(
-            ' %(levelname)s %(funcName)s: %(message)s')
-        hdlr.setFormatter(formatter)
-        log.addHandler(hdlr)
-
-        # populate sidebar
-        try:
-            self.db = DbConnector(self.config['dbtype']['module'])
-        except (TypeError, IndexError, RuntimeError, KeyError), e:
-            log.error('%s: %s' % (e.__class__.__name__, e))
-            log.warn('Unable to connect to database using config, falling ' +
-                'back to demo mode.')
-            self.db = None
-        db_objs = self.populate_sidebar()
-        self.sidebar = SideTree(TreeWalker(BranchNode(db_objs)), self.statxt)
-
-        # populate content pane
-        font = ( u.font.HalfBlock5x4Font() if unichar_avail
-            else u.font.Thin6x6Font() )
-        bt = u.BigText( ('logo', title), font)
-        self.logo = u.Padding(bt, 'right', width='clip')
-        self.contpane = u.ListBox([u.Text(''), self.logo])
-
-        mainbody = u.Columns([
-            ('weight', 1, u.AttrMap(self.sidebar, 'left')),
-            ('weight', 3, u.AttrMap(self.contpane, 'right')),
-        ], 1)  # 1 char spacer
-        super(DumboFrame, self).__init__(mainbody, header=header, footer=footer)
-
-    def on_exit_but(self, args):
-        raise u.ExitMainLoop()
-
-    def get_tbl_legend(self, dbname, tablename):
-        'Retreive the column metadata from a table.'
-        key = (dbname, tablename)
-        legend = self.legendcache.get(key)
-        if legend:
-            log.debug('Using cached legend: %s' % legend)
-        else:
-            try: # get table legend
-                q = self.config['dbtype']['query_tablecols'] % tablename
-                results = self.db.query(dbname, q)
-                legend = mutable_list( results )
-                for colmd in legend:
-                    log.debug( repr(colmd) )
-
-                # attempt retrieval of primary key info
-                legend.pkeys = None
-                q2 = self.config['dbtype'].get('query_tablecols2')
-                if q2:
-                    q2 = q2 % tablename
-                    results = self.db.query(dbname, q2)
-                    if results:
-                        legend.pkeys = results
-                self.legendcache[key] = legend
-            except Exception, e:
-                import traceback
-                log.error('Legend not found: %s' % traceback.format_exc())
-        return legend
-
-    def list_db_cluster(self):
-        'Find the dbs and tables at this connection.'
-        dbconns = self.db.conns
-        dbparams =  self.config['db']
-        primarydb = dbparams['dbname']
-        tbfilters = dbparams.pop('tablefilters', None)
-        if tbfilters:
-            if not type(tbfilters) is tuple:
-                tbfilters = (tbfilters,)
-        else:
-            tbfilters = tuple()
-        treedata = []
-
-        # list databases
-        self.db.connect(primarydb, **dbparams)
-        cur = dbconns[primarydb].cursor()
-        cur.execute(self.config['dbtype']['query_databases'])
-        for db in cur.fetchall():
-            dbname = str(db[0])
-            if dbname != primarydb:
-                dbconns[dbname] = None  # prepare for multiple connections.
-
-        # list tables under each database, pg needs a uniq conn for each.
-        for dbname in dbconns:
-            if not dbconns.get(dbname):     # set up new conns to database
-                dbparams['dbname'] = dbname
-                try:
-                    self.db.connect(dbname, **dbparams)
-                except Exception:
-                    continue
-
-            dbattrs = dict(name=dbname)
-            children = []
-            cur = dbconns[dbname].cursor()
-            cur.execute(self.config['dbtype']['query_tables'])
-            for tablename in cur.fetchall():
-                match = False
-                for dbf in tbfilters:    #  :/
-                    if dbf in tablename[0]:
-                        match = True
-                if not match:
-                    children.append(dict(name=str(tablename[0])))
-            dbattrs['children'] = sorted( children, key=(lambda k: k['name']) )
-            treedata.append(dbattrs)
-        return sorted( treedata, key=(lambda k: k['name']) )
-
-    def load_config(self, args):
-        'Search for and read in config file, then set up a few things.'
-        import csv
-        def convert_type(value):
-            'convert values where possible'
-            value = value.strip()
-            if value.isdigit():     value = int(value)
-            elif value == 'False':  value = False
-            elif value == 'True':   value = True
-            elif ',' in value:
-                values = csv.reader([value],
-                    quoting=csv.QUOTE_MINIMAL, skipinitialspace=True).next()
-                values = tuple( x.strip() for x in values )
-                value = values
-            return value
-
-        # find ini file
-        inifname = None
-        if args:
-            if os.access(args[0], os.R_OK):
-                inifname = args[0]
-            else:
-                print 'Error: filename "%s" not found.' % args[0]
-                sys.exit(3)
-        else:
-            for path in inipath:
-                path = os.path.expanduser(path)
-                if os.access(path, os.R_OK):
-                    inifname = path
-        if inifname:
-            cp = rcp()
-            cp.read(inifname)
-            config = dict(  (section, dict(cp.items(section)))  # conv to dicts
-                            for section in cp.sections()   )
-
-            # get theme info
-            theme = config[appname]['theme']
-            theme = config.get('theme_' + theme)
-            palette = []
-            for item in theme.items():
-                palette.append( (item[0],) + convert_type(item[1]) )
-
-            # get db info and copy to root
-            dbsection = 'db_' + config[appname]['db']
-            config['db'] = config.get(dbsection)
-            if config['db']:
-                if 'tablefilters' in config['db']:
-                    config['db']['tablefilters'] = (
-                        convert_type(config['db']['tablefilters']) )
-                dbtypesect = 'dbtype_' + config['db'].pop('type', '')
-                config['dbtype'] = config.get(dbtypesect)
-                if config['dbtype']:
-                    config[appname]['use256'] = (config[appname]['use256'].title() == 'True')
-                    self.config, self.palette = config, palette
-                else:
-                    raise RuntimeError, 'Database Type section "%s" not found. ' % dbtypesect
-            else:
-                raise RuntimeError, 'Database section "%s" not found. ' % dbsection
-        else:
-            log.error('ini file not found.')
-
-    def onActivate(self, input):
-        'An object in the sidebar has been activated.'
-        i = self.body.get_focus_column()
-        item, pos = self.sidebar.get_focus()
-        itemtype = 'Table'
-        if type(item) is BranchWidget:    # XP behavior, on open close others
-            itemtype = 'Database'
-            for method in ('next_sibling', 'prev_sibling'):
-                brother = item
-                while True:   # this sucks, no simple traversal
-                    brother = getattr(brother.get_node(), method)()
-                    if not brother:  break
-                    brother = brother.get_widget()
-                    if type(brother) is BranchWidget:
-                        brother.expanded = False
-                        brother.update_widget()
-            item.expanded = not item.expanded
-            item.update_widget()
-
-        if self.db:
-            currname = item.get_display_text()
-            txt = '%s "%s" activated.' % (itemtype, currname )
-            log.info(str(input).upper() + ', %s' % txt)
-
-            if itemtype == 'Table':
-                # get db name
-                dbname = item.get_node().get_parent().get_value()['name']
-
-                # get table legend and data
-                legend = self.get_tbl_legend(dbname, currname)
-                if legend.pkeys:
-                    pkeys = ','.join( (str(x[1]) for x in legend.pkeys ) )
-                    orderby = 'order by %s' % pkeys
-                elif legend.pkeys and ('id' in [ x[1] for x in legend.pkeys ]):
-                    orderby = 'order by id'
-                else:
-                    orderby = ''
-
-                q = self.config['dbtype']['query_table'] % (currname, orderby)
-                results = self.db.query(dbname, q)
-                if results:
-                    self.statxt.set_text(self.greeting)
-                elif results is False:  # fallback, try again
-                    q = 'select * from %s;' % currname
-                    results = self.db.query(dbname, q)
-                self.populate_cont(results or [], dbname, currname, legend)
-        else:
-            if type(item) is LeafWidget:
-                legend = mutable_list(
-                    [[1, 'id', 'character varying', None, 'YES', 255, None]] )
-                legend.pkeys = ((1, 'id'),)
-                for i in range(2,10):
-                    dummytype = legend[0][:]  # copy
-                    dummytype[0] = i
-                    dummytype[1] = 'col%02d' % i
-                    legend.append(dummytype)
-                dummy = ('The quick brown fox jumped over the lazy dog.'.split()
-                    ,) * 20
-                self.populate_cont(dummy, legend=legend)
-
-    def populate_cont(self, data, dbname=None, tbname=None, legend=None):
-        'Load up the content pane.'
-        rows = []
-        if legend:  # show name/type as column header
-            row = [ u.AttrMap( u.Text('%s\n%s' % (x[1],x[2])), 'logo' )
-                    for x in legend ]
-            rows.append( u.Columns(row, 1, 0, 4, ) )
-            if hasattr(legend, 'pkeys') and legend.pkeys:
-                pkeys = legend.pkeys
-            else:
-                pkeys = None
-            log.debug('pkeys:' + repr(pkeys))
-            # figure which col is which field, could be more efficient
-            nmfromi = dict(( (i,leg[1]) for i,leg in enumerate(legend) ))
-            ifromnm = dict(( (leg[1], i) for i,leg in enumerate(legend) ))
-
-            for row in data:
-                try:
-                    if not pkeys:  raise KeyError
-                    newrow = []
-                    # save primary key for later inside edit cell
-                    where = ''  #  pkstr = val, ...
-                    for j, pkindex in enumerate(pkeys):  # pkindex is 1-based
-                        pkindex, pkstr = pkindex
-                        if legend[pkindex-1][2] == 'integer':
-                            where += '%s = %s' % (pkstr, row[pkindex-1])
-                        else:
-                            where += "%s = '%s'" % (pkstr, row[pkindex-1])
-                        if j != len(pkeys) - 1:
-                            where += ' and '
-                    log.debug('where clause: ' + where)
-
-                    # populate pane, modest type enforcement
-                    for i, field in enumerate(row):
-                        if legend[i][2] == 'integer':
-                            maxint = 2**(legend[i][6]-1)-1
-                            editor = EditDBInt('', field, maxint=maxint)
-                        else:
-                            editor = EditDBStr('', str(field), maxlen=legend[i][5])
-                        editor._db_inf = { 'where': where,
-                            'dbname': dbname, 'tbname': tbname,
-                            'fdname': nmfromi[i], 'gparent': self,
-                        }
-                        newrow.append( u.AttrMap(editor, 'right', 'focus') )
-                except KeyError:  # fall back to read only.
-                    newrow = [ u.AttrMap(u.Text(str(x)), 'right') for x in row ]
-                rows.append( u.Columns(newrow, 1, None, 4, ) )
-        else:
-            for row in data:
-                row = [ u.AttrMap(u.Text(str(x)), 'right') for x in row ]
-                rows.append( u.Columns(row, 1, None, 4, ) )
-
-        self.contpane.body = u.SimpleListWalker(rows + [u.Text(''), u.Text(''),
-            self.logo])
-
-    def populate_sidebar(self):
-        'Load up the sidebar.'
-        if self.db:
-            treedata = self.list_db_cluster()
-            treedata = { 'name': 'Databases', 'children': treedata }
-        else:
-            log.warn('Database not found, using dummy data.')
-            children = [ dict(name='child%s' % i) for i in range(10) ]
-            children[3]['children'] = [ dict(name='child%s' % i) for i in range(10) ]
-            children[5]['children'] = [ dict(name='child%s' % i) for i in range(10) ]
-            treedata = {    'name': 'Databases (N/A)',
-                        'children': children
-            }
-        return treedata
-
-    def unhandled_input(self, input):
-        '''Handle input that wasn't already.'''
-        if self.db: dbconns = self.db.conns
-        else:       dbconns = ()
-        i = self.body.get_focus_column()
-        log.debug('focus in column:%s' % i)
-
-        if input in ('q', 'Q', 'esc'):
-            if self.db: self.db.close()
-            raise u.ExitMainLoop()
-        elif input in ('tab', 'right'):
-            try:
-                self.body.set_focus(i+1)
-                # keep focus out of legend, doesn't work with right key :/
-                if self.contpane.get_focus()[1] == 0:
-                    self.contpane.set_focus(1, coming_from='above')
-            except AssertionError:  self.body.set_focus(0)
-            log.debug('tab: col %s to %s' % (i, i+1))
-
-        elif i == 0:  # self.sidebar
-            if input == 'enter':
-                self.onActivate(input)
-            elif type(input) is tuple:  # mouse event
-                if input[0] == 'mouse release' and input[1] in (0,1):
-                    self.onActivate(input)
-        else:
-            log.debug(str(input))
-
-
-if __name__=='__main__':
-    parser = OptionParser(usage=__doc__.rstrip(), version=__version__)
-    parser.add_option(
-        '-l', '--logfile', metavar="F", default='log_dumbo.txt',
-        help='Log activity to this file. Default: %default.')
-    parser.add_option('-v', '--verbose', action='store_true',
-        help='Enable verbose output to log.')
-    parser.add_option('-V', '--very-verbose', action='store_true',
-        help='Enable ridiculous amounts of debugging output.')
-
-    (opts, args) = parser.parse_args()
-
-    # set up logging
-    log = logging.getLogger(__name__)
-    log.setLevel( (logging.DEBUG if debug else logging.WARN) )
-    if opts.verbose:        log.setLevel(logging.INFO)
-    if opts.very_verbose:   log.setLevel(logging.DEBUG)
-    hdlr = logging.FileHandler(opts.logfile)
-    hdlr.setFormatter(formatter)
-    log.addHandler(hdlr)
-
-    # get started
-    df = DumboFrame(args)
-
-    loop = u.MainLoop(df, df.palette, unhandled_input=df.unhandled_input)
-    if df.config and df.config[appname]['use256'] and hicolor_avail:
-        loop.screen.set_terminal_properties(colors=256)
-    try:
-        loop.run()
-    except Exception, e:    # try to close all db connections
-        print e.__class__.__name__, e, '.  Check log for details.\n'
-        import traceback
-        log.critical(traceback.format_exc())
-        if df.db:
-            df.db.close()
-

File dumbo/dumbo.py

View file
+#!/usr/bin/env python
+'''
+    Dumbo, a pgadmin-style database viewer for the console.
+    (c) 2011, Mike Miller, ...
+    Released under the GPL, v2.
+        http://www.gnu.org/licenses/gpl-2.0.html
+
+    Usage:
+        %prog [options] [inifile]
+'''
+if True:  # foldable init
+    import sys, os, logging
+    from optparse import OptionParser
+    from  ConfigParser import RawConfigParser as rcp
+    import urwid as u
+    from utk import *
+
+    appname     = 'dumbo'
+    __version__ = '0.61a'
+    title       = ' %s %s ' % (appname.title(), __version__)
+    inipath     = ('dumbo.ini', '~/.config/dumbo.ini', '/etc/dumbo.ini')
+    debug       = False
+
+    # build up some defaults in case .ini file not found.
+    defdb = dict(type='postgres')
+    defconfig = dict(dumbo=dict(theme='dark', db=defdb, use256=True))
+    defpalette = [
+        # name          fg       bg       mono fg256   bg256
+        ('header',      'white', 'black', '', 'g90', 'g20'),
+        ('left',        '', '', '', 'g70', 'g10'),
+        ('right',       '', '', '', 'g80', ''),
+        ('footer',      'light gray', 'black', '', 'g70', 'g20'),
+        ('logo',        'dark gray', '', '', 'g30', ''),
+
+        ('body',        '', '', 'standout', 'g70', 'g10'),
+        ('selected',    'white', 'dark red', ('bold','underline'), 'g90', '#800'),
+        ('focus',       'white', 'dark blue', 'standout', 'g90', '#008'),
+        ('selected focus',
+                        'white', 'dark magenta', ('bold','standout','underline'), '#fff', '#707'),
+        ('dirmark',     'white', '', 'standout'),
+        # logging
+        ('debug',       'light blue', 'black', 'underline', '#69f', 'g20'),
+        ('info',        'light green', 'black', 'underline', '#0b0', 'g20'),
+        ('warning',     'yellow', 'black', 'underline', '#cc0', 'g20'),
+        ('error',       'light red', 'black', 'underline', '#c00', 'g20'),
+        ('critical',    'white', 'dark red', 'underline', '#fff', '#800'),
+    ]
+    formatter = logging.Formatter(
+        '%(asctime)s %(levelname)-7.7s %(funcName)s: %(message)s',
+        '%Y-%m-%d %H:%M:%S')
+
+    # Utility classes
+    # Derived classes from the TreeListBox in the utk module.
+    class SideTree(TreeListBox):
+        '''Display widget for leaf nodes.'''
+        def __init__(self, body, statxt):
+            super(SideTree, self).__init__(body)
+            self.statxt = statxt
+
+        def keypress(self, size, key):
+            return super(SideTree, self).keypress(size, key)
+
+    class LeafWidget(TreeWidget):
+        '''Display widget for leaf nodes.'''
+        def get_display_text(self):
+            return self.get_node().get_value()['name']
+
+    class BranchWidget(ParentWidget):
+        '''Display widget for interior/parent nodes.'''
+        def get_display_text(self):
+            return self.get_node().get_value()['name']
+
+    class LeafNode(TreeNode):
+        '''Data storage object for leaf nodes.'''
+        def load_widget(self):
+            return LeafWidget(self)
+
+    class BranchNode(ParentNode):
+        '''Data storage object for interior/parent nodes.'''
+        def load_widget(self):
+            return BranchWidget( self,
+                expanded=(False if self._parent else True) )
+
+        def load_child_keys(self):
+            data = self.get_value()
+            return range(len(data['children']))
+
+        def load_child_node(self, key):
+            '''Return either an LeafNode or BranchNode.'''
+            childdata = self.get_value()['children'][key]
+            childdepth = self.get_depth() + 1
+            if 'children' in childdata:
+                childclass = BranchNode
+            else:
+                childclass = LeafNode
+            return childclass(childdata, parent=self, key=key, depth=childdepth)
+
+    class mutable_list(list): pass  # in order to attach attribs
+
+    class _StatusBarHandler(logging.Handler):
+        'Log to Statusbar.'
+        def __init__(self, statusbar, level=logging.NOTSET):
+            logging.Handler.__init__(self, level=level)
+            self.statusbar = statusbar
+
+        def emit(self, record):
+            try:
+                msg = self.format(record)
+                if record.levelname in msg:
+                    part1, levelname, part2 = msg.partition(record.levelname)
+                    msg = [part1, (levelname.lower(), levelname), part2]
+                self.statusbar.set_text(msg)
+                u.AttrMap(self.statusbar, 'footer')
+            except:
+                self.handleError(record)
+
+    class DBField(object):
+        'Field that writes its value to Database on Enter key.'
+        def unhandled_input(self, size, input, get_value_method):
+            if input == 'enter':                    # submit data
+                dbinf = self._db_inf
+                gparent = dbinf['gparent']          # dumbo self.
+                if not gparent.db:  return input    # using dummy data
+
+                # get inner value and create sql assignment string
+                myvalue = get_value_method()
+                if isinstance(myvalue, basestring): # quote strings
+                    myvalue = "'%s'" % myvalue.replace("'", "''")
+                dbinf['pair'] = '%s = %s' % (dbinf['fdname'], myvalue)
+
+                # create query string and execute
+                q = gparent.config['dbtype']['query_update'] % dbinf
+                gparent.db.query(dbinf['dbname'], q, fetchresults=False)
+                self.set_edit_text(get_value_method())
+                self._orig = get_value_method() # in case blur before commit
+            elif input in ('tab','left','right','up','down'):  # restore
+                if hasattr(self, '_orig'): self.set_edit_text(self._orig)
+                return input
+            else:
+                return input
+
+    class EditDBStr(u.Edit, DBField):
+        'An Edit widget for strings that may write to a Database.'
+        def __init__(self, *args, **kwargs):
+            self.maxlen = kwargs.pop('maxlen', None)
+            self.__super.__init__(*args, **kwargs)
+            self._orig = self.get_edit_text() # keep for later in case no commit
+        def keypress(self, size, key):
+            key = self.__super.keypress(size, key)
+            myvalue = self.get_edit_text()
+            if self.maxlen and len(myvalue) >= self.maxlen:
+                self.set_edit_text(myvalue[:self.maxlen])
+            return self.unhandled_input(size, key, self.get_edit_text)
+
+    class EditDBInt(u.IntEdit, DBField):
+        'An Edit widget for integers that may write to a Database.'
+        def __init__(self, *args, **kwargs):
+            self.maxint = kwargs.pop('maxint', None)
+            self.__super.__init__(*args, **kwargs)
+            self._orig = self.value()  # keep for later in case no commit
+        def valid_char(self, ch):
+            return len(ch)==1 and ch in '-0123456789'
+        def keypress(self, size, key):
+            if self.edit_text == '-0' and key == 'backspace':  # before handler
+                self.edit_text = ''
+            key = self.__super.keypress(size, key)
+            if '-' in self.edit_text:
+                self.edit_text = '-' + self.edit_text.replace('-', '') # > 1
+                if self.edit_text == '-':   # long() doesn't like '-'
+                    self.edit_text = '-0'
+                    self.set_edit_pos(2)
+                elif self.edit_text == '-0':
+                    self.edit_text = ''
+            if self.maxint:
+                val = self.value()
+                if val >= self.maxint:
+                    self.set_edit_text(str(self.maxint - 1))
+                elif val <= (-self.maxint) - 1: #
+                    self.set_edit_text(str(-self.maxint))
+
+            return self.unhandled_input(size, key, self.value)
+
+
+class DbConnector(object):
+    'Keep track of DB connections.'
+    def __init__(self, module):
+        self.conns = {}
+        try:
+            self.module = __import__(module)
+        except ImportError, e:
+            raise RuntimeError, 'Database API module not found.'
+
+    def close(self, conn=None):
+        'Close one or all open connections.'
+        if conn:
+            self.cons[conn].close()
+            self.cons[conn] = None
+        else:
+            for c in self.conns:
+                log.debug('dbconn:' + repr(self.conns[c]))
+                if self.conns.get(c): self.conns[c].close()
+
+    def connect(self, name, **kwargs):
+        'Connecten-Sie to die Database.'
+        dsn = []
+        for item in kwargs.items():
+            if item[1]:  # not blank
+                dsn.append("%s='%s'" % item)
+        dsn = ' '.join(dsn)
+        log.debug(self.module.__name__ + ':' + `kwargs`)
+
+        conn = self.module.connect(dsn);
+        self.conns[name] = conn
+
+    def query(self, dbname, querystr, fetchresults=True):
+        'Query database.'
+        try:
+            import psycopg2
+        except ImportError:
+            class dummy(str): pass
+            psycopg2 = dummy()  # :/
+            psycopg2.ProgrammingError, psycopg2.InternalError = None, None
+        retval = True # success
+
+        cur = self.conns[dbname].cursor()
+        try:
+            log.info('"%s" against %s' % (querystr, dbname))
+            cur.execute(querystr)
+            if fetchresults:    results = cur.fetchall()
+            else:               results = None
+            log.debug('results: %.512r...' % results)
+            retval = results
+        except (psycopg2.ProgrammingError, psycopg2.InternalError), e:  # fix trans
+            cur.execute('rollback;')    # cannot recover without this
+            msg = '%s: %s' % (e.__class__.__name__, e)
+            log.error(msg.rstrip())
+            retval = False  # known err
+        except Exception, e:
+            msg = '%s: %s' % (e.__class__.__name__, e)
+            log.error(msg.rstrip())
+            retval = None   # generic err
+
+        return retval
+
+
+class DumboFrame(u.Frame):
+    '''The main window of the application.'''
+    def __init__(self, args):
+        self.palette = defpalette       # defaults to fall back on
+        self.config = defconfig
+        self.load_config(args)          # read in ini file
+        self.greeting = ' Ready!'
+        self.legendcache = {}           # keeps track of db/tbl/col names/types
+
+        # set up header and footer
+        menubar = u.AttrMap(u.Text(''), 'header')  # todo
+        exitbut = ( u'\u2715' if unichar_avail else 'x')
+        exitbut = Button(exitbut, align='right', on_press=self.on_exit_but)
+        exitbut = u.AttrMap(exitbut, 'header')
+        header = u.Columns( [menubar, ('fixed', 3, exitbut)], 0 )
+        # footer
+        self.statxt = u.Text(self.greeting)
+        footttl = u.AttrMap(u.Text(title, align='right'), 'footer')
+        footer = u.Columns([
+            ('weight', 1, self.statxt),
+            ('fixed', len(title), footttl)
+            ], 1)
+        footer = u.AttrMap(footer, 'footer')
+
+        # log to status bar
+        hdlr = _StatusBarHandler(self.statxt)
+        formatter = logging.Formatter(
+            ' %(levelname)s %(funcName)s: %(message)s')
+        hdlr.setFormatter(formatter)
+        log.addHandler(hdlr)
+
+        # populate sidebar
+        try:
+            self.db = DbConnector(self.config['dbtype']['module'])
+        except (TypeError, IndexError, RuntimeError, KeyError), e:
+            log.error('%s: %s' % (e.__class__.__name__, e))
+            log.warn('Unable to connect to database using config, falling ' +
+                'back to demo mode.')
+            self.db = None
+        db_objs = self.populate_sidebar()
+        self.sidebar = SideTree(TreeWalker(BranchNode(db_objs)), self.statxt)
+
+        # populate content pane
+        font = ( u.font.HalfBlock5x4Font() if unichar_avail
+            else u.font.Thin6x6Font() )
+        bt = u.BigText( ('logo', title), font)
+        self.logo = u.Padding(bt, 'right', width='clip')
+        self.contpane = u.ListBox([u.Text(''), self.logo])
+
+        mainbody = u.Columns([
+            ('weight', 1, u.AttrMap(self.sidebar, 'left')),
+            ('weight', 3, u.AttrMap(self.contpane, 'right')),
+        ], 1)  # 1 char spacer
+        super(DumboFrame, self).__init__(mainbody, header=header, footer=footer)
+
+    def on_exit_but(self, args):
+        raise u.ExitMainLoop()
+
+    def get_tbl_legend(self, dbname, tablename):
+        'Retreive the column metadata from a table.'
+        key = (dbname, tablename)
+        legend = self.legendcache.get(key)
+        if legend:
+            log.debug('Using cached legend: %s' % legend)
+        else:
+            try: # get table legend
+                q = self.config['dbtype']['query_tablecols'] % tablename
+                results = self.db.query(dbname, q)
+                legend = mutable_list( results )
+                for colmd in legend:
+                    log.debug( repr(colmd) )
+
+                # attempt retrieval of primary key info
+                legend.pkeys = None
+                q2 = self.config['dbtype'].get('query_tablecols2')
+                if q2:
+                    q2 = q2 % tablename
+                    results = self.db.query(dbname, q2)
+                    if results:
+                        legend.pkeys = results
+                self.legendcache[key] = legend
+            except Exception, e:
+                import traceback
+                log.error('Legend not found: %s' % traceback.format_exc())
+        return legend
+
+    def list_db_cluster(self):
+        'Find the dbs and tables at this connection.'
+        dbconns = self.db.conns
+        dbparams =  self.config['db']
+        primarydb = dbparams['dbname']
+        tbfilters = dbparams.pop('tablefilters', None)
+        if tbfilters:
+            if not type(tbfilters) is tuple:
+                tbfilters = (tbfilters,)
+        else:
+            tbfilters = tuple()
+        treedata = []
+
+        # list databases
+        self.db.connect(primarydb, **dbparams)
+        cur = dbconns[primarydb].cursor()
+        cur.execute(self.config['dbtype']['query_databases'])
+        for db in cur.fetchall():
+            dbname = str(db[0])
+            if dbname != primarydb:
+                dbconns[dbname] = None  # prepare for multiple connections.
+
+        # list tables under each database, pg needs a uniq conn for each.
+        for dbname in dbconns:
+            if not dbconns.get(dbname):     # set up new conns to database
+                dbparams['dbname'] = dbname
+                try:
+                    self.db.connect(dbname, **dbparams)
+                except Exception:
+                    continue
+
+            dbattrs = dict(name=dbname)
+            children = []
+            cur = dbconns[dbname].cursor()
+            cur.execute(self.config['dbtype']['query_tables'])
+            for tablename in cur.fetchall():
+                match = False
+                for dbf in tbfilters:    #  :/
+                    if dbf in tablename[0]:
+                        match = True
+                if not match:
+                    children.append(dict(name=str(tablename[0])))
+            dbattrs['children'] = sorted( children, key=(lambda k: k['name']) )
+            treedata.append(dbattrs)
+        return sorted( treedata, key=(lambda k: k['name']) )
+
+    def load_config(self, args):
+        'Search for and read in config file, then set up a few things.'
+        import csv
+        def convert_type(value):
+            'convert values where possible'
+            value = value.strip()
+            if value.isdigit():     value = int(value)
+            elif value == 'False':  value = False
+            elif value == 'True':   value = True
+            elif ',' in value:
+                values = csv.reader([value],
+                    quoting=csv.QUOTE_MINIMAL, skipinitialspace=True).next()
+                values = tuple( x.strip() for x in values )
+                value = values
+            return value
+
+        # find ini file
+        inifname = None
+        if args:
+            if os.access(args[0], os.R_OK):
+                inifname = args[0]
+            else:
+                print 'Error: filename "%s" not found.' % args[0]
+                sys.exit(3)
+        else:
+            for path in inipath:
+                path = os.path.expanduser(path)
+                if os.access(path, os.R_OK):
+                    inifname = path
+        if inifname:
+            cp = rcp()
+            cp.read(inifname)
+            config = dict(  (section, dict(cp.items(section)))  # conv to dicts
+                            for section in cp.sections()   )
+
+            # get theme info
+            theme = config[appname]['theme']
+            theme = config.get('theme_' + theme)
+            palette = []
+            for item in theme.items():
+                palette.append( (item[0],) + convert_type(item[1]) )
+
+            # get db info and copy to root
+            dbsection = 'db_' + config[appname]['db']
+            config['db'] = config.get(dbsection)
+            if config['db']:
+                if 'tablefilters' in config['db']:
+                    config['db']['tablefilters'] = (
+                        convert_type(config['db']['tablefilters']) )
+                dbtypesect = 'dbtype_' + config['db'].pop('type', '')
+                config['dbtype'] = config.get(dbtypesect)
+                if config['dbtype']:
+                    config[appname]['use256'] = (config[appname]['use256'].title() == 'True')
+                    self.config, self.palette = config, palette
+                else:
+                    raise RuntimeError, 'Database Type section "%s" not found. ' % dbtypesect
+            else:
+                raise RuntimeError, 'Database section "%s" not found. ' % dbsection
+        else:
+            log.error('ini file not found.')
+
+    def onActivate(self, input):
+        'An object in the sidebar has been activated.'
+        i = self.body.get_focus_column()
+        item, pos = self.sidebar.get_focus()
+        itemtype = 'Table'
+        if type(item) is BranchWidget:    # XP behavior, on open close others
+            itemtype = 'Database'
+            for method in ('next_sibling', 'prev_sibling'):
+                brother = item
+                while True:   # this sucks, no simple traversal
+                    brother = getattr(brother.get_node(), method)()
+                    if not brother:  break
+                    brother = brother.get_widget()
+                    if type(brother) is BranchWidget:
+                        brother.expanded = False
+                        brother.update_widget()
+            item.expanded = not item.expanded
+            item.update_widget()
+
+        if self.db:
+            currname = item.get_display_text()
+            txt = '%s "%s" activated.' % (itemtype, currname )
+            log.info(str(input).upper() + ', %s' % txt)
+
+            if itemtype == 'Table':
+                # get db name
+                dbname = item.get_node().get_parent().get_value()['name']
+
+                # get table legend and data
+                legend = self.get_tbl_legend(dbname, currname)
+                if legend.pkeys:
+                    pkeys = ','.join( (str(x[1]) for x in legend.pkeys ) )
+                    orderby = 'order by %s' % pkeys
+                elif legend.pkeys and ('id' in [ x[1] for x in legend.pkeys ]):
+                    orderby = 'order by id'
+                else:
+                    orderby = ''
+
+                q = self.config['dbtype']['query_table'] % (currname, orderby)
+                results = self.db.query(dbname, q)
+                if results:
+                    self.statxt.set_text(self.greeting)
+                elif results is False:  # fallback, try again
+                    q = 'select * from %s;' % currname
+                    results = self.db.query(dbname, q)
+                self.populate_cont(results or [], dbname, currname, legend)
+        else:
+            if type(item) is LeafWidget:
+                legend = mutable_list(
+                    [[1, 'id', 'character varying', None, 'YES', 255, None]] )
+                legend.pkeys = ((1, 'id'),)
+                for i in range(2,10):
+                    dummytype = legend[0][:]  # copy
+                    dummytype[0] = i
+                    dummytype[1] = 'col%02d' % i
+                    legend.append(dummytype)
+                dummy = ('The quick brown fox jumped over the lazy dog.'.split()
+                    ,) * 20
+                self.populate_cont(dummy, legend=legend)
+
+    def populate_cont(self, data, dbname=None, tbname=None, legend=None):
+        'Load up the content pane.'
+        rows = []
+        if legend:  # show name/type as column header
+            row = [ u.AttrMap( u.Text('%s\n%s' % (x[1],x[2])), 'logo' )
+                    for x in legend ]
+            rows.append( u.Columns(row, 1, 0, 4, ) )
+            if hasattr(legend, 'pkeys') and legend.pkeys:
+                pkeys = legend.pkeys
+            else:
+                pkeys = None
+            log.debug('pkeys:' + repr(pkeys))
+            # figure which col is which field, could be more efficient
+            nmfromi = dict(( (i,leg[1]) for i,leg in enumerate(legend) ))
+            ifromnm = dict(( (leg[1], i) for i,leg in enumerate(legend) ))
+
+            for row in data:
+                try:
+                    if not pkeys:  raise KeyError
+                    newrow = []
+                    # save primary key for later inside edit cell
+                    where = ''  #  pkstr = val, ...
+                    for j, pkindex in enumerate(pkeys):  # pkindex is 1-based
+                        pkindex, pkstr = pkindex
+                        if legend[pkindex-1][2] == 'integer':
+                            where += '%s = %s' % (pkstr, row[pkindex-1])
+                        else:
+                            where += "%s = '%s'" % (pkstr, row[pkindex-1])
+                        if j != len(pkeys) - 1:
+                            where += ' and '
+                    log.debug('where clause: ' + where)
+
+                    # populate pane, modest type enforcement
+                    for i, field in enumerate(row):
+                        if legend[i][2] == 'integer':
+                            maxint = 2**(legend[i][6]-1)-1
+                            editor = EditDBInt('', field, maxint=maxint)
+                        else:
+                            editor = EditDBStr('', str(field), maxlen=legend[i][5])
+                        editor._db_inf = { 'where': where,
+                            'dbname': dbname, 'tbname': tbname,
+                            'fdname': nmfromi[i], 'gparent': self,
+                        }
+                        newrow.append( u.AttrMap(editor, 'right', 'focus') )
+                except KeyError:  # fall back to read only.
+                    newrow = [ u.AttrMap(u.Text(str(x)), 'right') for x in row ]
+                rows.append( u.Columns(newrow, 1, None, 4, ) )
+        else:
+            for row in data:
+                row = [ u.AttrMap(u.Text(str(x)), 'right') for x in row ]
+                rows.append( u.Columns(row, 1, None, 4, ) )
+
+        self.contpane.body = u.SimpleListWalker(rows + [u.Text(''), u.Text(''),
+            self.logo])
+
+    def populate_sidebar(self):
+        'Load up the sidebar.'
+        if self.db:
+            treedata = self.list_db_cluster()
+            treedata = { 'name': 'Databases', 'children': treedata }
+        else:
+            log.warn('Database not found, using dummy data.')
+            children = [ dict(name='child%s' % i) for i in range(10) ]
+            children[3]['children'] = [ dict(name='child%s' % i) for i in range(10) ]
+            children[5]['children'] = [ dict(name='child%s' % i) for i in range(10) ]
+            treedata = {    'name': 'Databases (N/A)',
+                        'children': children
+            }
+        return treedata
+
+    def unhandled_input(self, input):
+        '''Handle input that wasn't already.'''
+        if self.db: dbconns = self.db.conns
+        else:       dbconns = ()
+        i = self.body.get_focus_column()
+        log.debug('focus in column:%s' % i)
+
+        if input in ('q', 'Q', 'esc'):
+            if self.db: self.db.close()
+            raise u.ExitMainLoop()
+        elif input in ('tab', 'right'):
+            try:
+                self.body.set_focus(i+1)
+                # keep focus out of legend, doesn't work with right key :/
+                if self.contpane.get_focus()[1] == 0:
+                    self.contpane.set_focus(1, coming_from='above')
+            except AssertionError:  self.body.set_focus(0)
+            log.debug('tab: col %s to %s' % (i, i+1))
+
+        elif i == 0:  # self.sidebar
+            if input == 'enter':
+                self.onActivate(input)
+            elif type(input) is tuple:  # mouse event
+                if input[0] == 'mouse release' and input[1] in (0,1):
+                    self.onActivate(input)
+        else:
+            log.debug(str(input))
+
+
+if __name__=='__main__':
+    parser = OptionParser(usage=__doc__.rstrip(), version=__version__)
+    parser.add_option(
+        '-l', '--logfile', metavar="F", default='log_dumbo.txt',
+        help='Log activity to this file. Default: %default.')
+    parser.add_option('-v', '--verbose', action='store_true',
+        help='Enable verbose output to log.')
+    parser.add_option('-V', '--very-verbose', action='store_true',
+        help='Enable ridiculous amounts of debugging output.')
+
+    (opts, args) = parser.parse_args()
+
+    # set up logging
+    log = logging.getLogger(__name__)
+    log.setLevel( (logging.DEBUG if debug else logging.WARN) )
+    if opts.verbose:        log.setLevel(logging.INFO)
+    if opts.very_verbose:   log.setLevel(logging.DEBUG)
+    hdlr = logging.FileHandler(opts.logfile)
+    hdlr.setFormatter(formatter)
+    log.addHandler(hdlr)
+
+    # get started
+    df = DumboFrame(args)
+
+    loop = u.MainLoop(df, df.palette, unhandled_input=df.unhandled_input)
+    if df.config and df.config[appname]['use256'] and hicolor_avail:
+        loop.screen.set_terminal_properties(colors=256)
+    try:
+        loop.run()
+    except Exception, e:    # try to close all db connections
+        print e.__class__.__name__, e, '.  Check log for details.\n'
+        import traceback
+        log.critical(traceback.format_exc())
+        if df.db:
+            df.db.close()
+

File dumbo/utk.py

View file
+#!/usr/bin/env python
+#
+# Urwid web site: http://excess.org/urwid/
+# Generic TreeWidget/TreeWalker class
+#    Copyright (c) 2010  Rob Lanphier
+# Derived from Urwid example lazy directory browser / tree view:
+#    Copyright (C) 2004-2009  Ian Ward
+#
+#    This library is free software; you can redistribute it and/or
+#    modify it under the terms of the GNU Lesser General Public
+#    License as published by the Free Software Foundation; either
+#    version 2.1 of the License, or (at your option) any later version.
+#
+#    This library is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+#    Lesser General Public License for more details.
+#
+#    You should have received a copy of the GNU Lesser General Public
+#    License along with this library; if not, write to the Free Software
+#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
+#
+'''
+    Urwid tree view
+
+    Features:
+    - custom selectable widgets for trees
+    - custom list walker for displaying widgets in a tree fashion
+'''
+if True:  # foldable init
+    import os
+    import urwid as u
+
+    term = os.environ.get('TERM')
+    if term == 'xterm':
+        unichar_avail = True
+        hicolor_avail = True
+    else:
+        unichar_avail = False
+        hicolor_avail = False
+
+
+class Button(u.WidgetWrap):
+    '''
+        A simple button that can be aligned and fires on mouse release.
+        Subclassing didn't seem to work, so copied here from widget.py
+    '''
+    button_left = u.Text(' ')
+    button_right = u.Text(' ')
+    signals = ['click']
+
+    def __init__(self, label, align='right', on_press=None, user_data=None):
+        self._label = u.SelectableIcon('', 0)
+        self._label.set_align_mode(align)
+        cols = u.Columns([
+            ('weight', 1, self._label),
+            ('fixed', 1, self.button_right)
+            ], 0)
+        self.__super.__init__(cols)
+        if on_press:
+            u.connect_signal(self, 'click', on_press, user_data)
+        self.set_label(label)
+
+    def _repr_words(self):
+        return self.__super._repr_words() + [
+            repr(self.label)]
+
+    def set_label(self, label):
+        self._label.set_text(label)
+
+    def get_label(self):
+        return self._label.text
+    label = property(get_label)
+
+    def keypress(self, size, key):
+        from  command_map import command_map
+        if command_map[key] != 'activate':
+            return key
+
+        self._emit('click')
+
+    def mouse_event(self, size, event, button, x, y, focus):
+        'Click on release, not on press.'
+        if (button == 0) and ('release' in event):
+            self._emit('click')
+            return True
+        else:
+            return False
+
+
+class TreeWidgetError(RuntimeError):
+    pass
+
+
+class TreeWidget(u.WidgetWrap):
+    '''A widget representing something in the file tree.'''
+    def __init__(self, node):
+        self._node = node
+        self._innerwidget = None
+        self.selected = False
+
+        widget = self.get_indented_widget()
+
+        w = u.AttrWrap(widget, None)
+        self.__super.__init__(w)
+        # Compatibility fix for 0.9.9+
+        if not hasattr(self, 'get_w'):
+            self.get_w = self._retro_get_w
+        self.update_w()
+
+    def _retro_get_w(self):
+        '''
+        Implementation of get_w() if the base urwid install doesn't support it.
+        '''
+        return self._w
+
+    def get_indented_widget(self):
+        leftmargin = u.Text('')
+        widgetlist = [self.get_inner_widget()]
+        indent_cols = self.get_indent_cols()
+        if indent_cols > 0:
+            widgetlist.insert(0, ('fixed', indent_cols, leftmargin))
+        return u.Columns(widgetlist)
+
+    def get_indent_cols(self):
+        return 3 * self.get_node().get_depth()
+
+    def get_inner_widget(self):
+        if self._innerwidget is None:
+            self._innerwidget = self.load_inner_widget()
+        return self._innerwidget
+
+    def load_inner_widget(self):
+        return u.Text(self.get_display_text())
+
+    def get_node(self):
+        return self._node
+
+    def get_display_text(self):
+        return (self.get_node().get_key() + ': ' +
+                str(self.get_node().get_value()))
+
+    def selectable(self):
+        return True
+
+    def is_selected(self):
+        return self.selected
+
+    def set_selected(self, value=True):
+        self.selected = value
+
+    def keypress(self, size, key):
+        '''allow subclasses to intercept keystrokes'''
+        w = self.get_w()
+        try:
+            key = w.keypress(size, key)
+        except AttributeError:
+            # no biggie...we'll just handle the keypress here
+            pass
+        key = self.unhandled_keys(size, key)
+        return key
+
+    def unhandled_keys(self, size, key):
+        '''
+        Override this method to intercept keystrokes in subclasses.
+        Default behavior: Toggle selected on space, ignore other keys.
+        '''
+        if key == ' ':
+            self.selected = not self.selected
+            self.update_w()
+        else:
+            return key
+
+    def update_w(self):
+        '''Update the attributes of self.widget based on self.selected.
+        '''
+        if self.selected:
+            self.get_w().attr = 'selected'
+            self.get_w().focus_attr = 'selected focus'
+        else:
+            self.get_w().attr = 'body'
+            self.get_w().focus_attr = 'focus'
+
+    def next_inorder(self):
+        '''Return the next TreeWidget depth first from this one.'''
+        # first check if there's a child widget
+        firstchild = self.first_child()
+        if firstchild is not None:
+            return firstchild
+
+        # now we need to hunt for the next sibling
+        thisnode = self.get_node()
+        nextnode = thisnode.next_sibling()
+        depth = thisnode.get_depth()
+        while nextnode is None and depth > 0:
+            # keep going up the tree until we find an ancestor next sibling
+            thisnode = thisnode.get_parent()
+            nextnode = thisnode.next_sibling()
+            depth -= 1
+            assert depth == thisnode.get_depth()
+        if nextnode is None:
+            # we're at the end of the tree
+            return None
+        else:
+            return nextnode.get_widget()
+
+    def prev_inorder(self):
+        '''Return the previous TreeWidget depth first from this one.'''
+        thisnode = self._node
+        prevnode = thisnode.prev_sibling()
+        if prevnode is not None:
+            # we need to find the last child of the previous widget if its
+            # expanded
+            prevwidget = prevnode.get_widget()
+            lastchild = prevwidget.last_child()
+            if lastchild is None:
+                return prevwidget
+            else:
+                return lastchild
+        else:
+            # need to hunt for the parent
+            depth = thisnode.get_depth()
+            if prevnode is None and depth == 0:
+                return None
+            elif prevnode is None:
+                prevnode = thisnode.get_parent()
+            return prevnode.get_widget()
+
+    def first_child(self):
+        '''Default to have no children.'''
+        return None
+
+    def last_child(self):
+        '''Default to have no children.'''
+        return None
+
+
+class ParentWidget(TreeWidget):
+    '''Widget for an interior tree node.'''
+
+    def __init__(self, node, expanded=True):
+        self.__super.__init__(node)
+        self.expanded = expanded
+
+        self.update_widget()
+
+    def update_widget(self, focused=False):
+        '''Update display widget text.'''
+
+        if self.expanded:
+            if unichar_avail:   mark = u'\u25BC'
+            else:               mark = '-'
+        else:
+            if unichar_avail:   mark = u'\u25B6'
+            else:               mark = '+'
+
+        self._innerwidget.set_text(
+            [mark, ' ', self.get_display_text()] )
+
+    def keypress(self, size, key):
+        '''Handle expand & collapse requests.'''
+        if key in ('+', 'right'):
+            self.expanded = True
+            self.update_widget()
+        elif key in ('-', 'left'):
+            self.expanded = False
+            self.update_widget()
+        else:
+            self.update_widget()
+            return self.__super.keypress(size, key)
+
+    def mouse_event(self, size, event, button, col, row, focus):
+        if event != 'mouse press' or button!=1:
+            return False
+
+        if row == 0 and col == self.get_indent_cols():
+            self.expanded = not self.expanded
+            self.update_widget()
+            return True
+
+        return False
+
+    def first_child(self):
+        '''Return first child if expanded.'''
+        if not self.expanded:
+            return None
+        else:
+            if self._node.has_children():
+                firstnode = self._node.get_first_child()
+                return firstnode.get_widget()
+            else:
+                return None
+
+    def last_child(self):
+        '''Return last child if expanded.'''
+        if not self.expanded:
+            return None
+        else:
+            if self._node.has_children():
+                lastchild = self._node.get_last_child().get_widget()
+            else:
+                return None
+            # recursively search down for the last descendant
+            lastdescendant = lastchild.last_child()
+            if lastdescendant is None:
+                return lastchild
+            else:
+                return lastdescendant
+
+
+class TreeNode(object):
+    '''
+    Store tree contents and cache TreeWidget objects.
+    A TreeNode consists of the following elements:
+    *  key: accessor token for parent nodes
+    *  value: subclass-specific data
+    *  parent: a TreeNode which contains a pointer back to this object
+    *  widget: The widget used to render the object
+    '''
+    def __init__(self, value, parent=None, key=None, depth=None):
+        self._key = key
+        self._parent = parent
+        self._value = value
+        self._depth = depth
+        self._widget = None
+
+    def get_widget(self, reload=False):
+        ''' Return the widget for this node.'''
+        if self._widget is None or reload == True:
+            self._widget = self.load_widget()
+        return self._widget
+
+    def load_widget(self):
+        return TreeWidget(self)
+
+    def get_depth(self):
+        if self._depth is None and self._parent is None:
+            self._depth = 0
+        elif self._depth is None:
+            self._depth = self._parent.get_depth() + 1
+        return self._depth
+
+    def get_index(self):
+        if self.get_depth() == 0:
+            return None
+        else:
+            key = self.get_key()
+            parent = self.get_parent()
+            return parent.get_child_index(key)
+
+    def get_key(self):
+        return self._key
+
+    def set_key(self, key):
+        self._key = key
+
+    def change_key(self, key):
+        self.get_parent().change_child_key(self._key, key)
+
+    def get_parent(self):
+        if self._parent == None and self.get_depth() > 0:
+            self._parent = self.load_parent()
+        return self._parent
+
+    def load_parent(self):
+        '''Provide TreeNode with a parent for the current node.  This function
+        is only required if the tree was instantiated from a child node
+        (virtual function)'''
+        raise TreeWidgetError('virtual function.  Implement in subclass')
+
+    def get_value(self):
+        return self._value
+
+    def is_root(self):
+        return self.get_depth() == 0
+
+    def next_sibling(self):
+        if self.get_depth() > 0:
+            return self.get_parent().next_child(self.get_key())
+        else:
+            return None
+
+    def prev_sibling(self):
+        if self.get_depth() > 0:
+            return self.get_parent().prev_child(self.get_key())
+        else:
+            return None
+
+    def get_root(self):
+        root = self
+        while root.get_parent() is not None:
+            root = root.get_parent()
+        return root
+
+
+class ParentNode(TreeNode):
+    '''Maintain sort order for TreeNodes.'''
+    def __init__(self, value, parent=None, key=None, depth=None):
+        TreeNode.__init__(self, value, parent=parent, key=key, depth=depth)
+
+        self._child_keys = None
+        self._children = {}
+
+    def load_widget(self):
+        return ParentWidget(self)
+
+    def get_child_keys(self, reload=False):
+        '''Return a possibly ordered list of child keys'''
+        if self._child_keys is None or reload == True:
+            self._child_keys = self.load_child_keys()
+        return self._child_keys
+
+    def load_child_keys(self):
+        '''Provide ParentNode with an ordered list of child keys (virtual
+        function)'''
+        raise TreeWidgetError('virtual function.  Implement in subclass')
+
+    def get_child_widget(self, key):
+        '''Return the widget for a given key.  Create if necessary.'''
+
+        child = self.get_child_node(key)
+        return child.get_widget()
+
+    def get_child_node(self, key, reload=False):
+        '''Return the child node for a given key.  Create if necessary.'''
+        if key not in self._children or reload == True:
+            self._children[key] = self.load_child_node(key)
+        return self._children[key]
+
+    def load_child_node(self, key):
+        '''Load the child node for a given key (virtual function)'''
+        raise TreeWidgetError('virtual function.  Implement in subclass')
+
+    def set_child_node(self, key, node):
+        '''Set the child node for a given key.  Useful for bottom-up, lazy
+        population of a tree..'''
+        self._children[key]=node
+
+    def change_child_key(self, oldkey, newkey):
+        if newkey in self._children:
+            raise TreeWidgetError('%s is already in use' % newkey)
+        self._children[newkey] = self._children.pop(oldkey)
+        self._children[newkey].set_key(newkey)
+
+    def get_child_index(self, key):
+        try:
+            return self.get_child_keys().index(key)
+        except ValueError:
+            errorstring = ('Can\'t find key %s in ParentNode %s\n' +
+                           'ParentNode items: %s')
+            raise TreeWidgetError(errorstring % (key, self.get_key(),
+                                  str(self.get_child_keys())))
+
+    def next_child(self, key):
+        '''Return the next child node in index order from the given key.'''
+
+        index = self.get_child_index(key)
+        # the given node may have just been deleted
+        if index is None:
+            return None
+        index += 1
+
+        child_keys = self.get_child_keys()
+        if index < len(child_keys):
+            # get the next item at same level
+            return self.get_child_node(child_keys[index])
+        else:
+            return None
+
+    def prev_child(self, key):
+        '''Return the previous child node in index order from the given key.'''
+        index = self.get_child_index(key)
+        if index is None:
+            return None
+
+        child_keys = self.get_child_keys()
+        index -= 1
+
+        if index >= 0:
+            # get the previous item at same level
+            return self.get_child_node(child_keys[index])
+        else:
+            return None
+
+    def get_first_child(self):
+        '''Return the first TreeNode in the directory.'''
+        child_keys = self.get_child_keys()
+        return self.get_child_node(child_keys[0])
+
+    def get_last_child(self):
+        '''Return the last TreeNode in the directory.'''
+        child_keys = self.get_child_keys()
+        return self.get_child_node(child_keys[-1])
+
+    def has_children(self):
+        '''Does this node have any children?'''
+        return len(self.get_child_keys())>0
+
+
+class TreeWalker(u.ListWalker):
+    '''ListWalker-compatible class for browsing directories.
+
+    positions are TreeNodes.'''
+
+    def __init__(self, start_from):
+        '''start_from: TreeNode with the initial focus.'''
+        self.focus = start_from
+
+    def get_focus(self):
+        widget = self.focus.get_widget()
+        if hasattr(widget, 'update_widget'):  widget.update_widget(True)
+        return widget, self.focus
+
+    def set_focus(self, focus):
+        self.focus = focus
+        self._modified()
+
+    def get_next(self, start_from):
+        widget = start_from.get_widget()
+        target = widget.next_inorder()
+        if target is None:
+            return None, None
+        else:
+            return target, target.get_node()
+
+    def get_prev(self, start_from):
+        widget = start_from.get_widget()
+        target = widget.prev_inorder()
+        if target is None:
+            return None, None
+        else:
+            return target, target.get_node()
+
+
+class TreeListBox(u.ListBox):
+    def keypress(self, size, key):
+        key = self.__super.keypress(size, key)
+        return self.unhandled_input(size, key)
+
+    def unhandled_input(self, size, input):
+        '''Handle macro-navigation keys'''
+        if input == 'up':
+            self.move_focus_to_parent(size)
+        elif input == 'left':
+            self.collapse_focus_parent(size)
+        elif input == '-':
+            self.collapse_focus_parent(size)
+        elif input == 'home':
+            self.focus_home(size)
+        elif input == 'end':
+            self.focus_end(size)
+        else:
+            return input
+
+    def collapse_focus_parent(self, size):
+        '''Collapse parent directory.'''
+
+        widget, pos = self.body.get_focus()
+        self.move_focus_to_parent(size)
+
+        pwidget, ppos = self.body.get_focus()
+        if pos != ppos:
+            self.keypress(size, '-')
+
+    def move_focus_to_parent(self, size):
+        '''Move focus to parent of widget in focus.'''
+
+        widget, pos = self.body.get_focus()
+
+        parentpos = pos.get_parent()
+
+        if parentpos is None:
+            return
+
+        middle, top, bottom = self.calculate_visible( size )
+
+        row_offset, focus_widget, focus_pos, focus_rows, cursor = middle
+        trim_top, fill_above = top
+
+        for widget, pos, rows in fill_above:
+            row_offset -= rows
+            if pos == parentpos:
+                self.change_focus(size, pos, row_offset)
+                return
+
+        self.change_focus(size, pos.get_parent())
+
+    def focus_home(self, size):
+        '''Move focus to very top.'''
+
+        widget, pos = self.body.get_focus()
+        rootnode = pos.get_root()
+        self.change_focus(size, rootnode)
+
+    def focus_end( self, size ):
+        '''Move focus to far bottom.'''
+
+        maxrow, maxcol = size
+        widget, pos = self.body.get_focus()
+        rootnode = pos.get_root()
+        rootwidget = rootnode.get_widget()
+        lastwidget = rootwidget.last_child()
+        lastnode = lastwidget.get_node()
+
+        self.change_focus(size, lastnode, maxrow-1)

File utk.py

-#!/usr/bin/env python
-#
-# Urwid web site: http://excess.org/urwid/
-# Generic TreeWidget/TreeWalker class
-#    Copyright (c) 2010  Rob Lanphier
-# Derived from Urwid example lazy directory browser / tree view:
-#    Copyright (C) 2004-2009  Ian Ward
-#
-#    This library is free software; you can redistribute it and/or
-#    modify it under the terms of the GNU Lesser General Public
-#    License as published by the Free Software Foundation; either
-#    version 2.1 of the License, or (at your option) any later version.
-#
-#    This library is distributed in the hope that it will be useful,
-#    but WITHOUT ANY WARRANTY; without even the implied warranty of
-#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
-#    Lesser General Public License for more details.
-#
-#    You should have received a copy of the GNU Lesser General Public
-#    License along with this library; if not, write to the Free Software
-#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
-#
-'''
-    Urwid tree view
-
-    Features:
-    - custom selectable widgets for trees
-    - custom list walker for displaying widgets in a tree fashion
-'''
-if True:  # foldable init
-    import os
-    import urwid as u
-
-    term = os.environ.get('TERM')
-    if term == 'xterm':
-        unichar_avail = True
-        hicolor_avail = True
-    else:
-        unichar_avail = False
-        hicolor_avail = False
-
-
-class Button(u.WidgetWrap):
-    '''
-        A simple button that can be aligned and fires on mouse release.
-        Subclassing didn't seem to work, so copied here from widget.py
-    '''
-    button_left = u.Text(' ')
-    button_right = u.Text(' ')
-    signals = ['click']
-
-    def __init__(self, label, align='right', on_press=None, user_data=None):
-        self._label = u.SelectableIcon('', 0)
-        self._label.set_align_mode(align)
-        cols = u.Columns([
-            ('weight', 1, self._label),
-            ('fixed', 1, self.button_right)
-            ], 0)
-        self.__super.__init__(cols)
-        if on_press:
-            u.connect_signal(self, 'click', on_press, user_data)
-        self.set_label(label)
-
-    def _repr_words(self):
-        return self.__super._repr_words() + [
-            repr(self.label)]
-
-    def set_label(self, label):
-        self._label.set_text(label)
-
-    def get_label(self):
-        return self._label.text
-    label = property(get_label)
-
-    def keypress(self, size, key):
-        from  command_map import command_map
-        if command_map[key] != 'activate':
-            return key
-
-        self._emit('click')
-
-    def mouse_event(self, size, event, button, x, y, focus):
-        'Click on release, not on press.'
-        if (button == 0) and ('release' in event):
-            self._emit('click')
-            return True
-        else:
-            return False
-
-
-class TreeWidgetError(RuntimeError):
-    pass
-
-
-class TreeWidget(u.WidgetWrap):
-    '''A widget representing something in the file tree.'''
-    def __init__(self, node):
-        self._node = node
-        self._innerwidget = None
-        self.selected = False
-
-        widget = self.get_indented_widget()
-
-        w = u.AttrWrap(widget, None)
-        self.__super.__init__(w)
-        # Compatibility fix for 0.9.9+
-        if not hasattr(self, 'get_w'):
-            self.get_w = self._retro_get_w
-        self.update_w()
-
-    def _retro_get_w(self):
-        '''
-        Implementation of get_w() if the base urwid install doesn't support it.
-        '''
-        return self._w
-
-    def get_indented_widget(self):
-        leftmargin = u.Text('')
-        widgetlist = [self.get_inner_widget()]
-        indent_cols = self.get_indent_cols()
-        if indent_cols > 0:
-            widgetlist.insert(0, ('fixed', indent_cols, leftmargin))
-        return u.Columns(widgetlist)
-
-    def get_indent_cols(self):
-        return 3 * self.get_node().get_depth()
-
-    def get_inner_widget(self):
-        if self._innerwidget is None:
-            self._innerwidget = self.load_inner_widget()
-        return self._innerwidget
-
-    def load_inner_widget(self):
-        return u.Text(self.get_display_text())
-
-    def get_node(self):
-        return self._node
-
-    def get_display_text(self):
-        return (self.get_node().get_key() + ': ' +
-                str(self.get_node().get_value()))
-
-    def selectable(self):
-        return True
-
-    def is_selected(self):
-        return self.selected
-
-    def set_selected(self, value=True):
-        self.selected = value
-
-    def keypress(self, size, key):
-        '''allow subclasses to intercept keystrokes'''
-        w = self.get_w()
-        try:
-            key = w.keypress(size, key)
-        except AttributeError:
-            # no biggie...we'll just handle the keypress here
-            pass
-        key = self.unhandled_keys(size, key)
-        return key
-
-    def unhandled_keys(self, size, key):
-        '''
-        Override this method to intercept keystrokes in subclasses.
-        Default behavior: Toggle selected on space, ignore other keys.
-        '''
-        if key == ' ':
-            self.selected = not self.selected
-            self.update_w()
-        else:
-            return key
-
-    def update_w(self):
-        '''Update the attributes of self.widget based on self.selected.
-        '''
-        if self.selected:
-            self.get_w().attr = 'selected'
-            self.get_w().focus_attr = 'selected focus'
-        else:
-            self.get_w().attr = 'body'
-            self.get_w().focus_attr = 'focus'
-
-    def next_inorder(self):
-        '''Return the next TreeWidget depth first from this one.'''
-        # first check if there's a child widget
-        firstchild = self.first_child()
-        if firstchild is not None:
-            return firstchild
-
-        # now we need to hunt for the next sibling
-        thisnode = self.get_node()
-        nextnode = thisnode.next_sibling()
-        depth = thisnode.get_depth()
-        while nextnode is None and depth > 0:
-            # keep going up the tree until we find an ancestor next sibling
-            thisnode = thisnode.get_parent()
-            nextnode = thisnode.next_sibling()
-            depth -= 1
-            assert depth == thisnode.get_depth()
-        if nextnode is None:
-            # we're at the end of the tree
-            return None
-        else:
-            return nextnode.get_widget()
-
-    def prev_inorder(self):
-        '''Return the previous TreeWidget depth first from this one.'''
-        thisnode = self._node
-        prevnode = thisnode.prev_sibling()
-        if prevnode is not None:
-            # we need to find the last child of the previous widget if its
-            # expanded
-            prevwidget = prevnode.get_widget()
-            lastchild = prevwidget.last_child()
-            if lastchild is None:
-                return prevwidget
-            else:
-                return lastchild
-        else:
-            # need to hunt for the parent
-            depth = thisnode.get_depth()
-            if prevnode is None and depth == 0:
-                return None
-            elif prevnode is None:
-                prevnode = thisnode.get_parent()
-            return prevnode.get_widget()
-
-    def first_child(self):
-        '''Default to have no children.'''
-        return None
-
-    def last_child(self):
-        '''Default to have no children.'''
-        return None
-
-
-class ParentWidget(TreeWidget):
-    '''Widget for an interior tree node.'''
-
-    def __init__(self, node, expanded=True):
-        self.__super.__init__(node)
-        self.expanded = expanded
-
-        self.update_widget()
-
-    def update_widget(self, focused=False):
-        '''Update display widget text.'''
-
-        if self.expanded:
-            if unichar_avail:   mark = u'\u25BC'
-            else:               mark = '-'
-        else:
-            if unichar_avail:   mark = u'\u25B6'
-            else:               mark = '+'
-
-        self._innerwidget.set_text(
-            [mark, ' ', self.get_display_text()] )
-
-    def keypress(self, size, key):
-        '''Handle expand & collapse requests.'''
-        if key in ('+', 'right'):
-            self.expanded = True
-            self.update_widget()
-        elif key in ('-', 'left'):
-            self.expanded = False
-            self.update_widget()
-        else:
-            self.update_widget()
-            return self.__super.keypress(size, key)
-
-    def mouse_event(self, size, event, button, col, row, focus):
-        if event != 'mouse press' or button!=1:
-            return False
-
-        if row == 0 and col == self.get_indent_cols():
-            self.expanded = not self.expanded
-            self.update_widget()
-            return True
-
-        return False
-
-    def first_child(self):
-        '''Return first child if expanded.'''
-        if not self.expanded:
-            return None
-        else:
-            if self._node.has_children():
-                firstnode = self._node.get_first_child()
-                return firstnode.get_widget()
-            else:
-                return None
-
-    def last_child(self):
-        '''Return last child if expanded.'''
-        if not self.expanded:
-            return None
-        else:
-            if self._node.has_children():
-                lastchild = self._node.get_last_child().get_widget()
-            else:
-                return None
-            # recursively search down for the last descendant
-            lastdescendant = lastchild.last_child()
-            if lastdescendant is None:
-                return lastchild
-            else:
-                return lastdescendant
-
-
-class TreeNode(object):
-    '''
-    Store tree contents and cache TreeWidget objects.
-    A TreeNode consists of the following elements:
-    *  key: accessor token for parent nodes
-    *  value: subclass-specific data
-    *  parent: a TreeNode which contains a pointer back to this object
-    *  widget: The widget used to render the object
-    '''
-    def __init__(self, value, parent=None, key=None, depth=None):
-        self._key = key
-        self._parent = parent
-        self._value = value
-        self._depth = depth
-        self._widget = None
-
-    def get_widget(self, reload=False):
-        ''' Return the widget for this node.'''
-        if self._widget is None or reload == True:
-            self._widget = self.load_widget()
-        return self._widget
-
-    def load_widget(self):
-        return TreeWidget(self)
-
-    def get_depth(self):
-        if self._depth is None and self._parent is None:
-            self._depth = 0
-        elif self._depth is None:
-            self._depth = self._parent.get_depth() + 1
-        return self._depth
-
-    def get_index(self):
-        if self.get_depth() == 0:
-            return None
-        else:
-            key = self.get_key()
-            parent = self.get_parent()
-            return parent.get_child_index(key)
-
-    def get_key(self):
-        return self._key
-
-    def set_key(self, key):
-        self._key = key
-
-    def change_key(self, key):
-        self.get_parent().change_child_key(self._key, key)
-
-    def get_parent(self):
-        if self._parent == None and self.get_depth() > 0:
-            self._parent = self.load_parent()
-        return self._parent
-
-    def load_parent(self):
-        '''Provide TreeNode with a parent for the current node.  This function
-        is only required if the tree was instantiated from a child node
-        (virtual function)'''
-        raise TreeWidgetError('virtual function.  Implement in subclass')
-
-    def get_value(self):
-        return self._value
-
-    def is_root(self):
-        return self.get_depth() == 0
-
-    def next_sibling(self):
-        if self.get_depth() > 0:
-            return self.get_parent().next_child(self.get_key())
-        else:
-            return None
-