Source

mycloud / src / mycloud / cluster.py

Full commit
#!/usr/bin/env python

from cloud.serialization import cloudpickle
import Queue
import cPickle
import logging
import mycloud.connections
import mycloud.thread
import mycloud.util
import random
import sys
import traceback
import xmlrpclib

mycloud.thread.init()

def arg_name(args):
  '''Returns a short string representation of an argument list for a task.'''
  a = args[0]
  if isinstance(a, object):
    return a.__class__.__name__
  return repr(a)


class Task(object):
  '''A piece of work to be executed.'''
  def __init__(self, idx, function, args, kw):
    self.idx = idx
    self.pickle = cloudpickle.dumps((function, args, kw))
    self.result = None
    self.done = False

  def run(self, client):
    logging.info('Starting task %s on %s', self.idx, client)
    self.client = client
    result_data = self.client.execute_task(xmlrpclib.Binary(self.pickle))
    self.result = cPickle.loads(result_data.data)
    self.done = True
    logging.info('Task %d finished', self.idx)
    return self.result

  def wait(self):
    while not self.result:
      mycloud.thread.sleep(1)
    return self.result


class Server(object):
  '''Handles connections to remote machines and execution of tasks.
  
A Server is created for each core on a machine, and executes tasks as
machine resources become available.'''
  def __init__(self, cluster, host):
    self.cluster = cluster
    self.host = host
    ssh = mycloud.connections.SSHConnection.connect(host)
    self.stdin, self.stdout, self.stderr = ssh.invoke(
      sys.executable, '-m', 'mycloud.worker')

    self.port = int(self.stdout.readline().strip())

    self.stderr_logger = (
      mycloud.util.StreamLogger(
         'Remote(%s, %d)' % (self.host, self.port), buffer=False))
    self.stderr_logger.start(self.stderr)

    self.ready = True
    self.thread = None

    assert self.client().hello() == 'alive'

  def start_task(self, task):
    self.ready = False
    self.thread = mycloud.thread.spawn(self._run_task, task)
    mycloud.thread.sleep(0)
    return self.thread

  def client(self):
    logging.info('Created proxy to %s:%d', self.host, self.port)
    return xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port))

  def _run_task(self, task):
    try:
      task.run(self.client())
    except:
#      logging.info('Exception!', exc_info=1)
      self.cluster.report_exception(sys.exc_info())
    finally:
      self.ready = True


class Cluster(object):
  def __init__(self, machines=None, fs_prefix='/gfs'):
    self.machines = machines
    self.fs_prefix = fs_prefix
    self.servers = None
    self.exceptions = []

    self.start()

  def __del__(self):
    logging.info('Goodbye!')

  def report_exception(self, exc):
    self.exceptions.append(exc)

  def start(self):
    servers = []
    for host, cores in self.machines:
      for i in xrange(cores):
        servers.append(
#                       Server(self, host))
          mycloud.thread.spawn(lambda c, h: Server(c, h), self, host))

    servers = [s.wait() for s in servers]
    self.servers = servers
    logging.info('Started %d servers...', len(servers))


  def input(self, type, pattern):
    '''Return a cluster set of cluster inputs for the given pattern.'''
    return mycloud.resource.input(type, pattern)

  def output(self, type, name, shards):
    return mycloud.resource.output(type, name, shards)

  def map(self, f, arglist):
    assert len(arglist) > 0
    idx = 0

    arglist = mycloud.util.to_tuple(arglist)

    task_queue = Queue.Queue()

    tasks = [Task(i, f, args, {})
             for i, args in enumerate(arglist)]

    for t in tasks:
      task_queue.put(t)

    logging.info('Mapping %d tasks against %d servers',
                 len(tasks), len(self.servers))

    # instead of joining on the task_queue immediately, we poll the server threads
    # this way we can stop early in case we encounter an exception
    try:
      while 1:
        for s in self.servers:
          if not s.ready:
            continue

          t = task_queue.get_nowait()
          s.start_task(t)

        if self.exceptions:
          raise Exception, '\n'.join(traceback.format_exception(*self.exceptions[0]))

        mycloud.thread.sleep(0.1)
    except Queue.Empty:
      pass

    for t in tasks:
      while not t.done:
        if self.exceptions:
          raise Exception, '\n'.join(traceback.format_exception(*self.exceptions[0]))
        mycloud.thread.sleep(0.1)

    logging.info('Done.')
    return [t.result for t in tasks]