Commits

Russell Power committed 99a465b

Updates...

  • Participants
  • Parent commits e57d0f0

Comments (0)

Files changed (7)

File src/mycloud/cluster.py

 #!/usr/bin/env python
 
 from cloud.serialization import cloudpickle
-from eventlet import sleep
-from signal import *
 import Queue
 import cPickle
 import logging
 import mycloud.connections
+import mycloud.thread
 import mycloud.util
-import os
 import random
 import sys
-import time
 import traceback
+import xmlrpclib
+
+mycloud.thread.init()
+
+def arg_name(args):
+  '''Returns a short string representation of an argument list for a task.'''
+  a = args[0]
+  if isinstance(a, object):
+    return a.__class__.__name__
+  return repr(a)
+
 
 class Task(object):
   '''A piece of work to be executed.'''
-  def __init__(self, idx, function, args, buffer=False):
+  def __init__(self, idx, function, args, kw):
     self.idx = idx
-    self.pickle = cloudpickle.dumps((function, args))
+    self.pickle = cloudpickle.dumps((function, args, kw))
     self.result = None
-    self.stderr_logger = mycloud.util.StreamLogger('Task %05d' % self.idx, buffer)
+    self.done = False
 
-  def run(self, host, client):
+  def run(self, client):
+    logging.info('Starting task %s on %s', self.idx, client)
+    self.client = client
+    result_data = self.client.execute_task(xmlrpclib.Binary(self.pickle))
+    self.result = cPickle.loads(result_data.data)
+    self.done = True
+    logging.info('Task %d finished', self.idx)
+    return self.result
+
+  def wait(self):
+    while not self.result:
+      mycloud.thread.sleep(1)
+    return self.result
+
+
+class Server(object):
+  '''Handles connections to remote machines and execution of tasks.
+  
+A Server is created for each core on a machine, and executes tasks as
+machine resources become available.'''
+  def __init__(self, cluster, host):
+    self.cluster = cluster
     self.host = host
-    self.client = client
-
-    logging.info('Starting task %d', self.idx)
-    stdin, stdout, stderr = client.invoke(
+    ssh = mycloud.connections.SSHConnection.connect(host)
+    self.stdin, self.stdout, self.stderr = ssh.invoke(
       sys.executable, '-m', 'mycloud.worker')
 
-    self.stderr_logger.start(stderr)
-    stdin.write(self.pickle)
-    self.result = cPickle.load(stdout)
-    self.stderr_logger.join()
-    if isinstance(self.result, Exception):
-      raise self.result
-    logging.info('Task %d finished.', self.idx)
+    self.port = int(self.stdout.readline().strip())
 
+    self.stderr_logger = (
+      mycloud.util.StreamLogger(
+         'Remote(%s, %d)' % (self.host, self.port), buffer=False))
+    self.stderr_logger.start(self.stderr)
 
-class ServerProxy(object):
-  '''Handles connections to remote machines and execution of tasks.
-  
-A ServerProxy is created for each core on a machine, and executes tasks as
-machine resources become available.'''
-  def __init__(self, host):
-    self.host = host
-    self.client = mycloud.connections.SSHConnection.connect(host)
-    self.exception = None
+    self.ready = True
+    self.thread = None
 
-  def start_task_queue(self, task_queue):
-    logging.info('Starting server proxy.')
-    self.thread = mycloud.thread.spawn(self._read_queue, task_queue)
-    sleep(0)
+    assert self.client().hello() == 'alive'
+
+  def start_task(self, task):
+    self.ready = False
+    self.thread = mycloud.thread.spawn(self._run_task, task)
+    mycloud.thread.sleep(0)
     return self.thread
 
-  def start_task(self, task):
-    self.thread = mycloud.thread.spawn(self._run_task, task)
-    sleep(0)
-    return self.thread
+  def client(self):
+    logging.info('Created proxy to %s:%d', self.host, self.port)
+    return xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port))
 
   def _run_task(self, task):
     try:
