Commits

Russell Power committed 3316774

Better logging; rather then sucking down stderr and stdout, workers
now send log messages via UDP to the controller and also log to local
temp files.

Worker watchdog now runs in a separate thread; this allows workers to
be shutdown in in the middle of long running map tasks.

Comments (0)

Files changed (8)

     author="Russell Power",
     author_email="power@cs.nyu.edu",
     license="BSD",
-    version="0.22",
+    version="0.23",
     url="http://rjpower.org/mycloud",
     package_dir={ '' : 'src' },
     packages=[ 'mycloud' ],

src/mycloud/__init__.py

+#!/usr/bin/env python
+
 import mycloud.cluster
-import mycloud.connections
-import mycloud.mapreduce
-import mycloud.merge
-import mycloud.resource
-import mycloud.util
-
-resource = mycloud.resource
-mapreduce = mycloud.mapreduce
 
 Cluster = mycloud.cluster.Cluster

src/mycloud/cluster.py

 import mycloud.thread
 import mycloud.util
 import random
+import socket
 import sys
 import traceback
 import xmlrpclib
 
 mycloud.thread.init()
 
+class ClusterException(Exception):
+  pass
+
 def arg_name(args):
   '''Returns a short string representation of an argument list for a task.'''
   a = args[0]
     return a.__class__.__name__
   return repr(a)
 
-
 class Task(object):
   '''A piece of work to be executed.'''
-  def __init__(self, idx, function, args, kw):
-    self.idx = idx
-    self.pickle = cloudpickle.dumps((function, args, kw))
+  def __init__(self, name, index, function, args, kw):
+    self.idx = index
+    self.pickle = xmlrpclib.Binary(cloudpickle.dumps((function, args, kw)))
     self.result = None
     self.done = False
 
   def run(self, client):
     logging.info('Starting task %s on %s', self.idx, client)
     self.client = client
