Source

mdhub / mdhub / location.py

# -*- coding: utf-8 -*-

"""
Note to self: Read again:
The Python yield keyword explained
http://stackoverflow.com/q/231767/89391
"""

import pymarc
import tarfile
import zipfile
import os
import hashlib
import glob
import logging

log = logging.getLogger('mdhub.location')
logging.basicConfig(level=logging.INFO,
    format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
    datefmt='%m-%d %H:%M')

# TODO: move these into some config file as well
MDHUB_DIR = os.path.expanduser('~/.mdhub')
MDHUB_TEMP = os.path.expanduser(os.path.join(MDHUB_DIR, 'tmp'))

def get_cache_dir(filename):
    sha1 = hashlib.sha1()
    sha1.update(filename)
    cache_dir = os.path.join(MDHUB_TEMP, sha1.hexdigest())
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    return cache_dir

class EmptyInterator(object):
    """
    An empty iterator object.
    """
    def __iter__(self):
        """
        Iterator protocol.
        """
        return self
    def next(self):
        """ Just stop right here.
        """
        raise StopIteration

def marc_iterator(filename):
    """
    Return an iterator over MARC records. If the file happens
    to be corrupt, we skip it (silently for now).
    """
    log.debug("==> MARC iterator")
    # Empty files will yield an empty iterators
    if os.path.getsize(filename) > 0:
        try:
            iterator = pymarc.MARCReader(file(filename))
            iterator.next()
        except ValueError, value_error:
            log.error("Not in MARC format: {0} (Error was: {1})".format(
                filename, value_error))
            # force return an empty iterator
            return EmptyInterator()
    return pymarc.MARCReader(file(filename), to_unicode=True, force_utf8=True)

def zip_iterator(filename):
    """
    Zip iterator.
    """
    log.debug("==> zip iterator")
    with zipfile.ZipFile(filename, 'r') as archive:
        member_cache_dir = get_cache_dir(filename)
        archive.extractall(member_cache_dir)
        for fname in os.listdir(member_cache_dir):
            fpath = os.path.join(member_cache_dir, fname)
            for iterator in record_iterator(fpath):
                yield iterator

def tar_iterator(filename, compression):
    """
    Iterate over tar.(gz|bz2) file. We extract every file in the archive
    into a temporary location (see: ``get_cache_dir``); then we pass
    every file found to ``record_iterator``.
    """
    log.debug("==> tar iterator with compression: {0}".format(compression))
    member_cache_dir = get_cache_dir(filename)
    with tarfile.open(filename, 'r:{0}'.format(compression)) as archive:
        archive.extractall(member_cache_dir)
    for fname in os.listdir(member_cache_dir):
        fpath = os.path.join(member_cache_dir, fname)
        for iterator in record_iterator(fpath):
            yield iterator

def record_iterator(location):
    """
    Turns a (glob-like) location into an iterator of MARC records.
    Does not descend into subdirectories, but it will peek into archives such
    as TA-MARC-010-110902.tar.gz. Archives can be arbitrary nested, e.g.
    a zip archive containing a bunch of tar.gz's will also work as well as
    plain .mrc files.

    This function will treat all files ending in .mrc as MARC files.
    """
    log.debug("Record iterator on location: {0}".format(location))
    filelist = glob.glob(os.path.expanduser(location))
    for i, fname in enumerate(filelist):
        log.debug("==> Processing file {0}/{1}: {2}".format(
            i + 1, len(filelist), fname))
        if fname.endswith('.mrc'):
            yield marc_iterator(fname)
        elif fname.endswith('.tar.gz'):
            for entry in tar_iterator(fname, 'gz'):
                yield entry
        elif fname.endswith('.tar.bz2'):
            for entry in tar_iterator(fname, 'bz2'):
                yield entry
        elif fname.endswith('.zip'):
            for entry in zip_iterator(fname):
                yield entry
        elif fname.endswith('.csv'):
            raise NotImplementedError
        elif fname.endswith('.xml'):
            raise NotImplementedError

if __name__ == '__main__':
    # In my home data dir, I have tar.gz's and plain marc files mixed
    total = []
    for ri in record_iterator('~/bitbucket/miku/mdhub/mdhub/tests/data/recursive_marc/*'):
        records = 0
        for record in ri:
            # print record
            records += 1
        total.append(records)
        log.debug('==> {0} items ({1})'.format(records, ri))
    log.debug("Total records: {0} {1}".format(sum(total), total))