Commits

Russell Power committed ba610e4

Have reducers start a separate server for handling map input data;
return a client proxy object to pass around.

Add test for running with multiple output shards.

Comments (0)

Files changed (7)

       'eventlet',
       'pycrypto',
       'ssh',
-      'blocked-table',
+      'blocked-table>=1.05',
       'cloud',
     ],
 )

src/mycloud/cluster.py

     self.ready = True
     self.thread = None
 
-    assert self.client().hello() == 'alive'
+#    assert self.client().healthcheck() == 'alive'
 
   def start_task(self, task):
     self.ready = False

src/mycloud/mapreduce.py

       mapper = self.mapper
 
     for k, v in self.input.reader():
-      logging.info('Reading %s', k)
+#      logging.info('Reading %s', k)
       for mk, mv in mapper(k, v):
-        logging.info('Writing %s', k)
+#        logging.info('Writing %s', k)
         self.output(mk, mv)
     self.flush(final=True)
 
     self.thread = None
 
   def write_map_output(self, mapper, block, is_finished):
-    logging.info('Reading from mapper %d %d', mapper, is_finished)
+    logging.info('Reading from mapper %d - done? %d', mapper, is_finished)
     if is_finished:
       self.maps_finished[mapper] = 1
 
       self.flush()
 
   def flush(self):
+    logging.info('Flushing...')
+
     tf = tempfile.NamedTemporaryFile(suffix='reducer-tmp')
     bt = blocked_table.TableBuilder(tf.name)
     self.buffer.sort()
 
     self.map_tmp.append(tf)
 
+    logging.info('Flush finished to %s', tf.name)
+
   def start_server(self):
-    self.thread = mycloud.thread.spawn(self._run)
-    mycloud.thread.sleep(0)
+    self.proxy_server = mycloud.util.ProxyServer()
+    self.serving_thread = mycloud.thread.spawn(self.proxy_server.serve_forever)
+
     logging.info('Returning proxy to self')
-    return mycloud.util.Proxy(self)
+    return self.proxy_server.wrap(self)
 
   def _run(self):
-    # Read map outputs until all mappers have finished executing.
     while sum(self.maps_finished) != self.num_mappers:
-      mycloud.thread.sleep(1)
+      logging.info('Waiting for map data %d/%d',
+                   sum(self.maps_finished), self.num_mappers)
+      mycloud.thread.sleep(0.01)
+
     self.flush()
 
     logging.info('Finished reading map data, beginning merge.')
     logging.info('Returning output: %s', self.output)
 
   def wait(self):
-    self.thread.wait()
+    self._run()
+    logging.info('Waiting for reducer thread to finish...')
     return self.output
 
 
 
       self.cluster.map(lambda m: m.run(), mappers)
 
+      mycloud.thread.sleep(1)
+
       return [r.invoke('wait') for r in reduce_tasks]
     except:
       logging.info('MapReduce failed.', exc_info=1)

src/mycloud/thread.py

 def init():
   pass
 
-def spawn(f, *args):
+def spawn(f, *args, **kw):
   t = HelperThread(f, args)
-  t.setDaemon(True)
+  t.setDaemon(kw.get('daemon', True))
   t.start()
   return t
 

src/mycloud/util.py

 #!/usr/bin/env python
 
-import SimpleXMLRPCServer
+from cloud.serialization import cloudpickle
+from SocketServer import ThreadingMixIn
+from SimpleXMLRPCServer import SimpleXMLRPCServer
 import cPickle
 import logging
 import mycloud.thread
 import traceback
 import types
 import xmlrpclib
