Commits

Vladislav Bortnikov committed 2eaccb2

More pythonic and faster code

Comments (0)

Files changed (2)

+def comparator(x, y):
+    mn = min(x, y)
+    mx = max(x, y)
+    return mn, mx
+
+def separate(seq):
+    n = len(seq) // 2
+    return seq[:n], seq[n:]
+
+def half_cleaner(seq, merge = True):
+    if len(seq) == 1:
+        return seq
+    left, right = separate(seq)
+    result = [comparator(*x) for x in zip(left, right)]
+    left, right = [list(x) for x in zip(*result)]
+    if merge:
+        return left + right
+    else:
+        return left, right
+
+def bitonic_sorter(seq):
+    if len(seq) <= 1:
+        return seq
+    left, right = half_cleaner(seq, False)
+    return bitonic_sorter(left) + bitonic_sorter(right)
+
+def merger(seq):
+    left, right = separate(seq)
+    zipped = [comparator(*x) for x in zip(left, reversed(right))]
+    left, right = zip(*zipped)
+    return bitonic_sorter(left) + bitonic_sorter(right)
+
+def sorter(seq):
+    if len(seq) == 1:
+        return seq
+    if len(seq) == 2:
+        return list(comparator(seq[0], seq[1]))
+    left, right = separate(seq)
+    return merger(sorter(left) + sorter(right))
+from sorting_net import *
+
+class Test_mergesort(object):
+
+    def test_half_cleaner(self):
+        l = half_cleaner([0, 0, 1, 1, 1, 0, 0, 0])
+        assert l == [0, 0, 0, 0, 1, 0, 1, 1]
+
+    def test_bitonic_sorter(self):
+        l = bitonic_sorter([0, 0, 1, 1, 1, 0, 0, 0])
+        assert l == [0, 0, 0, 0, 0, 1, 1, 1]
+
+    def test_merger(self):
+        l = merger([0, 0, 1, 1, 0, 1, 1, 1])
+        assert l == [0, 0, 0, 1, 1, 1, 1, 1]
+
+    def test_comparator1(self):
+        a = -1; b = 1
+        assert comparator(a, b)[0] == -1
+
+    def test_sorter(self):
+        seq = [5, 9, 2, 1, 2, 3, 4, 5, 1, 0, 2, 4, 1, 2, 3, 4]
+        l = sorter(seq)
+        assert l == sorted(seq)
+
+    def test_big_seq(self):
+        import random
+        r = random.Random()
+        seq = []
+        add = seq.append
+        for i in xrange(2 ** 10):
+            add(r.randint(0, 1000))
+        l = sorter(seq)
+        assert l == sorted(seq)