mycloud / src / mycloud / cluster.py

#!/usr/bin/env python

from cloud.serialization import cloudpickle
import Queue
import cPickle
import logging
import mycloud.connections
import mycloud.thread
import mycloud.util
import random
import socket
import sys
import traceback
import xmlrpclib

mycloud.thread.init()

class ClusterException(Exception):
  pass

def arg_name(args):
  '''Returns a short string representation of an argument list for a task.'''
  a = args[0]
  if isinstance(a, object):
    return a.__class__.__name__
  return repr(a)

class Task(object):
  '''A piece of work to be executed.'''
  def __init__(self, name, index, function, args, kw):
    self.idx = index
    self.pickle = xmlrpclib.Binary(cloudpickle.dumps((function, args, kw)))
    self.result = None
    self.done = False

  def run(self, client):
    logging.info('Starting task %s on %s', self.idx, client)
    self.client = client
    result_data = self.client.execute_task(self.pickle)
    self.result = cPickle.loads(result_data.data)
    self.done = True
    logging.info('Task %d finished', self.idx)
    return self.result

  def wait(self):
    while not self.result:
      mycloud.thread.sleep(1)
    return self.result


class Server(object):
  '''Handles connections to remote machines and execution of tasks.
  
A Server is created for each core on a machine, and executes tasks as
machine resources become available.'''
  def __init__(self, cluster, host, index):
    self.cluster = cluster
    self.host = host
    self.index = index

  def connect(self):
    ssh = mycloud.connections.SSH.connect(self.host)
    self.stdin, self.stdout, self.stderr = ssh.invoke(
      sys.executable,
      '-m', 'mycloud.worker',
      '--index %s' % self.index,
      '--logger_host %s' % socket.gethostname(),
      '--logger_port %s' % logging.handlers.DEFAULT_UDP_LOGGING_PORT)

    self.port = int(self.stdout.readline().strip())

    self.ready = True
    self.thread = None

  def start_task(self, task):
    self.ready = False
    self.thread = mycloud.thread.spawn(self._run_task, task)
    mycloud.thread.sleep(0)
    return self.thread

  def client(self):
    logging.info('Created proxy to %s:%d', self.host, self.port)
    return xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port))

  def _run_task(self, task):
    try:
      task.run(self.client())
    except:
      #logging.exception('Failed to run task')
      self.cluster.report_exception(sys.exc_info())
      self.stdin.close()
      self.stdout.close()
    finally:
      self.ready = True


class Cluster(object):
  def __init__(self, machines=None, tmp_prefix=None):
    self.machines = machines
    self.tmp_prefix = tmp_prefix
    self.servers = None
    self.exceptions = []

    assert self.machines
    assert self.tmp_prefix

    self.start()

  def __del__(self):
    logging.info('Goodbye!')

  def report_exception(self, exc):
    self.exceptions.append(exc)

  def log_exceptions(self):
    for e in self.exceptions:
      exc_dump = ['remote exception: ' + line
               for line in traceback.format_exception(*e)]
      logging.info('\n'.join(exc_dump))

  def start(self):
    self.log_server = mycloud.util.LoggingServer(self)
    mycloud.thread.spawn(self.log_server.serve_forever)

    servers = []
    index = 0
    for host, cores in self.machines:
      for i in xrange(cores):
        s = Server(self, host, index)
        servers.append(s)
        index += 1

    connections = [mycloud.thread.spawn(s.connect) for s in servers]
    [c.wait() for c in connections]

    self.servers = servers
    random.shuffle(self.servers)
    logging.info('Started %d servers...', len(servers))

  def input(self, type, pattern):
    '''Return a cluster set of cluster inputs for the given pattern.'''
    return mycloud.resource.input(type, pattern)

  def output(self, type, name, shards):
    return mycloud.resource.output(type, name, shards)

  def show_status(self):
    return
    for (host, port), rlog in self.log_server.message_map.items():
      print >> sys.stderr, '(%s, %s) -- %s' % (host, port, rlog.msg)

  def map(self, f, arglist, name='worker'):
    assert len(arglist) > 0
    idx = 0

    arglist = mycloud.util.to_tuple(arglist)

    task_queue = Queue.Queue()

    tasks = [Task(name, i, f, args, {})
             for i, args in enumerate(arglist)]

    for t in tasks:
      task_queue.put(t)

    logging.info('Mapping %d tasks against %d servers',
                 len(tasks), len(self.servers))

    # instead of joining on the task_queue immediately, we poll the server threads
    # this way we can stop early in case we encounter an exception
    try:
      while 1:
        for s in self.servers:
          if not s.ready:
            continue

          t = task_queue.get_nowait()
          s.start_task(t)

        if self.exceptions:
          self.log_exceptions()
          raise ClusterException

        mycloud.thread.sleep(0.1)
        self.show_status()

    except Queue.Empty:
      pass

    for t in tasks:
      while not t.done:
        if self.exceptions:
          self.log_exceptions()
          raise ClusterException
        mycloud.thread.sleep(0.1)

    logging.info('Done.')
    return [t.result for t in tasks]
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.