Commits

Jure Žbontar committed f62d13c

Add libsvm and liblinear.

Comments (0)

Files changed (4)

 from ml.linear_regression import LinearRegression
 from ml.logistic_regression import LogisticRegression
 from ml.mlp import MLPClassifier
+from ml.libsvm import Libsvm
+import tempfile
+import os
+import shutil
+from subprocess import check_call, PIPE
+
+import numpy as np
+
+def dump_svm(X, y, file):
+    X.sort_indices()
+    for i in range(X.shape[0]):
+        l, r = X.indptr[i], X.indptr[i + 1]
+        row = ' '.join('%d:%.18e' % t for t in zip(X.indices[l:r] + 1, X.data[l:r]))
+        file.write('%d %s\n' % (y[i], row))
+
+class Libsvm:
+    def __init__(self, program, args):
+        self.program = program
+        self.args = args
+
+    def fit_predict(self, X, y, X_test):
+        dir = tempfile.mkdtemp()
+        tr = os.path.join(dir, 'tr')
+        te = os.path.join(dir, 'te')
+        pred = os.path.join(dir, 'pred')
+        model = os.path.join(dir, 'model')
+
+        dump_svm(X, y, open(tr, 'w'))
+        dump_svm(X_test, np.zeros(X_test.shape[0]), open(te, 'w'))
+
+        cmd = '%s-train %s -q %s %s' % (self.program, self.args, tr, model)
+        check_call(cmd.split())
+        cmd = '%s-predict -b 1 %s %s %s' % (self.program, te, model, pred)
+        check_call(cmd.split(), stdout=PIPE)
+
+        col = int(open(pred).readline().split()[-1]) + 1
+        p = np.loadtxt(pred, skiprows=1)[:,col]
+
+        shutil.rmtree(dir)
+
+        return p

ml/linear_regression.py

         self.theta, cost, ret = fmin_l_bfgs_b(
             self.cost_grad, theta, args=(X, y), **self.fmin_args)
         if ret['warnflag'] != 0:
-            logging.warning('L-BFGS failed to converge')
+            logging.info('L-BFGS failed to converge')
 
     def predict(self, X):
         return X.dot(self.theta)

ml/logistic_regression.py

         self.theta, cost, ret = fmin_l_bfgs_b(
             self.cost_grad, theta, args=(X, y), **self.fmin_args)
         if ret['warnflag'] != 0:
-            logging.warning('L-BFGS failed to converge')
+            logging.info('L-BFGS failed to converge')
 
     def predict(self, X):
         return sigmoid(X.dot(self.theta))
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.