Commits

Chad Dombrova  committed 1a00ca0

Initial implementation of a worker pool for monitoring a queue using stomp protocol.

  • Participants
  • Parent commits eb6cc73
  • Branches external_triggers

Comments (0)

Files changed (1)

File denormalize/integration/stomp.py

+from __future__ import absolute_import
+import inspect
+import time
+from pydoc import locate
+import multiprocessing as mp
+from xml.etree import ElementTree as ET
+import stomp
+import stomp.exception
+
+from ..orms.base import CollectionListener
+
+import logging
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+def _indent(elem, level=0):
+    i = "\n" + level*"  "
+    if len(elem):
+        if not elem.text or not elem.text.strip():
+            elem.text = i + "  "
+        if not elem.tail or not elem.tail.strip():
+            elem.tail = i
+        for elem in elem:
+            _indent(elem, level+1)
+        if not elem.tail or not elem.tail.strip():
+            elem.tail = i
+    else:
+        if level and (not elem.tail or not elem.tail.strip()):
+            elem.tail = i
+
+
+class StompChangeObserver(stomp.ConnectionListener):
+    MESSAGE_ID_KEY = 'message-id'
+
+    def __init__(self, queue_name, host, port, username=None, password=None,
+                 lock=None):
+        if not queue_name:
+            raise ValueError('queue_name is required')
+        if not host:
+            raise ValueError('host is required')
+        if port is None or port <= 0:
+            raise ValueError('port >0 is required')
+        if (not username) and password:
+            raise ValueError('username is required when password is specified')
+
+        if lock is not None:
+            self._lock = lock
+
+        self.queue_name = queue_name
+        self.host = host
+        self.port = port
+        self.username = username
+        self.password = password
+
+        self._on_message = None
+        self._subscription_id = 1
+        self._conn = None
+        self._subscribed = False
+        self._on_change = None
+
+    # -- context manager methods
+
+    def __enter__(self):
+        self.connect()
+        return self
+
+    def __exit__(self, exc_type, exc_value, tb):
+        self.close()
+
+    # -- ConnectionListener callbacks:
+
+    def on_connected(self, headers, body):
+        logger.info('Connected!')
+
+    def on_disconnected(self):
+        self.close()
+
+    def on_message(self, headers, message):
+        message_id = headers[self.MESSAGE_ID_KEY]
+        if self._on_message is None:
+            # Message arrived before set_message_listener() called
+            self._nack(message_id)
+        else:
+            try:
+                if self._on_message(self, headers, message):
+                    self._ack(message_id)
+                else:
+                    self._nack(message_id)
+            except:
+                # Error occurred processing message
+                self._nack(message_id)
+                raise
+
+    def on_error(self, headers, item):
+        raise RuntimeError('queue error: %s' % item)
+
+    # --
+
+    def set_message_listener(self, fn):
+        if not (inspect.isfunction(fn) or inspect.ismethod(fn)):
+            raise ValueError('fn must be a function')
+
+        self._on_message = fn
+
+    def connect(self):
+        if self.is_connected():
+            logger.error('connect() called when already connected')
+            raise ValueError('already connected')
+
+        logger.info('Connecting to queue %s:%s/%s' % (self.host, self.port,
+                                                      self.queue_name))
+        self._conn = stomp.connect.StompConnection11([(self.host, self.port)])
+
+        self._conn.start()
+        self._conn.connect(self.username, self.password)
+        self._conn.set_listener('', self)
+        self._conn.subscribe(destination=self.queue_name,
+                             id=self._subscription_id,
+                             ack='client')
+
+        self._subscribed = True
+        logger.info('Observer started!')
+
+    def is_connected(self):
+        return self._conn is not None
+
+    def close(self):
+        logger.info('Disconnected!')
+
+        if self._subscribed:
+            try:
+                self._conn.unsubscribe(self._subscription_id)
+            except:
+                pass
+            self._subscribed = False
+
+        if self._conn is not None:
+            try:
+                self._conn.disconnect()
+            except:
+                pass
+            self._conn = None
+
+        if self._lock is not None:
+            self._lock.release()
+            self._lock = None
+
+    def _ack(self, message_id):
+        if message_id is None:
+            raise ValueError('no %s' % self.MESSAGE_ID_KEY)
+
+        # stompy.transport.ack is broken
+        #self.__client.ack(id=message_id)
+
+        # workaround
+        self._send_frame('ACK', {
+            'subscription': str(self._subscription_id),
+            self.MESSAGE_ID_KEY: message_id
+        })
+
+    def _nack(self, message_id):
+        if message_id is None:
+            raise ValueError('no message_id')
+
+        # ActiveMQ 5.6+ (using STOMP 1.1+) is required for NACK support
+        self._send_frame('NACK', {
+            'subscription': str(self._subscription_id),
+            'message-id': message_id
+        })
+
+    def _send_frame(self, cmd, headers, body=''):
+        import traceback
+
+        if not cmd:
+            raise ValueError('command is required')
+        if not self.is_connected():
+            raise IOError('cannot _send_frame() when not connected to queue')
+
+        # FIXME: why not import this from stomp.utils?
+        class Frame:  # from stompy.utils
+            def __init__(self, cmd=None, headers={}, body=None):
+                self.cmd = cmd
+                self.headers = headers
+                self.body = body
+
+        frame = Frame(cmd, headers, body)
+        try:
+            self._conn.transport.send_frame(frame)
+        except Exception, ex:
+            logger.error('%s error sending STOMP frame "%s": %s' % (
+                         ex.__class__.__name__, cmd, ex))
+            logger.error(traceback.format_exc())
+            self.close()
+
+
+class PooledQueueWorker(object):
+    def __init__(self, queue_name, host, port, callback):
+        self.queue_name = queue_name
+        self.host = host
+        self.port = port
+        self.callback = callback
+
+    def __call__(self, id_list):
+        lock = mp.Lock()
+        lock.acquire()
+
+        while True:  # keep connection alive
+            logger.info('Connecting to message queue...')
+
+            try:
+                observer = StompChangeObserver(self.queue_name, host=self.host,
+                                               port=self.port, lock=lock)
+                with observer:
+                    # FIXME: why invite this race condition? why not pass the change
+                    # listener to __init__ or set it before the with statement?
+                    observer.set_message_listener(self._change_listener)
+                    logger.info('Worker is monitoring changes')
+
+                    # wait till work is done
+                    lock.acquire()
+                    logger.info('Work is done: disconnected from queue')
+
+            except stomp.exception.ConnectFailedException:
+                logger.error('Connection to queue failed')
+
+            except Exception, ex:
+                logger.error('Connection to queue terminated due from '
+                             'error %s: %s' % (ex.__class__.__name__, ex))
+                if not lock.locked():
+                    lock.acquire()
+            time.sleep(5)
+
+    def _change_listener(self, observer, headers, message):
+        return self.callback(headers, message)
+
+
+def symmetric_ds_change_listener(headers, message):
+    message_id = headers['message-id']
+
+    # Parse message contents
+    changes = ET.fromstring(message)
+    if changes.tag != 'changes':
+        logger.error('Invalid change message of type "%s" was found '
+                     'in queue. Rejecting message!' % changes.tag)
+        return False
+
+    logger.info('Synchronizing %s models...' % changes.attrib['count'])
+
+    # FIXME: should we put a transaction or lock around this whole loop?
+    for change in changes:
+        change_type = change.attrib['type']
+        table_name = change.find('table').text
+
+        keys = change.find('key')
+        if keys is None:
+            logger.warn('Change message %s for table %s contains no '
+                        'primary key. Ignoring message!' % (message_id,
+                                                            table_name))
+            continue
+
+        # TODO: confirm that symmetricDS returns composite keys in the right order
+        key = tuple(column.text for column in keys)
+
+        orig = change.find('old')
+        if orig is not None:
+            orig = dict((column.attrib['key'], column.text) for column in orig)
+            # normally, there is a pre-update / pre-delete phase for models
+            # with foreign keys.  when a foreign key to a root model changes
+            # (ex. ExtraBookInfo with a one-to-one relationship to Book),
+            # both the original document referenced by the foreign key *prior*
+            # to the change and the new document referenced after the change
+            # must be updated.  symmetricDS does not provide a pre-update /
+            # pre-delete callback, but it does provide the original values of
+            # the changed row.  we will use this to manually find the affected
+            # root model instances
+
+            # loop through the watched model's filter paths, checking if any
+            # of them have foreign keys on the watched model side of the
+            # relationship (the right side) that match ours
+
+            # if the root model is directly referenced by the foreign key,
+            # we've found an affected root instance: get its primary key
+
+            # otherwise, get the primary key of the directly related model and
+            # use it to query affected root model instances (and hope that the
+            # model instances in between still exist.
+
+        print change_type, table_name, key
+        #print ET.tostring(_indent(change))
+        listener = ExternalCollectionListener.get_listener(table_name)
+        if not listener:
+            continue
+        # TODO: try/except here:
+        if change_type == 'U':
+            listener.queue_update(key, orig)
+        elif change_type == 'I':
+            listener.queue_insert(key)
+        elif change_type == 'D':
+            listener.queue_delete(key, orig)
+        else:
+            raise TypeError("Unknown change type: %r" % change_type)
+
+    print "flushing!"
+    ExternalCollectionListener.flush()
+
+    return True
+
+
+class ExternalCollectionListener(CollectionListener):
+    _registry = {}
+
+    def __init__(self, *args, **kwargs):
+        super(ExternalCollectionListener, self).__init__(*args, **kwargs)
+        # FIXME: does order matter?
+        self._inserted = set([])
+        self._updated = set([])
+        self._deleted = set([])
+
+    @classmethod
+    def get_listener(cls, table_name):
+        return cls._registry.get(table_name)
+
+    # def on_update(self, table_name, primary_key):
+
+    # def on_insert(self, table_name, primary_key):
+
+    # def on_delete(self, table_name, primary_key):
+
+    def queue_insert(self, primary_key):
+        self._inserted.add(primary_key)
+
+    def queue_update(self, primary_key, orig):
+        self._updated.add(primary_key)
+
+    def queue_delete(self, primary_key, orig):
+        self._deleted.add(primary_key)
+
+    @classmethod
+    def flush(cls):
+        for table_name, listener in cls._registry.iteritems():
+            if listener._inserted:
+                listener.post_insert(listener._inserted)
+                listener._inserted.clear()
+            if listener._updated:
+                listener.post_update(listener._updated)
+                listener._updated.clear()
+            if listener._deleted:
+                listener.post_delete(listener._deleted)
+                listener._deleted.clear()
+        SSession.remove()
+
+    def connect(self):
+        # FIXME: the plan is to eventually have only one listener per table,
+        # so we're designing around that
+        assert self.model.table_name not in self._registry
+        self._registry[self.model.table_name] = self
+
+    def disconnect(self):
+        self._registry.pop(self.model.table_name)
+
+
+def serve(collections, queue_name, queue_host, queue_port, backend,
+          change_callback, num_workers=1):
+    """Start a server to listen for changes to an ActiveMQ queue.
+
+    :param collections: collections to publish to the backend. a collection can
+        be a DocumentCollection class or instanc, or the name of the class as
+        a dotted python module path: e.g. 'module.submodule.MyClassName'
+    :type collections: list of str, `DocumentCollection` classes, or
+        `DocumentCollection` instances
+    :param queue_name: name of the queue to listen to for changes
+    :type queue_name: str
+    :type queue_host: str
+    :type queue_port: int
+    :type backend: `denormalize.backend.BackendBase`
+    :param change_callback: function to call when the queue is changed
+    :type change_callback: callable, with signature: `func(headers, message)`
+    """
+
+    if num_workers is None or num_workers < 1:
+        raise ValueError('workers must be an integer >0')
+
+    # read in models
+    for collection in collections:
+        if isinstance(collection, basestring):
+            cls = locate(collection)
+            if cls is None:
+                raise ValueError("Could not find collection model %r" %
+                                 collection)
+            collection = cls
+
+        if inspect.isclass(collection):
+            collection = collection()
+        backend.register(collection)
+
+    pool = mp.Pool(processes=num_workers)
+    logger.info('Starting %d workers in pool...' % num_workers)
+    worker = PooledQueueWorker(queue_name, queue_host, queue_port,
+                               change_callback)
+    try:
+        pool.map_async(worker, range(num_workers)).get(timeout=9999999)
+    except KeyboardInterrupt:
+        pool.terminate()
+
+
+# class Sync(Command):
+#     PAGE_SIZE = 100
+
+#     def get_options(self):
+#         return [
+#             Option('--workers', '-w', dest='workers',
+#                    help='Number of worker processes to use for monitoring', default=1),
+#             Option('--rebuild', '-r', dest='rebuild',
+#                    help='Rebuild the Mongo cache', action='store_true', default=False),
+#             Option('--depth', '-d', dest='depth',
+#                    help='Model depth to use', default=0),
+#             Option('--skip-hidden', '-s', dest='skip_hidden',
+#                    help='Skip hidden fields?', default=True),
+#             Option('--models', '-m', dest='model_names',
+#                    help='Comma-separated list of models to sync (default is all)', default=None)
+#         ]
+
+#     def run(self, model_names=None, rebuild=None, depth=None, skip_hidden=None, workers=None):
+#         logging.getLogger('stomp.py').setLevel(logging.WARN)
+
+#         depth = int(depth)
+
+#         skip_hidden = bool(skip_hidden)
+#         if not skip_hidden:
+#             logger.warn('Not skipping hidden fields! This will slow down Mongo updates!')
+
+#         if rebuild:
+#             if model_names is not None:
+#                 model_names = [ name.strip() for name in model_names.split(',') ]
+#             self.rebuild(depth, skip_hidden, model_names)
+#         else:
+#             self.monitor_changes(depth, skip_hidden, int(workers))
+
+#     def monitor_changes(self, depth, skip_hidden, workers):
+#         if depth is None or depth < 0:
+#             raise ValueError('depth must be an integer >=0')
+#         if workers is None or workers < 1:
+#             raise ValueError('workers must be an integer >0')
+
+#         config = utils.getConfig()
+
+#         pool = mp.Pool(processes=workers)
+#         logger.info('Starting %d workers in pool...' % workers)
+#         try:
+#             pool.map_async(PooledQueueWorker('queue.TLOG',
+#                                              config['QUEUE_STOMP_HOST'],
+#                                              config['QUEUE_STOMP_PORT'],
+#                                              depth,
+#                                              skip_hidden),
+#                            range(workers)).get(timeout=9999999)
+#         except KeyboardInterrupt:
+#             pool.terminate()
+
+#     # TODO: figure out what is causing slow memory leak (pymongo?)
+#     def rebuild(self, depth, skip_hidden, model_names=None):
+#         if depth is None or depth < 0:
+#             raise ValueError('depth must be an integer >=0')
+
+#         available_models = set(filter(lambda klass: inspect.isclass(klass) and
+#                                                 issubclass(klass, models.BaseModel) and
+#                                                 not inspect.isabstract(klass),
+#                                       map(lambda attr: getattr(models, attr), dir(models))))
+#         if model_names is None or len(model_names) == 0:
+#             model_classes = available_models
+#             model_names = [ klass.__name__ for klass in available_models ]
+#         else:
+#             model_names = set([ klass.__name__ for klass in available_models ]).intersection(model_names)
+#             model_classes = set()
+#             for model_name in model_names:
+#                 model_classes.add(getattr(models, model_name))
+
+#         synced_models, total_models = 0, len(model_names)
+#         logger.info('Rebuilding %d models [ %s ] with depth %d...' % (total_models, ', '.join(model_names), depth))
+
+#         # Run in web context (since some model properties may depend on it)
+#         # TODO: pin to a single session (so that redis doesn't create so many redis entries)
+#         with app.test_request_context():
+
+#             for klass in model_classes:
+#                 model_name = klass.__name__
+#                 progress = synced_models / (total_models / 100.0)
+
+#                 # Drop all models in collection and start fresh
+#                 collection = data.mongo.models[model_name]
+#                 collection.drop()
+
+#                 def does_doc_exist(doc):
+#                     """Does :doc: already exist in :collection:?"""
+#                     docId = doc['_id']
+#                     return collection.find_one(docId) is not None
+
+#                 # TODO: recover from "mysql has gone away" errors
+#                 def paged_reader(query, page_size):
+#                     """Pages over and yields each result from sqlalchemy :query: object"""
+#                     if hasattr(klass, 'id'):
+#                         query = query.order_by(klass.id.desc())
+
+#                     total_pages = max(1, int(math.ceil(query.count() / page_size)))
+#                     for page in range(0, total_pages + 1):
+#                         logger.info('    %s: Page %d/%d (%.2f%% in table)' % (
+#                             model_name, page + 1, total_pages + 1, page / (total_pages / 100.0)))
+#                         for result in query.limit(page_size).offset(page * page_size).all():
+#                             yield result
+
+#                 # TODO: is there a better way to bulk insert?
+
+#                 collection_length = klass.query.count()
+#                 logger.info('Synchronizing %d models in %s collection (%.2f%% overall)...' % (
+#                     collection_length, model_name, progress))
+
+#                 # Page over results, so that they're not all loaded into memory at once.
+#                 for model in paged_reader(klass.query, page_size=self.PAGE_SIZE):
+#                     key = model.primaryKey()
+#                     try:
+#                         # Serialize Mongo document
+#                         doc = models.serialize(model, depth=depth, skipHidden=skip_hidden)
+#                         doc['_id'] = model.primaryKeyStr()
+
+#                         # Because of occasional missing indices, our database tables may contain duplicates.
+#                         if does_doc_exist(doc):
+#                             logger.warn('Skipping record for duplicate key: %s' % doc['_id'])
+#                             continue
+
+#                         collection.insert(doc)
+#                     except Exception, ex:
+#                         logger.error('%s error while syncing model %s.%s' % (
+#                             ex.__class__.__name__, model_name, key))
+#                         raise
+
+#                 synced_models += 1
+
+#         logger.info('Finished synchronizing!')
+
+if __name__ == '__main__':
+
+    import os
+    os.environ['LUMA_DB_HOST'] = 'devlumadb'
+
+    queue_host = 'sv-lumaapi02.luma-pictures.com'
+    queue_port = 61613
+    queue_name = 'queue.TLOG'
+
+    collections = [
+        #'luma.models.collections.BaseShot',
+        'luma.models.collections.Shot'
+    ]
+
+    # setup the mongo backend
+    # mongo_uri = 'mongodb://localhost'
+    mongo_uri = 'mongodb://sv-mongo01'
+    mongo_db_name = 'test_denormalize'
+
+    from ..backend.mongodb import MongoBackend
+    backend = MongoBackend(db_name=mongo_db_name,
+                           connection_uri=mongo_uri,
+                           listener_class=ExternalCollectionListener)
+
+    # fix an issue with mysql
+    import sqlalchemy.event
+    import sqlalchemy.pool
+    @sqlalchemy.event.listens_for(sqlalchemy.pool.Pool, 'checkout')
+    def connectionPinger(dbapi_con, con_record, con_proxy):
+        cur = dbapi_con.cursor()
+        try:
+            cur.execute('select 1')
+        except:
+            raise sqlalchemy.exc.DisconnectionError
+
+    # FIXME: this should be provided via the DocumentCollection!!!
+    from ..orms.sqlalchemy import SqlAlchemyModelInspector
+    from luma.models.db import _Session, SSession
+    SqlAlchemyModelInspector._session = SSession
+
+    # HACK!
+    from denormalize.backend.mongodb import log
+    log.setLevel(logging.DEBUG)
+    from denormalize.orms.base import log
+    log.setLevel(logging.DEBUG)
+    # from luma.models.db import _Session, SSession
+    # SqlAlchemyModelInspector._session = SSession()
+    serve(collections, queue_name, queue_host, queue_port, backend,
+          symmetric_ds_change_listener)