Jure Žbontar avatar Jure Žbontar committed dd755d9

Softmax regression.

Comments (0)

Files changed (2)

+1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667 
+1 1:-0.666667 2:-0.166667 3:-0.864407 4:-0.916667 
+1 1:-0.777778 3:-0.898305 4:-0.916667 
+1 1:-0.833333 2:-0.0833334 3:-0.830508 4:-0.916667 
+1 1:-0.611111 2:0.333333 3:-0.864407 4:-0.916667 
+1 1:-0.388889 2:0.583333 3:-0.762712 4:-0.75 
+1 1:-0.833333 2:0.166667 3:-0.864407 4:-0.833333 
+1 1:-0.611111 2:0.166667 3:-0.830508 4:-0.916667 
+1 1:-0.944444 2:-0.25 3:-0.864407 4:-0.916667 
+1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 
+1 1:-0.388889 2:0.416667 3:-0.830508 4:-0.916667 
+1 1:-0.722222 2:0.166667 3:-0.79661 4:-0.916667 
+1 1:-0.722222 2:-0.166667 3:-0.864407 4:-1 
+1 1:-1 2:-0.166667 3:-0.966102 4:-1 
+1 1:-0.166667 2:0.666667 3:-0.932203 4:-0.916667 
+1 1:-0.222222 2:1 3:-0.830508 4:-0.75 
+1 1:-0.388889 2:0.583333 3:-0.898305 4:-0.75 
+1 1:-0.555556 2:0.25 3:-0.864407 4:-0.833333 
+1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333 
+1 1:-0.555556 2:0.5 3:-0.830508 4:-0.833333 
+1 1:-0.388889 2:0.166667 3:-0.762712 4:-0.916667 
+1 1:-0.555556 2:0.416667 3:-0.830508 4:-0.75 
+1 1:-0.833333 2:0.333333 3:-1 4:-0.916667 
+1 1:-0.555556 2:0.0833333 3:-0.762712 4:-0.666667 
+1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667 
+1 1:-0.611111 2:-0.166667 3:-0.79661 4:-0.916667 
+1 1:-0.611111 2:0.166667 3:-0.79661 4:-0.75 
+1 1:-0.5 2:0.25 3:-0.830508 4:-0.916667 
+1 1:-0.5 2:0.166667 3:-0.864407 4:-0.916667 
+1 1:-0.777778 3:-0.79661 4:-0.916667 
+1 1:-0.722222 2:-0.0833334 3:-0.79661 4:-0.916667 
+1 1:-0.388889 2:0.166667 3:-0.830508 4:-0.75 
+1 1:-0.5 2:0.75 3:-0.830508 4:-1 
+1 1:-0.333333 2:0.833333 3:-0.864407 4:-0.916667 
+1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 
+1 1:-0.611111 3:-0.932203 4:-0.916667 
+1 1:-0.333333 2:0.25 3:-0.898305 4:-0.916667 
+1 1:-0.666667 2:-0.0833334 3:-0.830508 4:-1 
+1 1:-0.944444 2:-0.166667 3:-0.898305 4:-0.916667 
+1 1:-0.555556 2:0.166667 3:-0.830508 4:-0.916667 
+1 1:-0.611111 2:0.25 3:-0.898305 4:-0.833333 
+1 1:-0.888889 2:-0.75 3:-0.898305 4:-0.833333 
+1 1:-0.944444 3:-0.898305 4:-0.916667 
+1 1:-0.611111 2:0.25 3:-0.79661 4:-0.583333 
+1 1:-0.555556 2:0.5 3:-0.694915 4:-0.75 
+1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333 
+1 1:-0.555556 2:0.5 3:-0.79661 4:-0.916667 
+1 1:-0.833333 3:-0.864407 4:-0.916667 
+1 1:-0.444444 2:0.416667 3:-0.830508 4:-0.916667 
+1 1:-0.611111 2:0.0833333 3:-0.864407 4:-0.916667 
+2 1:0.5 3:0.254237 4:0.0833333 
+2 1:0.166667 3:0.186441 4:0.166667 
+2 1:0.444444 2:-0.0833334 3:0.322034 4:0.166667 
+2 1:-0.333333 2:-0.75 3:0.0169491 4:-4.03573e-08 
+2 1:0.222222 2:-0.333333 3:0.220339 4:0.166667 
+2 1:-0.222222 2:-0.333333 3:0.186441 4:-4.03573e-08 
+2 1:0.111111 2:0.0833333 3:0.254237 4:0.25 
+2 1:-0.666667 2:-0.666667 3:-0.220339 4:-0.25 
+2 1:0.277778 2:-0.25 3:0.220339 4:-4.03573e-08 
+2 1:-0.5 2:-0.416667 3:-0.0169491 4:0.0833333 
+2 1:-0.611111 2:-1 3:-0.152542 4:-0.25 
+2 1:-0.111111 2:-0.166667 3:0.0847457 4:0.166667 
+2 1:-0.0555556 2:-0.833333 3:0.0169491 4:-0.25 
+2 1:-1.32455e-07 2:-0.25 3:0.254237 4:0.0833333 
+2 1:-0.277778 2:-0.25 3:-0.118644 4:-4.03573e-08 
+2 1:0.333333 2:-0.0833334 3:0.152542 4:0.0833333 
+2 1:-0.277778 2:-0.166667 3:0.186441 4:0.166667 
+2 1:-0.166667 2:-0.416667 3:0.0508474 4:-0.25 
+2 1:0.0555554 2:-0.833333 3:0.186441 4:0.166667 
+2 1:-0.277778 2:-0.583333 3:-0.0169491 4:-0.166667 
+2 1:-0.111111 3:0.288136 4:0.416667 
+2 1:-1.32455e-07 2:-0.333333 3:0.0169491 4:-4.03573e-08 
+2 1:0.111111 2:-0.583333 3:0.322034 4:0.166667 
+2 1:-1.32455e-07 2:-0.333333 3:0.254237 4:-0.0833333 
+2 1:0.166667 2:-0.25 3:0.118644 4:-4.03573e-08 
+2 1:0.277778 2:-0.166667 3:0.152542 4:0.0833333 
+2 1:0.388889 2:-0.333333 3:0.288136 4:0.0833333 
+2 1:0.333333 2:-0.166667 3:0.355932 4:0.333333 
+2 1:-0.0555556 2:-0.25 3:0.186441 4:0.166667 
+2 1:-0.222222 2:-0.5 3:-0.152542 4:-0.25 
+2 1:-0.333333 2:-0.666667 3:-0.0508475 4:-0.166667 
+2 1:-0.333333 2:-0.666667 3:-0.0847458 4:-0.25 
+2 1:-0.166667 2:-0.416667 3:-0.0169491 4:-0.0833333 
+2 1:-0.0555556 2:-0.416667 3:0.38983 4:0.25 
+2 1:-0.388889 2:-0.166667 3:0.186441 4:0.166667 
+2 1:-0.0555556 2:0.166667 3:0.186441 4:0.25 
+2 1:0.333333 2:-0.0833334 3:0.254237 4:0.166667 
+2 1:0.111111 2:-0.75 3:0.152542 4:-4.03573e-08 
+2 1:-0.277778 2:-0.166667 3:0.0508474 4:-4.03573e-08 
+2 1:-0.333333 2:-0.583333 3:0.0169491 4:-4.03573e-08 
+2 1:-0.333333 2:-0.5 3:0.152542 4:-0.0833333 
+2 1:-1.32455e-07 2:-0.166667 3:0.220339 4:0.0833333 
+2 1:-0.166667 2:-0.5 3:0.0169491 4:-0.0833333 
+2 1:-0.611111 2:-0.75 3:-0.220339 4:-0.25 
+2 1:-0.277778 2:-0.416667 3:0.0847457 4:-4.03573e-08 
+2 1:-0.222222 2:-0.166667 3:0.0847457 4:-0.0833333 
+2 1:-0.222222 2:-0.25 3:0.0847457 4:-4.03573e-08 
+2 1:0.0555554 2:-0.25 3:0.118644 4:-4.03573e-08 
+2 1:-0.555556 2:-0.583333 3:-0.322034 4:-0.166667 
+2 1:-0.222222 2:-0.333333 3:0.0508474 4:-4.03573e-08 
+3 1:0.111111 2:0.0833333 3:0.694915 4:1 
+3 1:-0.166667 2:-0.416667 3:0.38983 4:0.5 
+3 1:0.555555 2:-0.166667 3:0.661017 4:0.666667 
+3 1:0.111111 2:-0.25 3:0.559322 4:0.416667 
+3 1:0.222222 2:-0.166667 3:0.627119 4:0.75 
+3 1:0.833333 2:-0.166667 3:0.898305 4:0.666667 
+3 1:-0.666667 2:-0.583333 3:0.186441 4:0.333333 
+3 1:0.666667 2:-0.25 3:0.79661 4:0.416667 
+3 1:0.333333 2:-0.583333 3:0.627119 4:0.416667 
+3 1:0.611111 2:0.333333 3:0.728813 4:1 
+3 1:0.222222 3:0.38983 4:0.583333 
+3 1:0.166667 2:-0.416667 3:0.457627 4:0.5 
+3 1:0.388889 2:-0.166667 3:0.525424 4:0.666667 
+3 1:-0.222222 2:-0.583333 3:0.355932 4:0.583333 
+3 1:-0.166667 2:-0.333333 3:0.38983 4:0.916667 
+3 1:0.166667 3:0.457627 4:0.833333 
+3 1:0.222222 2:-0.166667 3:0.525424 4:0.416667 
+3 1:0.888889 2:0.5 3:0.932203 4:0.75 
+3 1:0.888889 2:-0.5 3:1 4:0.833333 
+3 1:-0.0555556 2:-0.833333 3:0.355932 4:0.166667 
+3 1:0.444444 3:0.59322 4:0.833333 
+3 1:-0.277778 2:-0.333333 3:0.322034 4:0.583333 
+3 1:0.888889 2:-0.333333 3:0.932203 4:0.583333 
+3 1:0.111111 2:-0.416667 3:0.322034 4:0.416667 
+3 1:0.333333 2:0.0833333 3:0.59322 4:0.666667 
+3 1:0.611111 3:0.694915 4:0.416667 
+3 1:0.0555554 2:-0.333333 3:0.288136 4:0.416667 
+3 1:-1.32455e-07 2:-0.166667 3:0.322034 4:0.416667 
+3 1:0.166667 2:-0.333333 3:0.559322 4:0.666667 
+3 1:0.611111 2:-0.166667 3:0.627119 4:0.25 
+3 1:0.722222 2:-0.333333 3:0.728813 4:0.5 
+3 1:1 2:0.5 3:0.830508 4:0.583333 
+3 1:0.166667 2:-0.333333 3:0.559322 4:0.75 
+3 1:0.111111 2:-0.333333 3:0.38983 4:0.166667 
+3 1:-1.32455e-07 2:-0.5 3:0.559322 4:0.0833333 
+3 1:0.888889 2:-0.166667 3:0.728813 4:0.833333 
+3 1:0.111111 2:0.166667 3:0.559322 4:0.916667 
+3 1:0.166667 2:-0.0833334 3:0.525424 4:0.416667 
+3 1:-0.0555556 2:-0.166667 3:0.288136 4:0.416667 
+3 1:0.444444 2:-0.0833334 3:0.491525 4:0.666667 
+3 1:0.333333 2:-0.0833334 3:0.559322 4:0.916667 
+3 1:0.444444 2:-0.0833334 3:0.38983 4:0.833333 
+3 1:-0.166667 2:-0.416667 3:0.38983 4:0.5 
+3 1:0.388889 3:0.661017 4:0.833333 
+3 1:0.333333 2:0.0833333 3:0.59322 4:1 
+3 1:0.333333 2:-0.166667 3:0.423729 4:0.833333 
+3 1:0.111111 2:-0.583333 3:0.355932 4:0.5 
+3 1:0.222222 2:-0.166667 3:0.423729 4:0.583333 
+3 1:0.0555554 2:0.166667 3:0.491525 4:0.833333 
+3 1:-0.111111 2:-0.166667 3:0.38983 4:0.416667 

