Source

sorting_network / main.py

Full commit
def comparator(x, y):
    try:
        x1, y1 = x(), y()
    except:
        x1, y1 = x, y
    mn = lambda : min(x1, y1)
    mx = lambda : max(x1, y1)
    return mn, mx

def half_cleaner(seq):
    if len(seq) == 1:
        return seq
    ans = seq
    n = len(ans) // 2
    for i in xrange(n):
        cmp = comparator(seq[i], seq[i+n])
        ans[i], ans[i+n] = cmp[0], cmp[1]
    return ans

l = half_cleaner([0, 0, 1, 1, 0, 0])
print [x() for x in l]

def bitonic_sorter(seq):
    if len(seq) == 1:
        return seq
    ans = half_cleaner(seq)
    n = len(ans) // 2
    if len(ans) == 2:
        return ans
    return bitonic_sorter(ans[:n]) + bitonic_sorter(ans[n:])

l = bitonic_sorter([0, 0, 0, 1, 1, 1, 0, 0, 0, 0])
print [x() for x in l]

def merger(seq):
    if len(seq) == 2:
        return comparator(seq[0], seq[1])
    ans = seq
    n = len(ans) // 2
    for i in xrange(n):
        cmp = comparator(seq[i], seq[n * 2 - i - 1])
        ans[i], ans[n * 2 - i - 1] = cmp
    return bitonic_sorter(ans[:n]) + bitonic_sorter(ans[n:])

l = merger([1, 2, 3, 4, 0, 2, 3, 4])
print [x() for x in l]

def sorter(seq):
    l = 1
    n = len(seq)
    ans = seq
    while 2 ** l <= len(seq):
        nans = []
        x = 2 ** l
        for i in xrange(0, n, x):
            print i, i + x - 1
            nans += merger(ans[i:i + x])
        ans = nans
        l += 1
    return ans

l = sorter([-1, 2, 3, 1, 9, 8, 7, 6])
print [x() for x in l]

from multiprocessing.pool import ThreadPool

t = ThreadPool(4)
result = t.map(lambda x : x(), l)
print result