django-ztask / django_ztask / management / commands / ztaskd.py

import sys
import time
import pickle
import datetime
from optparse import make_option

try:
    from zmq import PULL
except:
    from zmq import UPSTREAM as PULL
from zmq.eventloop import ioloop
from django.utils.log import getLogger
from django.utils import autoreload, importlib
from django.core.management.base import BaseCommand

from django_ztask.models import Task
from django_ztask.conf import settings
from django_ztask.context import shared_context as context

log = getLogger(settings.ZTASKD_LOGGER)

 
class Command(BaseCommand):
    
    option_list = BaseCommand.option_list + (
        make_option(
            '--noreload', 
            action='store_false', 
            dest='use_reloader', 
            default=True, 
            help='Tells Django to NOT use the auto-reloader',
        ),
        make_option(
            '--replayfailed', 
            action='store_true', 
            dest='replay_failed', 
            default=False, 
            help='Replays all failed calls in the DB',
        ),
    )
    args = ''
    help = 'Start the ztaskd server'
    func_cache = {}
    io_loop = None
    
    def handle(self, *args, **options):
        use_reloader = options.get('use_reloader', True)
        replay_failed = options.get('replay_failed', False)
        if use_reloader:
            autoreload.main(lambda: self._handle(use_reloader, replay_failed))
        else:
            self._handle(use_reloader, replay_failed)
    
    def _handle(self, use_reloader, replay_failed):
        log.info("%sServer starting on %s." % (
            'Development ' if use_reloader else '', settings.ZTASKD_URL)
        )
        self._on_load()
        socket = context.socket(PULL)
        socket.bind(settings.ZTASKD_URL)
        def _queue_handler(socket, *args, **kw):
            try:
                function_name, args, kw, after = socket.recv_pyobj()
                if function_name == 'ztask_log':
                    log.warn('%s: %s' % (args[0], args[1]))
                    return
                task = Task.objects.create(
                    function_name=function_name, 
                    args=pickle.dumps(args), 
                    kwargs=pickle.dumps(kw), 
                    retry_count=settings.ZTASKD_RETRY_COUNT,
                    next_attempt=time.time() + after
                )
                if after:
                    ioloop.DelayedCallback(lambda: self._call_function(
                        task.pk, 
                        function_name=function_name, 
                        args=args, 
                        kwargs=kw),
                        after*1000, 
                        io_loop=self.io_loop
                    ).start()
                else:
                    self._call_function(
                        task.pk, 
                        function_name=function_name, 
                        args=args, 
                        kwargs=kw,
                    )
            except:
                log.exception('Error setting up function')
        # Reload tasks if necessary
        if replay_failed:
            replay_tasks = Task.objects.all().order_by('created')
        else:
            replay_tasks = Task.objects.filter(
                retry_count__gt=0
            ).order_by('created')
        for task in replay_tasks:
            if task.next_attempt < time.time():
                ioloop.DelayedCallback(
                    lambda: self._call_function(task.pk), 
                    5000, 
                    io_loop=self.io_loop,
                ).start()
            else:
                after = task.next_attempt - time.time()
                ioloop.DelayedCallback(
                    lambda: self._call_function(task.pk), 
                    after * 1000, 
                    io_loop=self.io_loop
                ).start()
        self.io_loop = ioloop.IOLoop.instance()
        self.io_loop.add_handler(socket, _queue_handler, self.io_loop.READ)
        self.io_loop.start()
    
    def p(self, txt):
        print txt
    
    def _call_function(self, task_id, function_name=None, args=None, kw=None):
        try:
            if not function_name:
                try:
                    task = Task.objects.get(pk=task_id)
                    function_name = task.function_name
                    args = pickle.loads(str(task.args))
                    kw = pickle.loads(str(task.kwargs))
                except:
                    log.exception('Count not get task id %s' % task_id)
                    return None
            log.info('Calling %s' % function_name)
            try:
                function = self.func_cache[function_name]
            except KeyError:
                parts = function_name.split('.')
                module_name = '.'.join(parts[:-1])
                member_name = parts[-1]
                if not module_name in sys.modules:
                    importlib.import_module(module_name)
                function = getattr(sys.modules[module_name], member_name)
                self.func_cache[function_name] = function
            function(*args, **kw)
            log.info('Called %s successfully' % function_name)
            Task.objects.get(pk=task_id).delete()
        except Exception, e:
            log.exception('Error calling %s' % function_name)
            try:
                task = Task.objects.get(pk=task_id)
                if task.retry_count > 0:
                    task.retry_count = task.retry_count - 1
                    task.next_attempt = time.time() + settings.ZTASKD_RETRY_AFTER
                    ioloop.DelayedCallback(
                        lambda: self._call_function(task.pk), 
                        settings.ZTASKD_RETRY_AFTER * 1000, 
                        io_loop=self.io_loop
                    ).start()
                task.failed = datetime.datetime.utcnow()
                task.last_exception = '%s' % e
                task.save()
            except:
                log.exception('Error capturing exception in _call_function')
        
    def _on_load(self):
        for callable_name in settings.ZTASKD_ON_LOAD:
            log.info("ON_LOAD calling %s" % callable_name)
            parts = callable_name.split('.')
            module_name = '.'.join(parts[:-1])
            member_name = parts[-1]
            if not module_name in sys.modules:
                importlib.import_module(module_name)
            callable_fn = getattr(sys.modules[module_name], member_name)
            callable_fn()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.