Source

mycloud / src / mycloud / cluster.py

Full commit
#!/usr/bin/env python

try:
  import Tkinter
except:
  Tkinter = None

from cloud.serialization import pickledebug as cloudpickle
import Queue
import cPickle
import collections
import logging
import mycloud.connections
import mycloud.thread
import mycloud.util
import random
import socket
import sys
import traceback
import xmlrpclib

mycloud.thread.init()

class ClusterException(Exception):
  pass

def arg_name(args):
  '''Returns a short string representation of an argument list for a task.'''
  a = args[0]
  if isinstance(a, object):
    return a.__class__.__name__
  return repr(a)

class Task(object):
  '''A piece of work to be executed.'''
  def __init__(self, name, index, function, args, kw):
    self.idx = index
    logging.info('Serializing %s %s %s', function, args, kw)
    self.pickle = xmlrpclib.Binary(
      cloudpickle.dumps((function, args, kw)))
    self.result = None
    self.done = False

  def run(self, client):
    logging.info('Starting task %s on %s', self.idx, client)
    self.client = client
    try:
      result_data = self.client.execute_task(self.pickle)
      self.result = cPickle.loads(result_data.data)
      self.done = True
      logging.info('Task %d finished', self.idx)
      return self.result
    except xmlrpclib.Fault, e:
      raise Exception(e.faultString)

  def wait(self):
    while not self.result:
      mycloud.thread.sleep(1)
    return self.result


class Server(object):
  '''Handles connections to remote machines and execution of tasks.
  
A Server is created for each core on a machine, and executes tasks as
machine resources become available.'''
  def __init__(self, controller, host, index):
    self.controller = controller
    self.host = host
    self.index = index

  def connect(self):
    ssh = mycloud.connections.SSH.connect(self.host)
    self.stdin, self.stdout, self.stderr = ssh.invoke(
      sys.executable,
      '-m', 'mycloud.worker',
      '--index %s' % self.index,
      '--logger_host %s' % socket.gethostname(),
      '--logger_port %s' % logging.handlers.DEFAULT_UDP_LOGGING_PORT)

    self.port = int(self.stdout.readline().strip())

    self.ready = True
    self.thread = None

  def start_task(self, task):
    self.ready = False
    self.thread = mycloud.thread.spawn(self._run_task, task)
    mycloud.thread.sleep(0)
    return self.thread

  def client(self):
    logging.debug('Created proxy to %s:%d', self.host, self.port)
    return xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port))

  def _run_task(self, task):
    try:
      task.run(self.client())
    except:
      #logging.exception('Failed to run task')
      self.controller.report_exception(sys.exc_info())
      self.stdin.close()
      self.stdout.close()
    finally:
      self.ready = True


class Cluster(object):
  def __init__(self, machines=None, tmp_prefix=None):
    self.machines = machines
    self.tmp_prefix = tmp_prefix
    self.servers = None
    self.exceptions = []

    assert self.machines
    assert self.tmp_prefix

    self.root_window = None

    if Tkinter:
      self.root_window = Tkinter.Tk()
      self.root_window.columnconfigure(0, minsize=200, pad=20)
      self.root_window.columnconfigure(1, minsize=800)

      self.progress_labels = (
        Tkinter.Label(self.root_window, text="map progress"),
        Tkinter.Label(self.root_window, bg="green")
      )

      self.status_labels = {}
    else:
      logging.info('Tkinter not available for status.', exc_info=1)

    self.start()

  def __del__(self):
    logging.info('Goodbye!')

  def report_exception(self, exc):
    self.exceptions.append(exc)

  def check_exceptions(self):
    '''Check if any remote exceptions have been thrown.  Log locally and rethrow.

If an exception is found, the controller is shutdown and all exceptions are reported
prior to raising a ClusterException.'''
    if self.exceptions:
      mycloud.connections.SSH.shutdown()

      counts = collections.defaultdict(int)

      for e in self.exceptions:
        exc_dump = '\n'.join(traceback.format_exception(*e))
        counts[exc_dump] += 1

      for exc_dump, count in sorted(counts.items(), key=lambda t: t[1]):
        logging.info('Remote exception (occurred %d times):' % count)
        logging.info('%s', '\nREMOTE:'.join(exc_dump.split('\n')))

      raise ClusterException

  def start(self):
    self.log_server = mycloud.util.LoggingServer(self)
    mycloud.thread.spawn(self.log_server.serve_forever)

    servers = []
    index = 0
    for host, cores in self.machines:
      for i in xrange(cores):
        s = Server(self, host, index)
        servers.append(s)
        index += 1

    connections = [mycloud.thread.spawn(s.connect) for s in servers]
    [c.wait() for c in connections]

    self.servers = servers
    random.shuffle(self.servers)
    logging.info('Started %d servers...', len(servers))

  def show_status(self, tasks):
    if not self.root_window:
      return

    tasks_finished = [t for t in tasks if t.done]
    self.progress_labels[1]['text'] = '%d of %d tasks finished' % (
      len(tasks_finished), len(tasks))

    self.progress_labels[0].grid(row=0, column=0)
    self.progress_labels[1].grid(row=0, column=1)

    try:
      row = 1
      formatter = logging.Formatter()
      for hostport, rlog in sorted(self.log_server.message_map.items()):
        if not hostport in self.status_labels:
          self.status_labels[hostport] = (
            Tkinter.Label(self.root_window),
            Tkinter.Entry(self.root_window, width=100))

        hp_label, msg_label = self.status_labels[hostport]
        hp_label['text'] = '%s:%s' % hostport

        msg = formatter.format(rlog)
        msg = msg.replace('\n', ' ')

        msg_label.delete(0, Tkinter.END)
        msg_label.insert(0, msg)

        hp_label.grid(row=row, column=0)
        msg_label.grid(row=row, column=1)

        row += 1
    except:
      logging.info('Failed to update status.', exc_info=1)

    self.root_window.update()

  def map_local(self, f, arglist):
    '''Invoke the given function once for each argument, returning the result
    of the invocations
    
    The function will be run locally on the controller.'''
    class LocalTask(object):
      def __init__(self, f, args, kw):
        self.f = f
        self.args = args
        self.kw = kw
        self.done = False

      def run(self):
        self.result = self.f(*self.args, **self.kw)
        self.done = True

    arglist = mycloud.util.to_tuple(arglist)
    tasks = [LocalTask(f, args, {}) for args in arglist]

    def task_runner():
      for t in tasks:
        t.run()

    runner = mycloud.thread.spawn(task_runner)
    while runner.isAlive():
      self.show_status(tasks)
      mycloud.thread.sleep(1)

  def map(self, f, arglist, name='generic-map'):
    assert len(arglist) > 0
    idx = 0

    arglist = mycloud.util.to_tuple(arglist)

    task_queue = Queue.Queue()

    tasks = [Task(name, i, f, args, {})
             for i, args in enumerate(arglist)]

    for t in tasks:
      task_queue.put(t)

    logging.info('Mapping %d tasks against %d servers',
                 len(tasks), len(self.servers))

    # Instead of joining on the task_queue, we poll the server 
    # threads so we can stop early in case we encounter an exception.
    try:
      while 1:
        self.check_exceptions()
        self.show_status(tasks)
        mycloud.thread.sleep(1)

        for s in self.servers:
          if not s.ready:
            continue

          t = task_queue.get_nowait()
          s.start_task(t)

    except Queue.Empty:
      pass

    for t in tasks:
      while not t.done:
        self.show_status(tasks)
        self.check_exceptions()
        mycloud.thread.sleep(1)

    logging.info('Done.')
    return [t.result for t in tasks]