-      task.run(self.host, self.client)
-    except Exception, e:
-      logging.info('Exception occurred during remote execution of task %d.' +
-             'stderr dump: \n\n%s\n\n',
-             task.idx,
-             task.stderr_logger.dump(),
-             exc_info=1)
-      self.exception = e
-
-  def _read_queue(self, task_queue):
-    '''Execute tasks from the task queue.
-    
-If an exception is raised by a task, halt and set self.exception.
-
-Stops when the queue becomes empty.'''
-    try:
-      while 1:
-        task = task_queue.get_nowait()
-        self._run_task(task)
-        task_queue.task_done()
-    except Queue.Empty:
-      return
-
+      task.run(self.client())
+    except:
+#      logging.info('Exception!', exc_info=1)
+      self.cluster.report_exception(sys.exc_info())
+    finally:
+      self.ready = True
 
 
 class Cluster(object):
     self.machines = machines
     self.fs_prefix = fs_prefix
     self.servers = None
+    self.exceptions = []
 
     self.start()
 
+  def __del__(self):
+    logging.info('Goodbye!')
+
+  def report_exception(self, exc):
+    self.exceptions.append(exc)
+
   def start(self):
     servers = []
     for host, cores in self.machines:
       for i in xrange(cores):
-        s = ServerProxy(host)
-        servers.append(s)
-    random.shuffle(servers)
+        servers.append(
+#                       Server(self, host))
+          mycloud.thread.spawn(lambda c, h: Server(c, h), self, host))
+
+    servers = [s.wait() for s in servers]
     self.servers = servers
     logging.info('Started %d servers...', len(servers))
 
   def output(self, type, name, shards):
     return mycloud.resource.output(type, name, shards)
 
-  def run(self, f, arglist):
-    assert len(arglist) < len(self.servers)
-    arglist = mycloud.util.to_tuple(arglist)
-
-    tasks = [Task(i, f, arg, buffer=False) for i, arg in enumerate(arglist)]
-    for s, task in zip(self.servers, tasks):
-      s.start_task(task)
-    return tasks
-
-
   def map(self, f, arglist):
     assert len(arglist) > 0
     idx = 0
 
     task_queue = Queue.Queue()
 
-    tasks = [Task(i, f, args) for i, args in enumerate(arglist)]
+    tasks = [Task(i, f, args, {})
+             for i, args in enumerate(arglist)]
 
     for t in tasks:
       task_queue.put(t)
     logging.info('Mapping %d tasks against %d servers',
                  len(tasks), len(self.servers))
 
-    threads = [s.start_task_queue(task_queue) for s in self.servers]
-
     # instead of joining on the task_queue immediately, we poll the server threads
     # this way we can stop early in case we encounter an exception
-    while not task_queue.empty():
-      for s in self.servers:
-        if s.exception:
-          raise s.exception
-      sleep(0.1)
+    try:
+      while 1:
+        for s in self.servers:
+          if not s.ready:
+            continue
 
-    task_queue.join()
-    for t in threads:
-      t.join()
+          t = task_queue.get_nowait()
+          s.start_task(t)
 
+        if self.exceptions:
+          raise Exception, '\n'.join(traceback.format_exception(*self.exceptions[0]))
+
+        mycloud.thread.sleep(0.1)
+    except Queue.Empty:
+      pass
+
+    for t in tasks:
+      while not t.done:
+        if self.exceptions:
+          raise Exception, '\n'.join(traceback.format_exception(*self.exceptions[0]))
+        mycloud.thread.sleep(0.1)
+
+    logging.info('Done.')
     return [t.result for t in tasks]

File src/mycloud/connections.py

 import subprocess
 
 class SSHConnection(object):
-  cache = {}
+  connections = []
 
-  def __init__(self, client):
-    self.client = client
+  def __init__(self, host):
+    self.host = host
+    self.client = ssh.SSHClient()
+    self.client.set_missing_host_key_policy(ssh.AutoAddPolicy())
+    self.client.connect(host)
 
   @staticmethod
   def connect(host):
