1. Russell Power
  2. mycloud

Source

mycloud / src / mycloud / cluster.py

#!/usr/bin/env python

try:
  import Tkinter
except:
  Tkinter = None

from cloud.serialization import cloudpickle, pickledebug
import Queue
import cPickle
import collections
import json
import logging
import mycloud.connections
import mycloud.thread
import mycloud.util
import random
import socket
import sys
import traceback
import xmlrpclib
import yaml

mycloud.thread.init()

class ClusterException(Exception):
  pass

def arg_name(args):
  '''Returns a short string representation of an argument list for a task.'''
  a = args[0]
  if isinstance(a, object):
    return a.__class__.__name__
  return repr(a)

class Task(object):
  '''A piece of work to be executed.'''
  def __init__(self, name, index, function, args, kw):
    self.idx = index
    logging.info('Serializing %s %s %s', function, args, kw)
    self.pickle = xmlrpclib.Binary(
      pickledebug.dumps((function, args, kw)))
    self.result = None
    self.done = False

  def run(self, client):
    logging.info('Starting task %s on %s', self.idx, client)
    self.client = client
    try:
      result_data = self.client.execute_task(self.pickle)
      self.result = cPickle.loads(result_data.data)
      self.done = True
      logging.info('Task %d finished', self.idx)
      return self.result
    except xmlrpclib.Fault, e:
      raise Exception(e.faultString)

  def wait(self):
    while not self.result:
      mycloud.thread.sleep(1)
    return self.result


class Server(object):
  '''Handles connections to remote machines and execution of tasks.
  
A Server is created for each core on a machine, and executes tasks as
machine resources become available.'''
  def __init__(self, controller, host, index):
    self.controller = controller
    self.host = host
    self.index = index

  def connect(self):
    ssh = mycloud.connections.SSH.connect(self.host)
    self.stdin, self.stdout, self.stderr = ssh.invoke(
      sys.executable,
      '-m', 'mycloud.worker',
      '--index %s' % self.index,
      '--logger_host %s' % socket.gethostname(),
      '--logger_port %s' % logging.handlers.DEFAULT_UDP_LOGGING_PORT)

    self.port = int(self.stdout.readline().strip())

    self.ready = True
    self.thread = None

  def start_task(self, task):
    self.ready = False
    self.thread = mycloud.thread.spawn(self._run_task, task)
    mycloud.thread.sleep(0)
    return self.thread

  def client(self):
    logging.debug('Created proxy to %s:%d', self.host, self.port)
    return xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port))

  def _run_task(self, task):
    try:
      task.run(self.client())
    except:
      #logging.exception('Failed to run task')
      self.controller.report_exception(sys.exc_info())
      self.stdin.close()
      self.stdout.close()
    finally:
      self.ready = True


class Cluster(object):
  def __init__(self, machines=None, tmp_prefix=None):
    self.machines = machines
    self.tmp_prefix = tmp_prefix
    self.servers = None
    self.exceptions = []

    assert self.machines
    assert self.tmp_prefix

    self.root_window = None
    self.tk_thread = None
    self.status_labels = {}

#    if Tkinter:
#      self.root_window = Tkinter.Tk()
#      self.root_window.columnconfigure(0, minsize=200, pad=20)
#      self.root_window.columnconfigure(1, minsize=800)
#    else:
    logging.info('Tkinter not available for status.', exc_info=1)

    self.start()

  def __del__(self):
    logging.info('Goodbye!')

  def report_exception(self, exc):
    self.exceptions.append(exc)

  def check_exceptions(self):
    '''Check if any remote exceptions have been thrown.  Log locally and rethrow.

If an exception is found, the controller is shutdown and all exceptions are reported
prior to raising a ClusterException.'''
    if self.exceptions:
      mycloud.connections.SSH.shutdown()

      counts = collections.defaultdict(int)

      for e in self.exceptions:
        exc_dump = '\n'.join(traceback.format_exception(*e))
        counts[exc_dump] += 1

      for exc_dump, count in sorted(counts.items(), key=lambda t: t[1]):
        logging.info('Remote exception (occurred %d times):' % count)
        logging.info('%s', '\nREMOTE:'.join(exc_dump.split('\n')))

      raise ClusterException

  def start(self):
    self.log_server = mycloud.util.LoggingServer(self)
    mycloud.thread.spawn(self.log_server.serve_forever)

    servers = []
    index = 0
    for host, cores in self.machines:
      for i in xrange(cores):
        s = Server(self, host, index)
        servers.append(s)
        index += 1

    connections = [mycloud.thread.spawn(s.connect) for s in servers]
    [c.wait() for c in connections]

    self.servers = servers
    random.shuffle(self.servers)
    logging.info('Started %d servers...', len(servers))

  def show_status(self):
    if not self.root_window:
      return

    try:
      row = 0
      formatter = logging.Formatter()
      for hostport, rlog in sorted(self.log_server.message_map.items()):
        if not hostport in self.status_labels:
          self.status_labels[hostport] = (
            Tkinter.Label(self.root_window),
            Tkinter.Entry(self.root_window, width=100))

        hp_label, msg_label = self.status_labels[hostport]
        hp_label['text'] = '%s:%s' % hostport

        msg = formatter.format(rlog)
        msg = msg.replace('\n', ' ')

        msg_label.delete(0, Tkinter.END)
        msg_label.insert(0, msg)

        hp_label.grid(row=row, column=0)
        msg_label.grid(row=row, column=1)

        row += 1
    except:
      logging.info('Failed to update status.', exc_info=1)

    self.root_window.update()

  def map(self, f, arglist, name='worker'):
    assert len(arglist) > 0
    idx = 0

    arglist = mycloud.util.to_tuple(arglist)

    task_queue = Queue.Queue()

    tasks = [Task(name, i, f, args, {})
             for i, args in enumerate(arglist)]

    for t in tasks:
      task_queue.put(t)

    logging.info('Mapping %d tasks against %d servers',
                 len(tasks), len(self.servers))

    # instead of joining on the task_queue immediately, we poll the server threads
    # this way we can stop early in case we encounter an exception
    try:
      while 1:
        self.check_exceptions()
        self.show_status()
        mycloud.thread.sleep(1)

        for s in self.servers:
          if not s.ready:
            continue

          t = task_queue.get_nowait()
          s.start_task(t)

    except Queue.Empty:
      pass

    for t in tasks:
      while not t.done:
        self.show_status()
        self.check_exceptions()
        mycloud.thread.sleep(1)

    logging.info('Done.')
    return [t.result for t in tasks]