Source

dawg / dawg.py

#!/usr/bin/env python

import collections

def read_words():
    words = open("words.lst")

    wordbuf = []
    while True:
        char = words.read(1)
        if not char:
            break
        elif char == ' ':
            yield ''.join(wordbuf)
            wordbuf[:] = []
        else:
            wordbuf.append(char)
    yield ''.join(wordbuf)

class Node(object):
    endflag = False
    depth = 0

    def __init__(self, letter):
        self.letter = letter
        self.nodes = {}

    def get(self, word):
        nodes = self.nodes
        for letter in word:
            try:
                node = nodes[letter]
            except KeyError:
                return None
            nodes = node.nodes
        return node

    @property
    def hash_str(self):
        h = "%d %s" % (self.depth, self._dump())
        self.__dict__['hash_str'] = h
        return h

    def __eq__(self, other):
        return self.hash_str == other.hash_str

    def _dump(self):
        buf = self.letter or "(head)"
        if len(self.nodes) > 1:
            buf += "(%s)" % ",".join(n._dump() for n in self.nodes.values())
        else:
            buf += ",".join(n._dump() for n in self.nodes.values())
        return buf

    def visit(self, fn):
        stack = collections.deque([self])
        while stack:
            node = stack.popleft()
            fn(node)
            stack.extendleft(sorted(node.nodes.values(), key=lambda n:n.letter))

def build_trie(words):
    head = Node(None)
    heads = head.nodes
    for i, word in enumerate(words):
        if i % 1000 == 0:
            print "read %d words, %s" % (i, word)
        nodes = heads
        wordlen = len(word)
        for depth, letter in enumerate(word.lower()):
            node = nodes.get(letter)
            if node is None:
                node = nodes[letter] = Node(letter)
            nodes = node.nodes
            # depth from here to the endmost end.
            node.depth = max(node.depth, wordlen - depth)
        node.endflag = True
    return head

def build_dawg(head):

    by_depth = collections.defaultdict(list)
    head.visit(lambda n:by_depth[n.depth].append(n))

    for depth in sorted(by_depth, key=lambda k:-k):
        dupes = collections.defaultdict(list)
        for node in by_depth[depth]:
            dupes[node.hash_str].append(node)

        for node in by_depth[depth + 1]:
            for key, n2 in node.nodes.items():
                if n2.depth == depth:
                    collection = dupes[n2.hash_str]
                    if len(collection) > 1:
                        print "compressing %d duplicate suffixes with hash %s" % (
                            len(collection),
                            collection[0].hash_str
                        )
                    node.nodes[key] = collection[0]

def count_nodes(head):
    node_set = set()
    head.visit(node_set.add)
    return len(node_set)

def find_word(word):
    nodes = head.nodes
    for letter in word.lower():
        try:
            node = nodes[letter]
        except KeyError:
            return False
        nodes = node.nodes
    return node.endflag

head = build_trie(read_words())
build_dawg(head)