Source

ml-class / ex6.py

#!/usr/bin/env python

from scipy.io import loadmat
import matplotlib.pyplot as plt
import numpy as np
from sklearn import svm
import re
from Stemmer import Stemmer

stem = Stemmer('en').stemWord

def plot(xs, ys, clf, title=None):
    data = np.append(xs, ys, 1)
    pos = data[data[:,2]==1]
    neg = data[data[:,2]==0]

    xs = np.linspace(data[:,0].min(), data[:,0].max(), 100)
    ys = np.linspace(data[:,1].min(), data[:,1].max(), 100)
    z = np.zeros(shape=(len(xs), len(ys)))
    for x, xv in enumerate(xs):
        for y, yv in enumerate(ys):
            z[x, y] = clf.predict([xv, yv])[0]

    fig = plt.figure()
    ax = fig.add_subplot(111)
    if title:
        ax.set_title(title)
    try:
        ax.contour(xs, ys, z)
        ax.scatter(pos[:,0], pos[:,1], marker='+', color='black')
        ax.scatter(neg[:,0], neg[:,1], marker='o', facecolor='yellow')

        fig.show()
    except ValueError:
        print('BUMMER')
        plt.close(fig)



def solve(datafile, clf=None):
    raw = loadmat(datafile)
    clf = clf or svm.SVC()
    clf.fit(raw['X'], raw['y'].ravel()>0)
    plot(raw['X'], raw['y'], clf)


def find_best(raw, **kw):
    values = [.01, .03, .1, .3, 1, 3, 10, 30]
    best, best_score = None, 0
    best_score = 0
    def p(score, C, gamma):
        return 'score={} C={} gamma={}'.format(score, C, gamma)

    for C in values:
        for gamma in values:
            clf = svm.SVC(C=C, gamma=gamma, **kw)
            clf.fit(raw['X'], raw['y'].ravel()>0)
            score = clf.score(raw['Xval'], raw['yval'])
            print(p(score, C, gamma))
            if score > best_score:
                best, best_score = clf, score

    plot(raw['X'], raw['y'], clf, title=p(score, C, gamma))
    return clf


def load_voc():
    with open('ex6/vocab.txt') as fo:
        kv = (line.split() for line in fo)
        return dict((v.strip(), int(k)) for k, v in kv)


def normailze(text):
    text = text.lower()
    text = re.sub('<[^<>]+>', ' ', text)
    text = re.sub('[0-9]+', 'number', text)
    text = re.sub('(http|https)://[^\s]*', 'httpaddr', text)
    text = re.sub('[^\s]+@[^\s]+', 'emailaddr', text)
    text = re.sub('\$+', 'dollar', text)

    return text


def tokenize(text):
    text = normailze(text)
    tokens = re.split(r'[ @$/#.\-:&*+=\[\]?!(){},\'">_<;%\n\r]', text)
    tokens = (re.sub('[^a-zA-Z]', '', token) for token in tokens)
    return (stem(token) for token in tokens if token.strip())


def vectorize(text):
    voc = load_voc()
    vec = np.zeros(len(voc))
    for token in tokenize(text):
        i = voc.get(token, -1)
        if i == -1:
            continue
        vec[i] = 1

    return vec


def spam_train(clf=None):
    raw = loadmat('ex6/spamTrain.mat')
    test = loadmat('ex6/spamTest.mat')
    raw['Xval'] = test['Xtest']
    raw['yval'] = test['ytest']

    find_best(raw)


def main(argv=None):
    import sys
    from argparse import ArgumentParser

    argv = argv or sys.argv

    parser = ArgumentParser(description='')
    parser.add_argument('datafile')
    args = parser.parse_args(argv[1:])

    solve(args.datafile)
    raw_input('Hit Enter to Quit')

if __name__ == '__main__':
    main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.