Source

mycloud / tests / test_mapreduce.py

Full commit
#!/usr/bin/env python

import logging
import mycloud
import sys
import unittest

class MapReduceTestCase(unittest.TestCase):
  def testSimpleMapper(self):
    cluster = mycloud.Cluster([('localhost', 4)], tmp_prefix='/tmp')
    input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
    output_desc = [mycloud.resource.MemoryFile() for i in range(1)]

    mr = mycloud.mapreduce.MapReduce(cluster,
                                     mycloud.mapreduce.identity_mapper,
                                     mycloud.mapreduce.sum_reducer,
                                     input_desc,
                                     output_desc)
    result = mr.run()

    oiter = result[0].reader()
    for j in range(100):
      k, v = oiter.next()
      self.assertEqual(k, j)
      self.assertEqual(v, j * 10)

  def testShardedOutput(self):
    cluster = mycloud.Cluster([('localhost', 4)], tmp_prefix='/tmp')
    input_desc = [mycloud.resource.SequenceFile(range(100)) for i in range(10)]
    output_desc = [mycloud.resource.MemoryFile() for i in range(5)]

    mr = mycloud.mapreduce.MapReduce(cluster,
                                     mycloud.mapreduce.identity_mapper,
                                     mycloud.mapreduce.sum_reducer,
                                     input_desc,
                                     output_desc)
    result = mr.run()

    logging.info('Result %s %s', result[0], result[0].__class__)
    for i in range(5):
      j = i
      count = 0
      for k, v in result[i].reader():
        self.assertEqual(k, j)
        self.assertEqual(v, j * 10)
        j += 5
        count += 1

      self.assertEqual(count, 20)


if __name__ == "__main__":
  logging.basicConfig(stream=sys.stderr,
                      format='%(asctime)s %(funcName)s %(message)s',
                      level=logging.INFO)
  unittest.main()