Lan Zagar avatar Lan Zagar committed 562c338

Bug fixes for classification.

Comments (0)

Files changed (1)

_multitask/mtfeat.py

 
 def get_weights(c):
     if isinstance(c, Orange.classification.svm.SVMClassifier):
-        weightsa = Orange.classification.svm.get_linear_svm_weights(c)
-        weights = [wd[f] for f in c.domain]
+        wd = Orange.classification.svm.get_linear_svm_weights(c)
+        weights = [wd[f] for f in c.domain.features]
     elif isinstance(c, Orange.classification.svm.LinearClassifier):
         assert len(c.weights) == 1
         # logreg.LibLinearLogRegLearner
         weights = c.weights[0]
         if len(c.domain.features) != len(weights):
             # svm.LinearSVMLearner
-            #intercept is the last weight (check by Ales)!
             weights = weights[:-1]
         assert len(c.domain.features) == len(weights)
     else:
         training = [datas[t].to_numpy() for t in tasks]
         intercepts = []
         if self.intercept:
-            intercepts = [y.mean() for _, y, _ in training]
-            training = [(X, y - y.mean(), w) for X, y, w in training]
+            if isinstance(data.domain.class_var, Orange.feature.Continuous):
+                intercepts = [y.mean() for _, y, _ in training]
+                training = [(X, y - y.mean(), w) for X, y, w in training]
+            else:
+                raise Exception('Not implemented.')
         dim = len(data.domain.features)
         domain = Orange.data.Domain([Orange.feature.Continuous(
             'f%i' % (i + 1)) for i in range(dim)], data.domain.class_var)
             val = self.domain.class_var(f)
             dist[val] = 1.
         else:
-            val = self.domain.class_var(int(f > 0))
-            dist[1] = 1. / (1. + np.exp(-f))
-            dist[0] = 1 - dist[1]
+            val = self.domain.class_var(int(f < 0))
+            dist[0] = 1. / (1. + np.exp(-f))
+            dist[1] = 1 - dist[0]
         if return_type == Orange.core.GetValue:
             return val
         elif return_type == Orange.core.GetBoth:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.