Commits

Lan Zagar committed 8dd440f

Support for overlapping groups (closes #12).

Comments (0)

Files changed (1)

_multitask/mtfeat.py

 import _multitask as multitask
 
 
+def duplicate(data, groups):
+    domain = Orange.data.Domain(
+        [data.domain.features[i] for g in groups for i in g],
+        data.domain.class_var, class_vars=data.domain.class_vars)
+    domain.add_metas(data.domain.get_metas())
+    return Orange.data.Table(domain, data)
+
 def transform_domain(domain, transformation):
     """Construct a new domain from a transformation matrix."""
     features = []
         self.__dict__.update(kwargs)
 
     def __call__(self, data, weights=0):
+        groups = self.groups
+        if groups and (sum(len(g) for g in self.groups) >
+                       len(set(x for g in self.groups for x in g))):
+            data = duplicate(data, groups)
+            groups, i = [], 0
+            for g in self.groups:
+                groups.append(range(i, i + len(g)))
+                i += len(g)
         datas = multitask.split_by_task(data)
         tasks = sorted(datas.keys())
         training = [datas[t].to_numpy() for t in tasks]
             # Update D (solve for fixed W)
             if self.selection:
                 s = sqrt(np.sum(W**2, 1))
-            elif self.groups:
+            elif groups:
                 U = zeros((dim, dim))
                 s = zeros(dim)
-                for g in self.groups:
+                for g in groups:
                     Uj, sj, _ = np.linalg.svd(W[g, :])
                     if len(g) > len(tasks):
                         sj = np.hstack((sj, zeros(len(g) - len(tasks))))