gpalign-cpp / src / tfIdfAligner.cpp

//--------------------------------------------------------------------------//
// tfIdfAligner.cpp
// Lars Yencken <lars.yencken@gmail.com>
// vim: ts=4 sw=4 sts=4 expandtab:
// Sat Oct  6 14:05:27 EST 2007
//--------------------------------------------------------------------------//

#include "tfIdfAligner.hpp"
#include "generator.hpp"
#include "scripts.hpp"
#include "progressBar.hpp"

#include <math.h>
#include <algorithm>
#include <iostream>
#include <sstream>
#include <fstream>
#include <stdexcept>

//--------------------------------------------------------------------------//

TfIdfAligner::TfIdfAligner(double alpha, int blockSize)
    :
m_alpha(alpha),
m_blockSize(blockSize),
m_knownDist(0.8),
m_unknownDist(0.2)
{
}

//--------------------------------------------------------------------------//

void TfIdfAligner::alignSegments(
            const vector<Segment>& entries,
            vector<BaseAlignment>& alignments
        )
{
    wcout << L"Building alignment model" << endl;
    alignments.clear();
    vector<AlignmentCloudPtr> alignmentClouds;
    _generateAlignments(entries, alignments, alignmentClouds);
    _disambiguateClouds(alignmentClouds, alignments);
    return;
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_generateAlignments(
            const vector<Segment>& entries,
            vector<BaseAlignment>& alignments,
            vector<AlignmentCloudPtr>& alignmentClouds
        )
{
    const int nEntries = entries.size();
    int nSolved = 0;
    int nFailed = 0;
    wofstream errorLog;
    errorLog.open("errors.log");
#ifndef DARWIN
    errorLog.imbue(locale(g_locale));
#endif

    if (!errorLog.good()) {
        throw runtime_error("Can't open errors.log for writing");
    }
    for (int i = 0; i < nEntries; i++) {
        const Segment& entry = entries[i];
        if (kanjiLen(entry.g) > 5) {
            errorLog << entry << L" (too long)" << endl;
            nFailed++;
            continue;
        }
        AlignmentCloudPtr cloud = _allAlignments(entry);
        if (cloud->isValid()) {
            if (cloud->isResolved()) {
                Alignment a = cloud->candidates[0];
                _addKnownCounts(a);
                alignments.push_back(make_pair(entry, a));
                nSolved++;
            } else {
                _addUnknownCounts(*cloud);
                alignmentClouds.push_back(cloud);
            }
        } else {
            errorLog << entry << L" (no valid alignments)" << endl;
            nFailed++;
        }
    }
    errorLog.close();
    wcout << L"--> " << nSolved << L" uniquely determined" << endl;
    wcout << L"--> " << nFailed << L" overconstrained, written to errors.log"
            << endl;
    return;
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_disambiguateClouds(
            vector<AlignmentCloudPtr>& alignmentClouds,
            vector<BaseAlignment>& alignments
        )
{
    wcout << L"Disambiguating clouds" << endl;
    ProgressBar progress;

    // Disambiguate the alignments N at a time.
    progress.start(alignmentClouds.size());
    int count = 0;
    vector<AlignmentCloudPtr> bestClouds;
    while (alignmentClouds.size() > 0) {
        progress.update(count);
        _popBestN(alignmentClouds, bestClouds);
        for (vector<AlignmentCloudPtr>::iterator iter = bestClouds.begin();
                iter != bestClouds.end(); iter++) {
            _delUnknownCounts(**iter);
            Alignment bestAlignment = (*iter)->getBest();
            _addKnownCounts(bestAlignment);
            alignments.push_back(make_pair((*iter)->base, bestAlignment));
        }
        count += m_blockSize;
    }
    progress.finish();
    return;
}

//--------------------------------------------------------------------------//

double TfIdfAligner::_scoreAlignment(const Alignment& a) PLATFORM_CONST
{
    double sum = 0.0;
    int n = 0;
    const int nSegs = a.size();
    for (int i = 0; i < nSegs; i++) {
        Context c;
        if (i > 0) {
            // Use previous segment, rightmost character as left context.
            const Segment& l = a[i-1];
            const Grapheme& gl = l.g;
            const Phoneme& pl = l.p;
            c.g += gl[gl.size()-1];
            c.p += pl[pl.size()-1];
        }

        if (i < nSegs - 1) {
            // Use next segment, leftmost character as right context.
            const Segment& r = a[i+1];
            c.g += r.g[0];
            c.p += r.p[0];
        }
        sum += _scoreSegment(a[i], c);
        n++;
    }

    n = ((n > 0) ? n : 1);    // Avoid divide by zero.
    return sum / n;
}

//--------------------------------------------------------------------------//

double TfIdfAligner::_scoreSegment(const Segment& s,
        const Context& context) PLATFORM_CONST
{
    const double w_known = m_knownDist.weight;
    const double w_unknown = m_unknownDist.weight; 
    double gFreq = w_known*m_knownDist.graphemeDist.mle(s.g) + 
        w_unknown*m_unknownDist.graphemeDist.mle(s.g);
    double gpFreq = w_known*m_knownDist.segmentDist.mle(s) + 
        w_unknown*m_unknownDist.segmentDist.mle(s);
    SegmentContext c(s, context);
    double gpcFreq = w_known*m_knownDist.contextDist.mle(c) + 
        w_unknown*m_unknownDist.contextDist.mle(c);

    double tf = (gpFreq - w_unknown + m_alpha) / gFreq;
    double idf = log(gpFreq/(gpcFreq - w_unknown + m_alpha));

    return tf * idf;
}

//--------------------------------------------------------------------------//

AlignmentCloudPtr TfIdfAligner::_allAlignments(const Segment& s)
{
    vector<Alignment> alignments;
    potentialAlignments(s, alignments);

    AlignmentCloudPtr cloud(new AlignmentCloud());
    for (vector<Alignment>::iterator i = alignments.begin();
            i != alignments.end(); i++) {
        cloud->candidates.push_back(*i);
    }
    cloud->base = s;
    return cloud;
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_addUnknownCounts(const AlignmentCloud& cloud)
{
    const vector<Alignment>& alignments = cloud.candidates;
    for (vector<Alignment>::const_iterator iter = alignments.begin();
            iter != alignments.end(); iter++) {
        m_unknownDist.addCounts(*iter);
    }
    return;
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_delUnknownCounts(const AlignmentCloud& cloud)
{
    const vector<Alignment>& alignments = cloud.candidates;
    for (vector<Alignment>::const_iterator iter = alignments.begin();
            iter != alignments.end(); iter++) {
        m_unknownDist.delCounts(*iter);
    }
    return;
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_addKnownCounts(const Alignment& alignment)
{
    m_knownDist.addCounts(alignment);
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_delKnownCounts(const Alignment& alignment)
{
    m_knownDist.delCounts(alignment);
}

//--------------------------------------------------------------------------//

double TfIdfAligner::_scoreCloud(AlignmentCloud& cloud) PLATFORM_CONST
{
    const vector<Alignment>& candidates = cloud.candidates;
    const int nCandidates = candidates.size();

    double bestScore = -1e10;
    int bestIndex = -1;
    double score;
    for (int j = 0; j < nCandidates; j++) {
        score = _scoreAlignment(candidates[j]);
        if (bestScore < score) {
            bestScore = score;
            bestIndex = j;
        }
    }
    cloud.bestScore = bestScore;
    cloud.bestIndex = bestIndex;

    return bestScore;
}

//--------------------------------------------------------------------------//

AlignmentCloudPtr TfIdfAligner::_popBest(vector<AlignmentCloudPtr>& clouds)
{
    const int nClouds = clouds.size();
    if (nClouds == 0) {
        cerr << "Fatal error: expected non-empty vector of clouds" << endl;
        exit(1);
    }

    int bestIndex = 0;
    double bestScore = -1e10;
    double score;
    for (int i = 0; i < nClouds; i++) {
        AlignmentCloudPtr& cloud = clouds[i];
        score = _scoreCloud(*cloud);

        if (score > bestScore) {
            bestIndex = i;
            bestScore = score;
        }
    }

    if (nClouds > 1) {
        swap(clouds[bestIndex], clouds[nClouds - 1]);
    }
    AlignmentCloudPtr result = clouds.back();
    clouds.pop_back();
    return result;
}

//--------------------------------------------------------------------------//

void TfIdfAligner::_popBestN(vector<AlignmentCloudPtr>& clouds,
        vector<AlignmentCloudPtr>& results)
{
    results.clear();
    // Rescore all the clouds.
    for (vector<AlignmentCloudPtr>::iterator iter = clouds.begin();
            iter != clouds.end(); iter++) {
        (void) _scoreCloud(**iter);
    }

    for (int i = 0; i < m_blockSize && clouds.size() > 0; i++) {
        double bestScore = -1e10;
        int bestIndex = 0;
        const int nClouds = clouds.size();
        for (int j = 0; j < nClouds; j++) {
            AlignmentCloud& cloud = *clouds[j];
            if (cloud.bestScore > bestScore) {
                bestScore = cloud.bestScore;
                bestIndex = j;
            }
        }
        if (nClouds > 1) {
            swap(clouds[bestIndex], clouds[nClouds - 1]);
        }
        results.push_back(clouds.back());
        clouds.pop_back();
    }
    return;
}

//--------------------------------------------------------------------------//
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.