Commits

committed 54b9988

• Participants
• Parent commits 83ce25b

File pyreia/algorithms/__init__.py

` `
` ## imports for easier importing`
` from pyreia.algorithms.astar_al import astar`
`+from pyreia.algorithms.prim_al import prim`

File pyreia/algorithms/prim_al.py

`+"""`
`+    prim_al`
`+    =======`
`+`
`+    Prim algorithm to find the minumum spanning tree.`
`+`
`+    :copyright: 2010 by the pyreia Team, see AUTHORS for more details.`
`+    :license: GNU GPL, see LICENSE for more details.`
`+"""`
`+from pyreia.core import Heap`
`+from pyreia.algorithms.utils import PathAlgorithmNode`
`+from pyreia.nodes import Node`
`+`
`+class PNode(PathAlgorithmNode):`
`+`
`+    def __init__(self, node):`
`+        PathAlgorithmNode.__init__(self, node)`
`+`
`+        self.cost = float('inf')`
`+        self.parent = None`
`+`
`+    def __cmp__(self, other):`
`+        return cmp(self.cost, other.cost)`
`+`
`+    def _create_graph(self, in_place=False, reverse=False):`
`+        """`
`+        Create Nodes from the PNodes structure.`
`+`
`+        See func:~`pyreia.algorithms.prim_al.prim` for a more useful docstring.`
`+        """`
`+        node_map = {}`
`+        if in_place:`
`+            self.node.clean_tree()`
`+            for node in self.tree:`
`+                node_map[node] = node.node`
`+        else:`
`+            for node in self.tree:`
`+                node_map[node] = Node()`
`+`
`+        for pnode, node in node_map.iteritems():`
`+            if pnode.parent:`
`+                node_map[pnode.parent].add_child(node, pnode.cost, reverse=reverse)`
`+`
`+        return node_map[self].tree`
`+`
`+    def prim(self, *args, **kwargs):`
`+        self.cost = 0`
`+        heap = Heap(list(self.tree))`
`+`
`+        while heap:`
`+            node = heap.heappop()`
`+            for child, cost in node.children.iteritems():`
`+                if child in heap and cost < child.cost:`
`+                    child.parent = node`
`+                    child.cost = cost`
`+`
`+        ## all trees are equal, it does not matter which node is chosen`
`+        return self._create_graph(*args, **kwargs)`
`+`
`+`
`+def prim(node, *args, **kwargs):`
`+    """`
`+    Prim Algorithms to find a MST.`
`+`
`+    :param node: Any Node from the graph`
`+    :type node: :class:`~pyreia.nodes.Node``
`+    :param in_place: if True, `Node` and belonging Nodes are changed (ie.`
`+                     children removed, defaults to False.`
`+    :type in_place: boolean`
`+    :param reverse: if True, a Node's parent is reachable by that Node, should`
`+                    only be used for connected weighted undirected graphs and`
`+                    defaults to False.`
`+    :type reverse: boolean`
`+    :return: MST`
`+    :rtype: :class:`~pyreia.nodes.Graph``
`+    """`
`+    pnode = PNode(node)`
`+    pnode.copy_graph()`
`+    return pnode.prim(*args, **kwargs)`

File tests/test_al_utils.py

` """`
`-    test_astar`
`-    ==========`
`+    test_al_utils`
`+    =============`
` `
`     testing algorithm utils`
` `

File tests/test_prim.py

`+"""`
`+    test_prim`
`+    =========`
`+`
`+    test algorithm of Prim`
`+`
`+    :copyright: 2010 by the pyreia Team, see AUTHORS for more details.`
`+    :license: GNU GPL, see LICENSE for more details.`
`+"""`
`+import unittest`
`+`
`+from pyreia.algorithms.prim_al import prim, PNode`
`+from pyreia.nodes import Node`
`+`
`+class TestPrim(unittest.TestCase):`
`+`
`+    def setUp(self):`
`+        nodes = {}`
`+        for i in range(7):`
`+            char = chr(i + 97)`
`+            nodes[char] = Node(char)`
`+`
`+        #: TODO: wow, this should be easier :/`
`+        nodes['a'].add_child(nodes['b'], 7, reverse=True)`
`+        nodes['a'].add_child(nodes['d'], 5, reverse=True)`
`+        nodes['b'].add_child(nodes['c'], 8, reverse=True)`
`+        nodes['b'].add_child(nodes['d'], 9, reverse=True)`
`+        nodes['b'].add_child(nodes['e'], 7, reverse=True)`
`+        nodes['c'].add_child(nodes['e'], 5, reverse=True)`
`+        nodes['d'].add_child(nodes['e'], 15, reverse=True)`
`+        nodes['d'].add_child(nodes['f'], 6, reverse=True)`
`+        nodes['e'].add_child(nodes['f'], 8, reverse=True)`
`+        nodes['e'].add_child(nodes['g'], 9, reverse=True)`
`+        nodes['f'].add_child(nodes['g'], 11, reverse=True)`
`+        self.nodes = nodes`
`+`
`+    def test_create_graph(self):`
`+        nodes = self.nodes`
`+        ## do prim() manually`
`+        node_map = {}`
`+        for node in nodes.itervalues():`
`+            node_map[node] = PNode(node)`
`+`
`+        node_map[nodes['b']].parent = nodes['a']`
`+        node_map[nodes['b']].cost = 7`
`+`
`+        node_map[nodes['d']].parent = nodes['a']`
`+        node_map[nodes['d']].cost = 5`
`+`
`+        node_map[nodes['e']].parent = nodes['b']`
`+        node_map[nodes['e']].cost = 7`
`+`
`+        node_map[nodes['c']].parent = nodes['e']`
`+        node_map[nodes['c']].cost = 5`
`+`
`+        node_map[nodes['f']].parent = nodes['d']`
`+        node_map[nodes['f']].cost = 6`
`+`
`+        node_map[nodes['g']].parent = nodes['e']`
`+        node_map[nodes['g']].cost = 9`
`+`
`+        pnode = node_map[nodes['a']]`
`+        graph = pnode._create_graph()`
`+`
`+    def test_prim(self):`
`+        nodes = self.nodes`
`+        mst = prim(nodes['d'], in_place=True)`
`+        self.assertFalse(nodes['d'] in nodes['b'].children)`
`+        self.assertFalse(nodes['c'] in nodes['b'].children)`
`+        self.assertTrue(nodes['e'] in nodes['b'].children)`
`+`
`+`
`+    def test_equality(self):`
`+        nodes = self.nodes`
`+        msts = []`
`+        for node in nodes.itervalues():`
`+            msts.append(set([n.name for n in prim(node)]))`
`+        ## all msts should be equal`
`+        self.assertTrue(msts[0] == msts[1] == msts[2] == msts[3])`
`+        self.assertTrue(msts[3] == msts[4] == msts[5] == msts[6])`