Lan Zagar avatar Lan Zagar committed de40db0

Classification support and more general extraction of weights from base learners.

Comments (0)

Files changed (1)

_multitask/mtfeat.py

 import _multitask as multitask
 
 
+def get_weights(c):
+    if isinstance(c, Orange.classification.svm.SVMClassifier):
+        weightsa = Orange.classification.svm.get_linear_svm_weights(c)
+        weights = [wd[f] for f in c.domain]
+    elif isinstance(c, Orange.classification.svm.LinearClassifier):
+        #intercept is the last weight (check by Ales)!
+        weights = c.weights[0][:-1]
+        assert len(c.domain.features) == len(weights)
+    else:
+        weights = array(c.coefficients)
+    return weights
+
 def duplicate(data, groups):
     domain = Orange.data.Domain(
         [data.domain.features[i] for g in groups for i in g],
                 if self.learner:
                     dt = Orange.data.Table(domain, np.column_stack((fX.T, y)))
                     c = self.learner(dt)
-                    w = array(c.coefficients)
+                    w = get_weights(c)
                 else:
                     K = dot(fX.T, fX)
                     a = dot(np.linalg.inv(K + self.gamma * eye(K.shape[0])), y)
             s[s < 1e-10] = 0
             D = diag(s) if self.selection else dot(U, dot(diag(s), U.T))
 
-            self.cb(**locals())
+            self.cb(iter=i, max_iter=self.max_iter, **locals())
             if np.linalg.norm(W - Wold) / W.size < self.tol:
                 break
 
         f = dot(self.W[:, t], ins.native()[:-1])
         if self.intercepts:
             f += self.intercepts[t]
-        val = self.domain.class_var(f)
+        if isinstance(self.domain.class_var, Orange.feature.Continuous):
+            val = self.domain.class_var(f)
+        else:
+            val = self.domain.class_var(int(f > 0))
         dist = Orange.statistics.distribution.Distribution(val.variable)
-        dist[val] = 1.
+        dist[val] = 1. / (1. + np.exp(-f))
         if return_type == Orange.core.GetValue:
             return val
         elif return_type == Orange.core.GetBoth:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.