-    if not host in SSHConnection.cache:
-      client = ssh.SSHClient()
-      client.set_missing_host_key_policy(ssh.AutoAddPolicy())
-      client.connect(host)
-      SSHConnection.cache[host] = SSHConnection(client)
-      logging.info('Connecting to %s', host)
-
-    return SSHConnection.cache[host]
+    c = SSHConnection(host)
+    SSHConnection.connections.append(c)
+    return c
 
   def invoke(self, command, *args):
     logging.info('Invoking %s %s', command, args)
 
   @staticmethod
   def shutdown():
-    for host, connection in SSHConnection.cache.items():
-      logging.info('Closing SSH connection to %s', host)
+    for connection in SSHConnection.connections:
+      logging.info('Closing SSH connection to %s', connection.host)
       connection.client.close()
 
 

File src/mycloud/mapreduce.py

 #!/usr/bin/env python
 
-from eventlet import sleep
-from mycloud.merge import Merger
-from os.path import join
-import SimpleXMLRPCServer
 import blocked_table
 import collections
 import json
 import logging
-import socket
-import sys
+import mycloud.merge
+import mycloud.thread
+import mycloud.util
 import tempfile
-import time
 import types
 import xmlrpclib
 
-REDUCER_PORT_BASE = 40000
-
 def shard_for_key(k, num_shards):
   return hash(k) % num_shards
 
 def sum_reducer(k, values):
   yield k, sum(values)
 
-class XMLServer(SimpleXMLRPCServer.SimpleXMLRPCServer):
-  def __init__(self, *args, **kw):
-    SimpleXMLRPCServer.SimpleXMLRPCServer.__init__(self, *args, **kw)
-
-  def shutdown(self):
-    self.__shutdown_request = True
-
-  def server_bind(self):
-    logging.info('Binding to address %s', self.server_address)
-    self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-    SimpleXMLRPCServer.SimpleXMLRPCServer.server_bind(self)
-
 class MRHelper(object):
   def __init__(self,
                mapper,
                num_mappers,
                num_reducers,
                map_buffer_size=1000,
-               reduce_buffer_size=10e6):
+               reduce_buffer_size=100e6):
     self.mapper = mapper
     self.reducer = reducer
     self.tmp_prefix = tmp_prefix
     for shard in range(self.num_reducers):
       shard_output = self.output_tmp[shard]
       logging.info('Writing to reducer')
-      self.reduce_proxy[shard].write_map_output(self.index,
-                                                json.dumps(shard_output),
-                                                final)
+      self.reducers[shard].invoke('write_map_output',
+                                  self.index,
+                                  json.dumps(shard_output),
+                                  final)
 
     self.output_tmp.clear()
     self.buffer_size = 0
     logging.info('Flush finished.')
 
   def run(self):
-    self.reduce_proxy = [
-      xmlrpclib.ServerProxy('http://%s:%d' % (self.reducers[i], REDUCER_PORT_BASE + i))
-      for i in range(len(self.reducers))]
-
-
     logging.info('Reading from: %s', self.input)
     if isinstance(self.mapper, types.ClassType):
       mapper = self.mapper(self.mrinfo, self.index, self.input)
       mapper = self.mapper
 
     for k, v in self.input.reader():
+      logging.info('Reading %s', k)
       for mk, mv in mapper(k, v):
+        logging.info('Writing %s', k)
         self.output(mk, mv)
     self.flush(final=True)
 
 
     self.index = index
     self.buffer_size = 0
+    self.buffer = []
     self.map_tmp = []
     self.maps_finished = [0] * self.num_mappers
     self.output = output
 
+    self.serving = False
+    self.thread = None
+
   def write_map_output(self, mapper, block, is_finished):
     logging.info('Reading from mapper %d %d', mapper, is_finished)
     if is_finished:
 
     self.buffer_size += len(block)
     for k, v in json.loads(block):
