Commits

keba  committed 54b9988

add Prim algorithm

  • Participants
  • Parent commits 83ce25b

Comments (0)

Files changed (4)

File pyreia/algorithms/__init__.py

 
 ## imports for easier importing
 from pyreia.algorithms.astar_al import astar
+from pyreia.algorithms.prim_al import prim

File pyreia/algorithms/prim_al.py

+"""
+    prim_al
+    =======
+
+    Prim algorithm to find the minumum spanning tree.
+
+    :copyright: 2010 by the pyreia Team, see AUTHORS for more details.
+    :license: GNU GPL, see LICENSE for more details.
+"""
+from pyreia.core import Heap
+from pyreia.algorithms.utils import PathAlgorithmNode
+from pyreia.nodes import Node
+
+class PNode(PathAlgorithmNode):
+
+    def __init__(self, node):
+        PathAlgorithmNode.__init__(self, node)
+
+        self.cost = float('inf')
+        self.parent = None
+
+    def __cmp__(self, other):
+        return cmp(self.cost, other.cost)
+
+    def _create_graph(self, in_place=False, reverse=False):
+        """
+        Create Nodes from the PNodes structure.
+
+        See func:~`pyreia.algorithms.prim_al.prim` for a more useful docstring.
+        """
+        node_map = {}
+        if in_place:
+            self.node.clean_tree()
+            for node in self.tree:
+                node_map[node] = node.node
+        else:
+            for node in self.tree:
+                node_map[node] = Node()
+
+        for pnode, node in node_map.iteritems():
+            if pnode.parent:
+                node_map[pnode.parent].add_child(node, pnode.cost, reverse=reverse)
+
+        return node_map[self].tree
+
+    def prim(self, *args, **kwargs):
+        self.cost = 0
+        heap = Heap(list(self.tree))
+
+        while heap:
+            node = heap.heappop()
+            for child, cost in node.children.iteritems():
+                if child in heap and cost < child.cost:
+                    child.parent = node
+                    child.cost = cost
+
+        ## all trees are equal, it does not matter which node is chosen
+        return self._create_graph(*args, **kwargs)
+
+
+def prim(node, *args, **kwargs):
+    """
+    Prim Algorithms to find a MST.
+
+    :param node: Any Node from the graph
+    :type node: :class:`~pyreia.nodes.Node`
+    :param in_place: if True, `Node` and belonging Nodes are changed (ie.
+                     children removed, defaults to False.
+    :type in_place: boolean
+    :param reverse: if True, a Node's parent is reachable by that Node, should
+                    only be used for connected weighted undirected graphs and
+                    defaults to False.
+    :type reverse: boolean
+    :return: MST
+    :rtype: :class:`~pyreia.nodes.Graph`
+    """
+    pnode = PNode(node)
+    pnode.copy_graph()
+    return pnode.prim(*args, **kwargs)

File tests/test_al_utils.py

 """
-    test_astar
-    ==========
+    test_al_utils
+    =============
 
     testing algorithm utils
 

File tests/test_prim.py

+"""
+    test_prim
+    =========
+
+    test algorithm of Prim
+
+    :copyright: 2010 by the pyreia Team, see AUTHORS for more details.
+    :license: GNU GPL, see LICENSE for more details.
+"""
+import unittest
+
+from pyreia.algorithms.prim_al import prim, PNode
+from pyreia.nodes import Node
+
+class TestPrim(unittest.TestCase):
+
+    def setUp(self):
+        nodes = {}
+        for i in range(7):
+            char = chr(i + 97)
+            nodes[char] = Node(char)
+
+        #: TODO: wow, this should be easier :/
+        nodes['a'].add_child(nodes['b'], 7, reverse=True)
+        nodes['a'].add_child(nodes['d'], 5, reverse=True)
+        nodes['b'].add_child(nodes['c'], 8, reverse=True)
+        nodes['b'].add_child(nodes['d'], 9, reverse=True)
+        nodes['b'].add_child(nodes['e'], 7, reverse=True)
+        nodes['c'].add_child(nodes['e'], 5, reverse=True)
+        nodes['d'].add_child(nodes['e'], 15, reverse=True)
+        nodes['d'].add_child(nodes['f'], 6, reverse=True)
+        nodes['e'].add_child(nodes['f'], 8, reverse=True)
+        nodes['e'].add_child(nodes['g'], 9, reverse=True)
+        nodes['f'].add_child(nodes['g'], 11, reverse=True)
+        self.nodes = nodes
+
+    def test_create_graph(self):
+        nodes = self.nodes
+        ## do prim() manually
+        node_map = {}
+        for node in nodes.itervalues():
+            node_map[node] = PNode(node)
+
+        node_map[nodes['b']].parent = nodes['a']
+        node_map[nodes['b']].cost = 7
+
+        node_map[nodes['d']].parent = nodes['a']
+        node_map[nodes['d']].cost = 5
+
+        node_map[nodes['e']].parent = nodes['b']
+        node_map[nodes['e']].cost = 7
+
+        node_map[nodes['c']].parent = nodes['e']
+        node_map[nodes['c']].cost = 5
+
+        node_map[nodes['f']].parent = nodes['d']
+        node_map[nodes['f']].cost = 6
+
+        node_map[nodes['g']].parent = nodes['e']
+        node_map[nodes['g']].cost = 9
+
+        pnode = node_map[nodes['a']]
+        graph = pnode._create_graph()
+
+    def test_prim(self):
+        nodes = self.nodes
+        mst = prim(nodes['d'], in_place=True)
+        self.assertFalse(nodes['d'] in nodes['b'].children)
+        self.assertFalse(nodes['c'] in nodes['b'].children)
+        self.assertTrue(nodes['e'] in nodes['b'].children)
+
+
+    def test_equality(self):
+        nodes = self.nodes
+        msts = []
+        for node in nodes.itervalues():
+            msts.append(set([n.name for n in prim(node)]))
+        ## all msts should be equal
+        self.assertTrue(msts[0] == msts[1] == msts[2] == msts[3])
+        self.assertTrue(msts[3] == msts[4] == msts[5] == msts[6])