-    result_data = self.client.execute_task(xmlrpclib.Binary(self.pickle))
+    result_data = self.client.execute_task(self.pickle)
     self.result = cPickle.loads(result_data.data)
     self.done = True
     logging.info('Task %d finished', self.idx)
   
 A Server is created for each core on a machine, and executes tasks as
 machine resources become available.'''
-  def __init__(self, cluster, host):
+  def __init__(self, cluster, host, index):
     self.cluster = cluster
     self.host = host
-    ssh = mycloud.connections.SSHConnection.connect(host)
+    self.index = index
+
+  def connect(self):
+    ssh = mycloud.connections.SSH.connect(self.host)
     self.stdin, self.stdout, self.stderr = ssh.invoke(
-      sys.executable, '-m', 'mycloud.worker')
+      sys.executable,
+      '-m', 'mycloud.worker',
+      '--index %s' % self.index,
+      '--logger_host %s' % socket.gethostname(),
+      '--logger_port %s' % logging.handlers.DEFAULT_UDP_LOGGING_PORT)
 
     self.port = int(self.stdout.readline().strip())
 
-    self.stderr_logger = (
-      mycloud.util.StreamLogger(
-         'Remote(%s, %d)' % (self.host, self.port), buffer=False))
-    self.stderr_logger.start(self.stderr)
-
     self.ready = True
     self.thread = None
 
-#    assert self.client().healthcheck() == 'alive'
-
   def start_task(self, task):
     self.ready = False
     self.thread = mycloud.thread.spawn(self._run_task, task)
     try:
       task.run(self.client())
     except:
-#      logging.info('Exception!', exc_info=1)
+      #logging.exception('Failed to run task')
       self.cluster.report_exception(sys.exc_info())
+      self.stdin.close()
+      self.stdout.close()
     finally:
       self.ready = True
 
 
 class Cluster(object):
-  def __init__(self, machines=None, fs_prefix='/gfs'):
+  def __init__(self, machines=None, tmp_prefix=None):
     self.machines = machines
-    self.fs_prefix = fs_prefix
+    self.tmp_prefix = tmp_prefix
     self.servers = None
     self.exceptions = []
 
+    assert self.machines
+    assert self.tmp_prefix
+
     self.start()
 
   def __del__(self):
   def report_exception(self, exc):
     self.exceptions.append(exc)
 
+  def log_exceptions(self):
+    for e in self.exceptions:
+      exc_dump = ['remote exception: ' + line
+               for line in traceback.format_exception(*e)]
+      logging.info('\n'.join(exc_dump))
+
   def start(self):
+    self.log_server = mycloud.util.LoggingServer(self)
+    mycloud.thread.spawn(self.log_server.serve_forever)
+
     servers = []
+    index = 0
     for host, cores in self.machines:
       for i in xrange(cores):
-        servers.append(
-#                       Server(self, host))
-          mycloud.thread.spawn(lambda c, h: Server(c, h), self, host))
+        s = Server(self, host, index)
+        servers.append(s)
+        index += 1
 
-    servers = [s.wait() for s in servers]
+    connections = [mycloud.thread.spawn(s.connect) for s in servers]
+    [c.wait() for c in connections]
+
     self.servers = servers
+    random.shuffle(self.servers)
     logging.info('Started %d servers...', len(servers))
 
-
   def input(self, type, pattern):
     '''Return a cluster set of cluster inputs for the given pattern.'''
     return mycloud.resource.input(type, pattern)
   def output(self, type, name, shards):
     return mycloud.resource.output(type, name, shards)
 
-  def map(self, f, arglist):
+  def show_status(self):
+    return
+    for (host, port), rlog in self.log_server.message_map.items():
+      print >> sys.stderr, '(%s, %s) -- %s' % (host, port, rlog.msg)
+
+  def map(self, f, arglist, name='worker'):
     assert len(arglist) > 0
     idx = 0
 
 
     task_queue = Queue.Queue()
 
-    tasks = [Task(i, f, args, {})
+    tasks = [Task(name, i, f, args, {})
              for i, args in enumerate(arglist)]
 
     for t in tasks:
           s.start_task(t)
 
         if self.exceptions:
-          raise Exception, '\n'.join(traceback.format_exception(*self.exceptions[0]))
+          self.log_exceptions()
+          raise ClusterException
 
         mycloud.thread.sleep(0.1)
+        self.show_status()
+
     except Queue.Empty:
       pass
 
     for t in tasks:
       while not t.done:
         if self.exceptions:
-          raise Exception, '\n'.join(traceback.format_exception(*self.exceptions[0]))
+          self.log_exceptions()
+          raise ClusterException
         mycloud.thread.sleep(0.1)
 
     logging.info('Done.')

src/mycloud/connections.py

 import logging
 import ssh
 import subprocess
+import threading
 
-class SSHConnection(object):
-  connections = []
+class SSH(object):
+  connections = {}
+  connections_lock = threading.Lock()
 
   def __init__(self, host):
     self.host = host
+    self.lock = threading.Lock()
     self.client = ssh.SSHClient()
     self.client.set_missing_host_key_policy(ssh.AutoAddPolicy())
-    self.client.connect(host)
+    self._connected = False
+
+  def _connect(self):
+    self.client.connect(self.host)
+    self._connected = True
+
+  def close(self):
+    self.client.close()
 
   @staticmethod
   def connect(host):
-    c = SSHConnection(host)
-    SSHConnection.connections.append(c)
-    return c
+    with SSH.connections_lock:
+      if not host in SSH.connections:
+        SSH.connections[host] = SSH(host)
+
+    return SSH.connections[host]
 
   def invoke(self, command, *args):
+    with self.lock:
+      if not self._connected:
+        self._connect()
+
     logging.info('Invoking %s %s', command, args)
     chan = self.client._transport.open_session()
     stdin = chan.makefile('wb', 64)
 
   @staticmethod
   def shutdown():
-    for connection in SSHConnection.connections:
-      logging.info('Closing SSH connection to %s', connection.host)
-      connection.client.close()
+    logging.info('Closing all SSH connections')
+    for connection in SSH.connections.values():
+      connection.close()
 
 
-class LocalConnection(object):
+class Local(object):
   @staticmethod
   def connect(host):
-    return LocalConnection()
+    return Local()
 
   def invoke(self, command, *args):
     p = subprocess.Popen([command] + list(args),
     return (p.stdin, p.stdout, p.stderr)
 
 
-atexit.register(SSHConnection.shutdown)
+atexit.register(SSH.shutdown)

src/mycloud/mapreduce.py

 import mycloud.merge
 import mycloud.thread
 import mycloud.util
-import tempfile
 import types
 import xmlrpclib
 
   return r
 
 
-def identity_mapper(k, v):
-  yield k, v
+def identity_mapper(k, v, output):
+  output(k, v)
 
-def identity_reducer(k, values):
+def identity_reducer(k, values, output):
   for v in values:
-    yield k, v
+    output(k, v)
 
-def sum_reducer(k, values):
-  yield k, sum(values)
+def sum_reducer(k, values, output):
+  output(k, sum(values))
 
 class MRHelper(object):
   def __init__(self,
                tmp_prefix,
                num_mappers,
                num_reducers,
-               map_buffer_size=1000,
+               map_buffer_size=100,
                reduce_buffer_size=100e6):
     self.mapper = mapper
     self.reducer = reducer
       self.flush()
 
   def flush(self, final=False):
-    logging.info('Flushing map %d', self.index)
     for shard in range(self.num_reducers):
       shard_output = self.output_tmp[shard]
-      logging.info('Writing to reducer')
+      if not final and not shard_output:
+        continue
+
       self.reducers[shard].invoke('write_map_output',
                                   self.index,
                                   json.dumps(shard_output),
 
     self.output_tmp.clear()
     self.buffer_size = 0
-    logging.info('Flush finished.')
+    logging.info('Flushed map %d', self.index)
 
   def run(self):
     logging.info('Reading from: %s', self.input)
     else:
       mapper = self.mapper
 
-    for k, v in self.input.reader():
-#      logging.info('Reading %s', k)
-      for mk, mv in mapper(k, v):
-#        logging.info('Writing %s', k)
-        self.output(mk, mv)
+    reader = self.input.reader()
+    for k, v in reader:
+      #logging.info('Read %s', k)
+      mapper(k, v, self.output)
+      #logging.info('Mapped %s', k)
     self.flush(final=True)
+    logging.info('Map of %s finished.', self.input)
 
 
 class ReduceHelper(MRHelper):
     self.thread = None
 
   def write_map_output(self, mapper, block, is_finished):
-    logging.info('Reading from mapper %d - done? %d', mapper, is_finished)
     if is_finished:
       self.maps_finished[mapper] = 1
 
       self.flush()
 
   def flush(self):
-    logging.info('Flushing...')
+    logging.info('Reducer flushing - %s', self.buffer_size)
 
-    tf = tempfile.NamedTemporaryFile(suffix='reducer-tmp')
+    tf = mycloud.util.create_tempfile(dir=self.tmp_prefix,
+                                      suffix='reducer-tmp')
     bt = blocked_table.TableBuilder(tf.name)
     self.buffer.sort()
     for k, v in self.buffer:
     del bt
 
     self.map_tmp.append(tf)
-
+    self.buffer_size = 0
     logging.info('Flush finished to %s', tf.name)
 
   def start_server(self):
+    logging.info('Starting server...')
     self.proxy_server = mycloud.util.ProxyServer()
     self.serving_thread = mycloud.thread.spawn(self.proxy_server.serve_forever)
 
     logging.info('Reducing over %s temporary map inputs.', len(inputs))
     for k, v in mycloud.merge.Merger(inputs):
 #      logging.info('REDUCE: %s %s', k, v)
-      for rk, rv in reducer(k, v):
-        out.add(rk, rv)
+      reducer(k, v, out.add)
 
     logging.info('Returning output: %s', self.output)
 
                                reducer=self.reducer,
                                num_mappers=len(self.input),
                                num_reducers=len(self.output),
-                               tmp_prefix=self.cluster.fs_prefix + '/tmp/mr')
+                               tmp_prefix=self.cluster.tmp_prefix)
                   for i in range(len(self.output)) ]
 
       reduce_tasks = self.cluster.map(lambda r: r.start_server(), reducers)
                            reducer=self.reducer,
                            num_mappers=len(self.input),
                            num_reducers=len(self.output),
-                           tmp_prefix=self.cluster.fs_prefix + '/tmp/mr')
+                           tmp_prefix=self.cluster.tmp_prefix)
                  for i in range(len(self.input)) ]
 
       self.cluster.map(lambda m: m.run(), mappers)
 #!/usr/bin/env python
 
+from SimpleXMLRPCServer import SimpleXMLRPCServer
+from SocketServer import UDPServer
 from cloud.serialization import cloudpickle
-from SocketServer import ThreadingMixIn
-from SimpleXMLRPCServer import SimpleXMLRPCServer
 import cPickle
 import logging
-import mycloud.thread
+import os
 import socket
-import sys
+import struct
+import tempfile
+import time
 import traceback
 import types
 import xmlrpclib
 
-class StreamLogger(object):
-  '''Read lines from a file object in a separate thread.
-  
-  These are then logged on the local host with a given prefix.'''
-  def __init__(self, prefix, buffer=True):
-    self.prefix = prefix
-    self.buffer = buffer
-    self.lines = []
+def create_tempfile(dir, suffix):
+  os.system("mkdir -p '%s'" % dir)
+  return tempfile.NamedTemporaryFile(dir=dir, suffix=suffix)
 
-  def start(self, stream):
-    self.thread = mycloud.thread.spawn(self.run, stream)
+class LoggingServer(UDPServer):
+  log_output = ""
 
-  def run(self, stream):
-    while 1:
-      line = stream.readline()
-      if not line:
-        break
+  def __init__(self, cluster):
+    host = '0.0.0.0'
+    port = logging.handlers.DEFAULT_UDP_LOGGING_PORT
 
-      if not self.buffer:
-        print >> sys.stderr, self.prefix + ' --- ' + line.strip()
+    UDPServer.__init__(self, (host, port), None)
+    self.timeout = 0.1
+    self.cluster = cluster
 
-      self.lines.append(line.strip())
+    # for each distinct host, keep track of the last message sent
+    self.message_map = {}
 
-  def dump(self):
-    return (self.prefix + ' --- ' +
-            ('\n' + self.prefix + ' --- ').join(self.lines))
+  def server_bind(self):
+    logging.info('LoggingServer binding to address %s', self.server_address)
+    self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+    UDPServer.server_bind(self)
 
-  def join(self):
-    self.thread.wait()
+  def finish_request(self, request, client_address):
+    packet, socket = request
+
+    rlen = struct.unpack('>L', packet[:4])[0]
+
+    if len(packet) != rlen + 4:
+      logging.error('Received invalid logging packet. %s %s',
+                    len(packet), rlen)
+
+    record = logging.makeLogRecord(cPickle.loads(packet[4:]))
+    srchost = client_address[0]
+
+    self.message_map[client_address] = record
+
+    if record.exc_info:
+      self.cluster.report_exception(record.exc_info)
+#      logging.info('Exception from %s.', srchost)
+    else:
+      record.msg = 'Remote(%s) -- ' % srchost + record.msg
+#      logging.getLogger().handle(record)
 
 
 def to_tuple(arglist):
     self.value = value
     self.tb = traceback.format_exc(tb)
 
-class XMLServer(ThreadingMixIn, SimpleXMLRPCServer):
+class XMLServer(SimpleXMLRPCServer):
   def __init__(self, *args, **kw):
     SimpleXMLRPCServer.__init__(self, *args, **kw)
 
+  def _dispatch(self, method, params):
+    try:
+      return getattr(self, method)(*params)
+    except:
+      logging.exception('Error during dispatch!')
+      return xmlrpclib.Fault('Error while invoking method.',
+                             '\n'.join(traceback.format_exc()))
+
   def server_bind(self):
     logging.info('Binding to address %s', self.server_address)
     self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
     SimpleXMLRPCServer.server_bind(self)
 
-  def handle_request(self):
-    try:
-      SimpleXMLRPCServer.handle_request(self)
-    except:
-      logging.exception('Failed to handle request.')
 
-
-class ProxyServer(SimpleXMLRPCServer):
+class ProxyServer(XMLServer):
   def __init__(self):
     self.wrapped_objects = {}
-    SimpleXMLRPCServer.__init__(self, ('0.0.0.0', find_open_port()))
-
-  def _dispatch(self, method, params):
-    return getattr(self, method)(*params)
+    XMLServer.__init__(self, ('0.0.0.0', find_open_port()))
 
   def wrap(self, obj):
     self.wrapped_objects[id(obj)] = obj
     logging.info('Wrapped id %s', id(obj))
-    return ProxyObject(self.server_address[0],
-                       self.server_address[1],
-                       id(obj))
+    return ProxyObject(socket.gethostname(), self.server_address[1], id(obj))
 
   def invoke(self, objid, method, *args, **kw):
-    #logging.info('Invoking %s %s %s %s',
-    #             self.wrapped_objects[objid], method, args, kw)
-    return xmlrpclib.Binary(
-             cloudpickle.dumps(
-               getattr(self.wrapped_objects[objid], method)(*args, **kw)))
+    try:
+      logging.debug('Invoking object method...')
+      result = getattr(self.wrapped_objects[objid], method)(*args, **kw)
+      logging.debug('Success.')
+      return xmlrpclib.Binary(cloudpickle.dumps(result))
+    except:
+      logging.exception('Error during invocation!')
+      return xmlrpclib.Fault('Error while invoking method.',
+                             '\n'.join(traceback.format_exc()))
 
 
 class ProxyObject(object):
     self.server = None
 
   def get_server(self):
-#    logging.info('Connecting to %s %d', self.host, self.port)
-    if not self.server:
-      self.server = xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port),
-                                          allow_none=True)
-#    logging.info('Connection established to %s %d', self.host, self.port)
+    if self.server is None:
+#      logging.info('Connecting to %s %d', self.host, self.port)
+      self.server = xmlrpclib.ServerProxy('http://%s:%d' % (self.host, self.port))
+#      logging.info('Connection established to %s %d', self.host, self.port)
     return self.server
 
   def invoke(self, method, *args, **kw):
-    return cPickle.loads(
-             self.get_server().invoke(self.objid, method, *args, **kw).data)
+    for i in range(10):
+      try:
+        result = self.get_server().invoke(self.objid, method, *args, **kw)
+        return cPickle.loads(result.data)
+      except:
+        logging.exception('Failed to invoke remote method %s; trying again.' % method)
+        time.sleep(5)
+    raise Exception('Failed to invoke remote method %s on %s' % (method, self.host))

src/mycloud/worker.py

 #!/usr/bin/env python
+
 from cloud.serialization import cloudpickle
+from mycloud.util import XMLServer
+import argparse
 import cPickle
-import cStringIO
 import logging
 import mycloud.thread
 import mycloud.util
+import os
 import select
 import socket
 import sys
-import threading
 import time
 import xmlrpclib
 
 
 __doc__ = '''Worker for executing cluster tasks.'''
 
+def watchdog(worker):
+  while 1:
+    r, w, x = select.select([sys.stdin], [], [sys.stdin], 1)
+    if r or x:
+      logging.info('Lost controller.  Exiting.')
+      os._exit(1)
 
-class Worker(object):
-  def __init__(self, host, port):
-    self.host = host
-    self.port = port
+class Worker(XMLServer):
+  def __init__(self, *args, **kw):
+    XMLServer.__init__(self, *args, **kw)
+
+    self.host = socket.gethostname()
+    self.port = self.server_address[1]
     self.last_keepalive = time.time()
 
+    logging.info('Worker starting on %s:%s', self.host, self.port)
+
   def execute_task(self, pickled):
-    f, args, kw = cPickle.loads(pickled.data)
-    logging.info('Executing task %s %s %s', f, args, kw)
-    result = f(*args, **kw)
-    dump = cloudpickle.dumps(result)
-#    logging.info('Got result!')
-    return xmlrpclib.Binary(dump)
+    try:
+      f, args, kw = cPickle.loads(pickled.data)
+      logging.info('Executing task %s %s %s', f, args, kw)
+      result = f(*args, **kw)
+      dump = cloudpickle.dumps(result)
+      logging.info('Got result!')
+      return xmlrpclib.Binary(dump)
+    except:
+      logging.info('Failed to execute task.', exc_info=1)
+      raise
 
   def healthcheck(self):
     self.last_keepalive = time.time()
     return 'alive'
 
-def dump_stderr(src, dst):
-  while 1:
-    data = src.get_value()
-    src.truncate()
-    dst.write(data)
-    mycloud.thread.sleep(1)
+if __name__ == '__main__':
+  p = argparse.ArgumentParser()
+  p.add_argument('--index', type=int)
+  p.add_argument('--logger_host', type=str)
+  p.add_argument('--logger_port', type=int)
+  p.add_argument('--worker_name', type=str, default='worker')
 
+  opts = p.parse_args()
 
-if __name__ == '__main__':
+  index = opts.index
   myport = mycloud.util.find_open_port()
 
-  logging.basicConfig(stream=sys.stderr,
-                      #filename='/tmp/worker.%d.log' % myport,
+  log_prefix = '/tmp/%s-worker-%03d' % (socket.gethostname(), index)
+
+  logging.basicConfig(stream=open(log_prefix + '.log', 'w'),
                       format='%(asctime)s %(funcName)s %(message)s',
                       level=logging.INFO)
 
-  # Open a server on an open port, and inform our caller
-  old_stderr = sys.stderr
-  sys.stderr = cStringIO.StringIO()
+  if opts.logger_host:
+    logging.info('Additionally logging to %s:%s',
+                 opts.logger_host, opts.logger_port)
 
-  stderr_log = threading.Thread(target=dump_stderr, args=(sys.stderr, old_stderr))
-  stderr_log.setDaemon(True)
-  stderr_log.start()
+    logging.getLogger().addHandler(
+      logging.handlers.DatagramHandler(opts.logger_host, opts.logger_port))
 
-  xmlserver = mycloud.util.XMLServer(('0.0.0.0', myport), allow_none=True)
-  xmlserver.timeout = 1
-
-  worker = Worker(socket.gethostname(), myport)
-
-  xmlserver.register_function(worker.execute_task, 'execute_task')
-  xmlserver.register_function(worker.healthcheck, 'healthcheck')
+  worker = Worker(('0.0.0.0', myport))
+  worker.timeout = 1
 
   print myport
   sys.stdout.flush()
 
+  # redirect stdout and stderr to local files to avoid pipe/buffering issues
+  # with controller 
+  sys.stdout = open(log_prefix + '.out', 'w')
+  sys.stderr = open(log_prefix + '.err', 'w')
+
+  mycloud.thread.spawn(watchdog, worker)
+
   # handle requests until we lose our stdin connection the controller
   try:
     while 1:
-      xmlserver.handle_request()
-
-      r, w, x = select.select([sys.stdin], [], [sys.stdin], 0)
-      if r or x:
-        break
+      worker.handle_request()
   except:
     logging.info('Error while serving.', exc_info=1)
 
+
   logging.info('Shutting down.')

tests/test_mapreduce.py

 import sys
 import unittest
 
-def map_identity(k, v):
-  yield (k, v)
-
-def reduce_sum(k, values):
-  #logging.info('%s %s', k, values)
-  yield (k, sum(values))
-
 class MapReduceTestCase(unittest.TestCase):
   def testSimpleMapper(self):
-    cluster = mycloud.Cluster([('localhost', 4)])
+    cluster = mycloud.Cluster([('localhost', 4)], tmp_prefix='/tmp')
     input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
     output_desc = [mycloud.resource.MemoryFile() for i in range(1)]
 
     mr = mycloud.mapreduce.MapReduce(cluster,
-                                     map_identity,
-                                     reduce_sum,
+                                     mycloud.mapreduce.identity_mapper,
+                                     mycloud.mapreduce.sum_reducer,
                                      input_desc,
                                      output_desc)
     result = mr.run()
       self.assertEqual(v, j * 10)
 
   def testShardedOutput(self):
-    cluster = mycloud.Cluster([('localhost', 4)])
+    cluster = mycloud.Cluster([('localhost', 4)], tmp_prefix='/tmp')
     input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
     output_desc = [mycloud.resource.MemoryFile() for i in range(5)]
 
     mr = mycloud.mapreduce.MapReduce(cluster,
-                                     map_identity,
-                                     reduce_sum,
+                                     mycloud.mapreduce.identity_mapper,
+                                     mycloud.mapreduce.sum_reducer,
                                      input_desc,
                                      output_desc)
     result = mr.run()