-      self.map_tmp.append((k, v))
+      self.buffer.append((k, v))
 
     if self.buffer_size > self.reduce_buffer_size:
-      self.flush_map()
+      self.flush()
 
-    if sum(self.maps_finished) == self.num_mappers:
-      logging.info('Shutting down reducer.')
-      self.server.shutdown()
-
-  def flush_map(self):
+  def flush(self):
     tf = tempfile.NamedTemporaryFile(suffix='reducer-tmp')
-    bt = blocked_table.BlockedTableBuilder(tf.name)
-    self.map_tmp.sort()
-    for k, v in self.map_tmp:
+    bt = blocked_table.TableBuilder(tf.name)
+    self.buffer.sort()
+    for k, v in self.buffer:
       bt.add(k, v)
     del bt
 
     self.map_tmp.append(tf)
 
+  def start_server(self):
+    self.thread = mycloud.thread.spawn(self._run)
+    mycloud.thread.sleep(0)
+    logging.info('Returning proxy to self')
+    return mycloud.util.Proxy(self)
 
-  def run(self):
+  def _run(self):
     # Read map outputs until all mappers have finished executing.
-    self.server = XMLServer(('0.0.0.0', REDUCER_PORT_BASE + self.index),
-                            allow_none=True)
-    self.server.register_function(self.write_map_output)
-    self.server.serve_forever()
+    while sum(self.maps_finished) != self.num_mappers:
+      mycloud.thread.sleep(1)
+    self.flush()
 
-    inputs = [blocked_table.BlockedTable(tf.name) for tf in self.map_tmp]
+    logging.info('Finished reading map data, beginning merge.')
+
+    inputs = [blocked_table.Table(tf.name).iteritems() for tf in self.map_tmp]
     out = self.output.writer()
 
     if isinstance(self.reducer, types.ClassType):
       reducer = self.reducer
 
     logging.info('Reducing over %s temporary map inputs.', len(inputs))
-    for k, v in Merger(inputs):
+    for k, v in mycloud.merge.Merger(inputs):
 #      logging.info('REDUCE: %s %s', k, v)
       for rk, rv in reducer(k, v):
         out.add(rk, rv)
 
-    logging.info('Returning map output: %s', self.output)
+    logging.info('Returning output: %s', self.output)
+
+  def wait(self):
+    self.thread.wait()
     return self.output
 
 
     logging.info('Inputs: %s...', self.input[:10])
     logging.info('Outputs: %s...', self.output[:10])
 
-    reducers = [ReduceHelper(index=i,
-                             output=self.output[i],
-                             mapper=self.mapper,
-                             reducer=self.reducer,
-                             num_mappers=len(self.input),
-                             num_reducers=len(self.output),
-                             tmp_prefix=self.cluster.fs_prefix + '/tmp/mr')
-                for i in range(len(self.output)) ]
+    try:
+      reducers = [ReduceHelper(index=i,
+                               output=self.output[i],
+                               mapper=self.mapper,
+                               reducer=self.reducer,
+                               num_mappers=len(self.input),
+                               num_reducers=len(self.output),
+                               tmp_prefix=self.cluster.fs_prefix + '/tmp/mr')
+                  for i in range(len(self.output)) ]
 
+      reduce_tasks = self.cluster.map(lambda r: r.start_server(), reducers)
 
-    reduce_tasks = self.cluster.run(lambda r: r.run(), reducers)
-    sleep(5)
+      mappers = [MapHelper(index=i,
+                           input=self.input[i],
+                           reducers=reduce_tasks,
+                           mapper=self.mapper,
+                           reducer=self.reducer,
+                           num_mappers=len(self.input),
+                           num_reducers=len(self.output),
+                           tmp_prefix=self.cluster.fs_prefix + '/tmp/mr')
+                 for i in range(len(self.input)) ]
 