-from SocketServer import ThreadingMixIn
 
 class StreamLogger(object):
   '''Read lines from a file object in a separate thread.
     self.value = value
     self.tb = traceback.format_exc(tb)
 
-class XMLServer(ThreadingMixIn, SimpleXMLRPCServer.SimpleXMLRPCServer):
+class XMLServer(ThreadingMixIn, SimpleXMLRPCServer):
   def __init__(self, *args, **kw):
-    SimpleXMLRPCServer.SimpleXMLRPCServer.__init__(self, *args, **kw)
+    SimpleXMLRPCServer.__init__(self, *args, **kw)
 
   def server_bind(self):
     logging.info('Binding to address %s', self.server_address)
     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-    SimpleXMLRPCServer.SimpleXMLRPCServer.server_bind(self)
+    SimpleXMLRPCServer.server_bind(self)
 
-# reference to the worker being used
-WORKER = None
+  def handle_request(self):
+    try:
+      SimpleXMLRPCServer.handle_request(self)
+    except:
+      logging.exception('Failed to handle request.')
 
-class ClientProxy(object):
+
+class ProxyServer(SimpleXMLRPCServer):
+  def __init__(self):
+    self.wrapped_objects = {}
+    SimpleXMLRPCServer.__init__(self, ('0.0.0.0', find_open_port()))
+
+  def _dispatch(self, method, params):
+    return getattr(self, method)(*params)
+
+  def wrap(self, obj):
+    self.wrapped_objects[id(obj)] = obj
+    logging.info('Wrapped id %s', id(obj))
+    return ProxyObject(self.server_address[0],
+                       self.server_address[1],
+                       id(obj))
+
+  def invoke(self, objid, method, *args, **kw):
+    #logging.info('Invoking %s %s %s %s',
+    #             self.wrapped_objects[objid], method, args, kw)
+    return xmlrpclib.Binary(
+             cloudpickle.dumps(
+               getattr(self.wrapped_objects[objid], method)(*args, **kw)))
+
+
+class ProxyObject(object):
   def __init__(self, host, port, objid):
     self.host = host
     self.port = port
     self.server = None
 
   def get_server(self):
-    logging.info('Connecting to %s %d', self.host, self.port)
+#    logging.info('Connecting to %s %d', self.host, self.port)
     if not self.server:
       self.server = xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port),
                                           allow_none=True)
-    logging.info('Connection established to %s %d', self.host, self.port)
+#    logging.info('Connection established to %s %d', self.host, self.port)
     return self.server
 
   def invoke(self, method, *args, **kw):
     return cPickle.loads(
              self.get_server().invoke(self.objid, method, *args, **kw).data)
-
-def Proxy(obj):
-  key = WORKER.wrap(obj)
-  logging.info('Wrapped id %s', key)
-  return ClientProxy(WORKER.host, WORKER.port, key)

src/mycloud/worker.py

 import socket
 import sys
 import threading
+import time
 import xmlrpclib
 
 mycloud.thread.init()
 
 __doc__ = '''Worker for executing cluster tasks.'''
 
-logging.basicConfig(stream=sys.stderr,
-                      format='%(asctime)s %(funcName)s %(message)s',
-                      level=logging.INFO)
 
 class Worker(object):
   def __init__(self, host, port):
     self.host = host
     self.port = port
-    self.wrapped_objects = {}
+    self.last_keepalive = time.time()
 
   def execute_task(self, pickled):
     f, args, kw = cPickle.loads(pickled.data)
 #    logging.info('Got result!')
     return xmlrpclib.Binary(dump)
 
-  def hello(self):
+  def healthcheck(self):
+    self.last_keepalive = time.time()
     return 'alive'
 
-  def wrap(self, obj):
-    self.wrapped_objects[id(obj)] = obj
-    return id(obj)
-
-  def invoke(self, objid, method, *args, **kw):
-    #logging.info('Invoking %s %s %s %s',
-    #             self.wrapped_objects[objid], method, args, kw)
-    return xmlrpclib.Binary(
-             cloudpickle.dumps(
-               getattr(self.wrapped_objects[objid], method)(*args, **kw)))
-
 def dump_stderr(src, dst):
   while 1:
     data = src.get_value()
 
 
 if __name__ == '__main__':
+  myport = mycloud.util.find_open_port()
+
+  logging.basicConfig(stream=sys.stderr,
+                      #filename='/tmp/worker.%d.log' % myport,
+                      format='%(asctime)s %(funcName)s %(message)s',
+                      level=logging.INFO)
+
   # Open a server on an open port, and inform our caller
   old_stderr = sys.stderr
   sys.stderr = cStringIO.StringIO()
   stderr_log.setDaemon(True)
   stderr_log.start()
 
-  myport = mycloud.util.find_open_port()
   xmlserver = mycloud.util.XMLServer(('0.0.0.0', myport), allow_none=True)
   xmlserver.timeout = 1
 
   worker = Worker(socket.gethostname(), myport)
-  mycloud.util.WORKER = worker
+
+  xmlserver.register_function(worker.execute_task, 'execute_task')
+  xmlserver.register_function(worker.healthcheck, 'healthcheck')
 
   print myport
   sys.stdout.flush()
 
-  xmlserver.register_function(worker.execute_task, 'execute_task')
-  xmlserver.register_function(worker.invoke, 'invoke')
-  xmlserver.register_function(worker.hello, 'hello')
+  # handle requests until we lose our stdin connection the controller
+  try:
+    while 1:
+      xmlserver.handle_request()
 
-  # handle requests until our stdout is closed - (our controller shutdown or crashed)
-  while 1:
-    try:
-      xmlserver.handle_request()
-    except:
-      logging.info('Error handling request!!!', exc_info=1)
+      r, w, x = select.select([sys.stdin], [], [sys.stdin], 0)
+      if r or x:
+        break
+  except:
+    logging.info('Error while serving.', exc_info=1)
+
+  logging.info('Shutting down.')

tests/test_mapreduce.py

 import sys
 import unittest
 
+def map_identity(k, v):
+  yield (k, v)
+
+def reduce_sum(k, values):
+  #logging.info('%s %s', k, values)
+  yield (k, sum(values))
+
 class MapReduceTestCase(unittest.TestCase):
   def testSimpleMapper(self):
     cluster = mycloud.Cluster([('localhost', 4)])
     input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
+    output_desc = [mycloud.resource.MemoryFile() for i in range(1)]
+
+    mr = mycloud.mapreduce.MapReduce(cluster,
+                                     map_identity,
+                                     reduce_sum,
+                                     input_desc,
+                                     output_desc)
+    result = mr.run()
+
+    oiter = result[0].reader()
+    for j in range(100):
+      k, v = oiter.next()
+      self.assertEqual(k, j)
+      self.assertEqual(v, j * 10)
+
+  def testShardedOutput(self):
+    cluster = mycloud.Cluster([('localhost', 4)])
+    input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
     output_desc = [mycloud.resource.MemoryFile() for i in range(5)]
 
-    def map_identity(k, v):
-      yield (k, v)
-
-    def reduce_sum(k, values):
-      logging.info('%s %s', k, values)
-      yield (k, sum(values))
-
     mr = mycloud.mapreduce.MapReduce(cluster,
                                      map_identity,
                                      reduce_sum,
     result = mr.run()
 
     logging.info('Result %s %s', result[0], result[0].__class__)
-    for k, v in result[0].reader():
-      logging.info('Result: %s %s', k, v)
+    for i in range(5):
+      j = i
+      count = 0
+      for k, v in result[i].reader():
+        self.assertEqual(k, j)
+        self.assertEqual(v, j * 10)
+        j += 5
+        count += 1
 
-    oiter = result[0].reader()
-    for j in range(100):
-      k, v = oiter.next()
-      self.assertEqual(k, j)
-      self.assertEqual(v, j * 10)
+      self.assertEqual(count, 20)
 
 
 if __name__ == "__main__":