Commits

Miki Tebeka committed 88e24c2

working regression, in process spam

  • Participants
  • Parent commits 3885bcb

Comments (0)

Files changed (1)

 import matplotlib.pyplot as plt
 import numpy as np
 from sklearn import svm
+import re
 
-def solve(datafile, clf=None):
-    raw = loadmat(datafile)
-    data = np.append(raw['X'], raw['y'], 1)
-
+def plot(xs, ys, clf, title=None):
+    data = np.append(xs, ys, 1)
     pos = data[data[:,2]==1]
     neg = data[data[:,2]==0]
 
-    clf = clf or svm.SVC()
-    clf.fit(raw['X'], raw['y'].ravel()>0)
-
     xs = np.linspace(data[:,0].min(), data[:,0].max(), 100)
     ys = np.linspace(data[:,1].min(), data[:,1].max(), 100)
     z = np.zeros(shape=(len(xs), len(ys)))
 
     fig = plt.figure()
     ax = fig.add_subplot(111)
-    ax.contour(xs, ys, z)
-    ax.scatter(pos[:,0], pos[:,1], marker='+', color='black')
-    ax.scatter(neg[:,0], neg[:,1], marker='o', facecolor='yellow')
+    if title:
+        ax.set_title(title)
+    try:
+        ax.contour(xs, ys, z)
+        ax.scatter(pos[:,0], pos[:,1], marker='+', color='black')
+        ax.scatter(neg[:,0], neg[:,1], marker='o', facecolor='yellow')
 
-    fig.show()
+        fig.show()
+    except ValueError:
+        print('BUMMER')
+        plt.close(fig)
+
+
+
+def solve(datafile, clf=None):
+    raw = loadmat(datafile)
+    clf = clf or svm.SVC()
+    clf.fit(raw['X'], raw['y'].ravel()>0)
+    plot(raw['X'], raw['y'], clf)
+
+
+def find_best(**kw):
+    raw = loadmat('ex6/ex6data3.mat')
+    values = [.01, .03, .1, .3, 1, 3, 10, 30]
+    best, best_score = None, 0
+    best_score = 0
+    def p(score, C, gamma):
+        return 'score={} C={} gamma={}'.format(score, C, gamma)
+
+    for C in values:
+        for gamma in values:
+            clf = svm.SVC(C=C, gamma=gamma, **kw)
+            clf.fit(raw['X'], raw['y'].ravel()>0)
+            score = clf.score(raw['Xval'], raw['yval'])
+            print(p(score, C, gamma))
+            if score > best_score:
+                best, best_score = clf, score
+
+    plot(raw['X'], raw['y'], clf, title=p(score, C, gamma))
+    return clf
+
+
+def load_voc():
+    with open('ex6/vocab.txt') as fo:
+        return np.array([line.split()[1].strip() for line in fo])
+
+
+def normailze(text):
+    text = text.lower()
+    text = re.sub('<[^<>]+>', ' ', text)
+    text = re.sub('[0-9]+', 'number', text)
+    text = re.sub('(http|https)://[^\s]*', 'httpaddr', text)
+    text = re.sub('[^\s]+@[^\s]+', 'emailaddr', text)
+    text = re.sub('\$+', 'dollar', text)
+
+    return text
+
 
 
 def main(argv=None):