-    mappers = [MapHelper(index=i,
-                         input=self.input[i],
-                         reducers=[t.host for t in reduce_tasks],
-                         mapper=self.mapper,
-                         reducer=self.reducer,
-                         num_mappers=len(self.input),
-                         num_reducers=len(self.output),
-                         tmp_prefix=self.cluster.fs_prefix + '/tmp/mr')
-               for i in range(len(self.input)) ]
+      self.cluster.map(lambda m: m.run(), mappers)
 
-    self.cluster.map(lambda m: m.run(), mappers)
+      return [r.invoke('wait') for r in reduce_tasks]
+    except:
+      logging.info('MapReduce failed.', exc_info=1)
+      raise

File src/mycloud/thread.py

 #!/usr/bin/env python
 
-import eventlet
-import eventlet.debug
+import threading
+import time
 
-eventlet.debug.hub_listener_stacks(True)
-eventlet.debug.hub_exceptions(True)
-eventlet.debug.hub_blocking_detection(True)
+class HelperThread(threading.Thread):
+  def __init__(self, f, args):
+    self.f = f
+    self.args = args
+    self.result = None
+    threading.Thread.__init__(self)
+
+  def run(self):
+    self.result = self.f(*self.args)
+
+  def wait(self):
+    self.join()
+    return self.result
+
+def init():
+  pass
 
 def spawn(f, *args):
-  return eventlet.spawn(f, *args)
+  t = HelperThread(f, args)
+  t.setDaemon(True)
+  t.start()
+  return t
+
+def sleep(timeout):
+  time.sleep(timeout)

File src/mycloud/util.py

 #!/usr/bin/env python
 
+import SimpleXMLRPCServer
+import cPickle
 import logging
+import mycloud.thread
+import socket
+import sys
 import traceback
 import types
-import mycloud.thread
+import xmlrpclib
+from SocketServer import ThreadingMixIn
 
 class StreamLogger(object):
   '''Read lines from a file object in a separate thread.
         break
 
       if not self.buffer:
-        logging.info(self.prefix + ' --- ' + line.strip())
+        print >> sys.stderr, self.prefix + ' --- ' + line.strip()
 
       self.lines.append(line.strip())
 
 
   return arglist
 
+def find_open_port():
+  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+  s.bind(("", 0))
+  s.listen(1)
+  port = s.getsockname()[1]
+  s.close()
+  return port
+
 class RemoteException(object):
   def __init__(self, type, value, tb):
     self.type = type
     self.value = value
     self.tb = traceback.format_exc(tb)
+
+class XMLServer(ThreadingMixIn, SimpleXMLRPCServer.SimpleXMLRPCServer):
+  def __init__(self, *args, **kw):
+    SimpleXMLRPCServer.SimpleXMLRPCServer.__init__(self, *args, **kw)
+
+  def server_bind(self):
+    logging.info('Binding to address %s', self.server_address)
+    self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    SimpleXMLRPCServer.SimpleXMLRPCServer.server_bind(self)
+
+# reference to the worker being used
+WORKER = None
+
+class ClientProxy(object):
+  def __init__(self, host, port, objid):
+    self.host = host
+    self.port = port
+    self.objid = objid
+    self.server = None
+
+  def get_server(self):
+    logging.info('Connecting to %s %d', self.host, self.port)
+    if not self.server:
+      self.server = xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port),
+                                          allow_none=True)
+    logging.info('Connection established to %s %d', self.host, self.port)
+    return self.server
+
+  def invoke(self, method, *args, **kw):
+    return cPickle.loads(
+             self.get_server().invoke(self.objid, method, *args, **kw).data)
+
+def Proxy(obj):
+  key = WORKER.wrap(obj)
+  logging.info('Wrapped id %s', key)
+  return ClientProxy(WORKER.host, WORKER.port, key)

File src/mycloud/worker.py

 #!/usr/bin/env python
 from cloud.serialization import cloudpickle
 import cPickle
+import cStringIO
 import logging
-import os
+import mycloud.thread
+import mycloud.util
+import select
+import socket
 import sys
-import traceback
+import threading
+import xmlrpclib
+
+mycloud.thread.init()
 
 __doc__ = '''Worker for executing cluster tasks.'''
 
                       format='%(asctime)s %(funcName)s %(message)s',
                       level=logging.INFO)
 
-def execute_task():
-  '''Execute a function and it's arguments, as read from stdin.'''
+class Worker(object):
+  def __init__(self, host, port):
+    self.host = host
+    self.port = port
+    self.wrapped_objects = {}
 