mlclass/logistic_regression.py

     def predict(self, X):
         return sigmoid(X.dot(self.theta))
 
+class SoftmaxRegressionGD:
+    def __init__(self, lambda_=1, **fmin_args):
+        self.lambda_ = lambda_
+        self.fmin_args = fmin_args
+
+    def cost_grad(self, Theta_flat):
+        m, n = self.X.shape
+        k = self.Y.shape[1]
+
+        Theta = Theta_flat.reshape((k, n))
+
+        P = np.exp(self.X.dot(Theta.T))
+        P /= np.sum(P, axis=1)[:,None]
+
+        j = -np.sum(np.log(P) * self.Y)
+        j += self.lambda_ * Theta_flat.dot(Theta_flat) / 2.0
+        j /= m
+
+        grad = self.X.T.dot(P - self.Y).T
+
+        grad += self.lambda_ * Theta
+        grad /= m
+
+        return j, grad.ravel()
+
+    def fit(self, X, Y):
+        self.X, self.Y = X, Y
+        theta = np.zeros(Y.shape[1] * X.shape[1])
+        theta, j, ret = fmin_l_bfgs_b(self.cost_grad, theta, **self.fmin_args)
+        if ret['warnflag'] != 0:
+            warnings.warn('L-BFGS failed to converge')
+        self.Theta = theta.reshape((Y.shape[1], X.shape[1]))
+
+    def predict(self, X):
+        P = np.exp(X.dot(self.Theta.T))
+        P /= np.sum(P, axis=1)[:,None]
+
+        return np.argmax(P, axis=1)
+
+
 if __name__ == '__main__':
+    def softmax():
+        from data import load_svm
+
+        X, y = load_svm('../data/iris.scale')
+        Y = np.eye(3)[(y - 1 + 0.5).astype(np.int)]
+        m = SoftmaxRegressionGD(lambda_=0)
+        m.fit(X, Y)
+        print np.mean(m.predict(X) == (y - 1 + 0.5).astype(np.int))
+
+    softmax()
+
     def ex2():
         from data import load_txt
 
         plt.plot(x, y)
 
         plt.show()
-
-    ps1_1()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.