-  # capture stdin and stdout in case user code attempts to access them
-  input = sys.stdin
-  output = sys.stdout
+  def execute_task(self, pickled):
+    f, args, kw = cPickle.loads(pickled.data)
+    logging.info('Executing task %s %s %s', f, args, kw)
+    result = f(*args, **kw)
+    dump = cloudpickle.dumps(result)
+#    logging.info('Got result!')
+    return xmlrpclib.Binary(dump)
 
-  sys.stdin = open('/dev/null', 'r')
-  sys.stdout = open('/dev/null', 'w')
+  def hello(self):
+    return 'alive'
 
-  logging.info('Task execution start.')
-  try:
-    logging.info('Loading function and arguments.')
-    f, args = cPickle.load(input)
-    logging.info('Done.  Executing function %s(%s)', f, args)
-    result = f(*args)
-    logging.info('Done. Returning result.')
-    cloudpickle.dump(result, output)
-    logging.info('Done!')
-  except Exception, e:
-    logging.info('Exception! %s', traceback.print_exc())
-    cloudpickle.dump(Exception(traceback.format_exc()), output)
-  logging.info('Task execution finished.')
+  def wrap(self, obj):
+    self.wrapped_objects[id(obj)] = obj
+    return id(obj)
+
+  def invoke(self, objid, method, *args, **kw):
+    #logging.info('Invoking %s %s %s %s',
+    #             self.wrapped_objects[objid], method, args, kw)
+    return xmlrpclib.Binary(
+             cloudpickle.dumps(
+               getattr(self.wrapped_objects[objid], method)(*args, **kw)))
+
+def dump_stderr(src, dst):
+  while 1:
+    data = src.get_value()
+    src.truncate()
+    dst.write(data)
+    mycloud.thread.sleep(1)
 
 
 if __name__ == '__main__':
-  execute_task()
+  # Open a server on an open port, and inform our caller
+  old_stderr = sys.stderr
+  sys.stderr = cStringIO.StringIO()
+
+  stderr_log = threading.Thread(target=dump_stderr, args=(sys.stderr, old_stderr))
+  stderr_log.setDaemon(True)
+  stderr_log.start()
+
+  myport = mycloud.util.find_open_port()
+  xmlserver = mycloud.util.XMLServer(('0.0.0.0', myport), allow_none=True)
+  xmlserver.timeout = 1
+
+  worker = Worker(socket.gethostname(), myport)
+  mycloud.util.WORKER = worker
+
+  print myport
+  sys.stdout.flush()
+
+  xmlserver.register_function(worker.execute_task, 'execute_task')
+  xmlserver.register_function(worker.invoke, 'invoke')
+  xmlserver.register_function(worker.hello, 'hello')
+
+  # handle requests until our stdout is closed - (our controller shutdown or crashed)
+  while 1:
+    try:
+      xmlserver.handle_request()
+    except:
+      logging.info('Error handling request!!!', exc_info=1)

File tests/test_mapreduce.py

   def testSimpleMapper(self):
     cluster = mycloud.Cluster([('localhost', 4)])
     input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
-    output_desc = [mycloud.resource.MemoryFile() for i in range(1)]
+    output_desc = [mycloud.resource.MemoryFile() for i in range(5)]
 
     def map_identity(k, v):
       yield (k, v)
                                      output_desc)
     result = mr.run()
 
+    logging.info('Result %s %s', result[0], result[0].__class__)
     for k, v in result[0].reader():
       logging.info('Result: %s %s', k, v)