orange-multitask / _multitask / mtfeat.py

"""
Implementation of the multi-task feature learning algorithm described in
[Argyriou_etal_2008]_.

.. [Argyriou_etal_2008] Argyriou, A., Evgeniou, T., Pontil, M. (2008).
   Convex multi-task feature learning. Machine Learning, 73(3), 243-272.
        
"""

import numpy as np
from numpy import dot, sqrt, array, diag, zeros, eye, ix_
from scipy.spatial.distance import pdist

import Orange
import _multitask as multitask


def get_weights(c):
    if isinstance(c, Orange.classification.svm.SVMClassifier):
        weightsa = Orange.classification.svm.get_linear_svm_weights(c)
        weights = [wd[f] for f in c.domain]
    elif isinstance(c, Orange.classification.svm.LinearClassifier):
        assert len(c.weights) == 1
        # logreg.LibLinearLogRegLearner
        weights = c.weights[0]
        if len(c.domain.features) != len(weights):
            # svm.LinearSVMLearner
            #intercept is the last weight (check by Ales)!
            weights = weights[:-1]
        assert len(c.domain.features) == len(weights)
    else:
        weights = array(c.coefficients)
    return weights

def duplicate(data, groups):
    domain = Orange.data.Domain(
        [data.domain.features[i] for g in groups for i in g],
        data.domain.class_var, class_vars=data.domain.class_vars)
    domain.add_metas(data.domain.get_metas())
    return Orange.data.Table(domain, data)

def transform_domain(domain, transformation):
    """Construct a new domain from a transformation matrix."""
    features = []
    for i, row in enumerate(transformation):
        f = Orange.feature.Continuous('f%i' % (i + 1))
        def tfun(ins, ret, weights=row):
            ins = Orange.data.Instance(domain, ins)
            return dot(weights, ins.native()[:-1])
        f.get_value_from = tfun
        features.append(f)
    dom = Orange.data.Domain(features, domain.class_var)
    dom.add_metas(domain.get_metas())
    return dom

class MTFeatLearner(Orange.regression.base.BaseRegressionLearner):
    """Multi-task feature learning algorithm from [Argyriou_etal_2008]_."""
    def __new__(cls, data=None, weights=0, **kwargs):
        self = Orange.regression.base.BaseRegressionLearner.__new__(
            cls, **kwargs)
        if data:
            self.__init__(**kwargs)
            return self.__call__(data, weights)
        else:
            return self
    
    def __init__(self, learner=None, selection=False, gamma=1,
                 intercept=False, groups=[], max_iter=50, tol=1e-5,
                 norm_covered=0.99, cb=None, name='MTFeat', **kwargs):
        super(MTFeatLearner, self).__init__()
        self.learner = learner
        self.selection = selection
        self.gamma = gamma
        self.intercept = intercept
        self.groups = groups
        self.max_iter = max_iter
        self.tol = tol
        self.norm_covered = norm_covered
        self.cb = cb if cb else lambda *args, **kwargs: None
        self.name = name
        self.__dict__.update(kwargs)

    def __call__(self, data, weights=0):
        groups = self.groups
        if groups and (sum(len(g) for g in self.groups) >
                       len(set(x for g in self.groups for x in g))):
            data = duplicate(data, groups)
            groups, i = [], 0
            for g in self.groups:
                groups.append(range(i, i + len(g)))
                i += len(g)
        datas = multitask.split_by_task(data)
        tasks = sorted(datas.keys())
        training = [datas[t].to_numpy() for t in tasks]
        intercepts = []
        if self.intercept:
            intercepts = [y.mean() for _, y, _ in training]
            training = [(X, y - y.mean(), w) for X, y, w in training]
        dim = len(data.domain.features)
        domain = Orange.data.Domain([Orange.feature.Continuous(
            'f%i' % (i + 1)) for i in range(dim)], data.domain.class_var)
        W = zeros((dim, len(tasks)))
        U = eye(dim)
        D = U / dim
        s = diag(D)

        for i in range(self.max_iter):
            Wold = W.copy()

            # Compute D^{1/2}
            sqrt(s, s)
            D_sqrt = diag(s) if self.selection else dot(U, dot(diag(s), U.T))

            # Solve the regularization problem for fixed D
            for t, (X, y, _) in enumerate(training):
                fX = dot(D_sqrt, X.T)
                if self.learner:
                    dt = Orange.data.Table(domain, np.column_stack((fX.T, y)))
                    c = self.learner(dt)
                    w = get_weights(c)
                else:
                    K = dot(fX.T, fX)
                    a = dot(np.linalg.inv(K + self.gamma * eye(K.shape[0])), y)
                    w = dot(fX, a)
                W[:, t] = w
            W = dot(D_sqrt, W)

            # Update D (solve for fixed W)
            if self.selection:
                s = sqrt(np.sum(W**2, 1))
            elif groups:
                U = zeros((dim, dim))
                s = zeros(dim)
                for g in groups:
                    Uj, sj, _ = np.linalg.svd(W[g, :])
                    if len(g) > len(tasks):
                        sj = np.hstack((sj, zeros(len(g) - len(tasks))))
                    U[ix_(g, g)] = Uj
                    s[g] = sj
            else:
                U, s, _ = np.linalg.svd(W)
                if dim > len(tasks):
                    s = np.hstack((s, zeros(dim - len(tasks))))
            s /= s.sum()
            s[s < 1e-10] = 0
            D = diag(s) if self.selection else dot(U, dot(diag(s), U.T))

            self.cb(iter=i, max_iter=self.max_iter, **locals())
            if np.linalg.norm(W - Wold) / W.size < self.tol:
                break

        return MTFeatClassifier(W, U, s, tasks, i, self.selection, data.domain,
                                intercepts, self.norm_covered, name=self.name)

class MTFeatClassifier(Orange.classification.Classifier):
    def __init__(self, W, U, s, tasks, iter, selection, domain, intercepts,
                 norm_covered, **kwargs):
        self.W = W
        self.U = U
        self.s = s
        self.tasks = tasks
        self.iter = iter
        self.selection = selection
        self.domain = domain
        self.intercepts = intercepts
        self.class_var = domain.class_var
        self.norm_covered = norm_covered
        self.__dict__.update(kwargs)
        
        ind = np.argsort(s)[::-1]
        s = np.sort(s)[::-1] if selection else s
        nrelevant = np.searchsorted(np.cumsum(s), norm_covered) + 1
        if selection:
            self.new_domain = Orange.data.Domain([domain.features[i]
                for i in ind[:nrelevant]], domain.class_var)
            self.new_domain.add_metas(domain.get_metas())
        else:
            self.transform = U.T[:nrelevant, :]
            self.new_domain = transform_domain(domain, self.transform)


    def __call__(self, instance, return_type=Orange.core.GetValue):
        ins = Orange.data.Instance(self.domain, instance)
        t = self.tasks.index(ins['task'].value)
        f = dot(self.W[:, t], ins.native()[:-1])
        if self.intercepts:
            f += self.intercepts[t]
        dist = Orange.statistics.distribution.Distribution(self.domain.class_var)
        if isinstance(self.domain.class_var, Orange.feature.Continuous):
            val = self.domain.class_var(f)
            dist[val] = 1.
        else:
            val = self.domain.class_var(int(f > 0))
            dist[1] = 1. / (1. + np.exp(-f))
            dist[0] = 1 - dist[0]
        if return_type == Orange.core.GetValue:
            return val
        elif return_type == Orange.core.GetBoth:
            return val, dist
        else:
            return dist
    
    def transform_data(self, data):
        """Transform data to the new (reduced) domain."""
        if self.selection:
            return Orange.data.Table(self.new_domain, data)
        else: # transform in numpy for speed
            data = Orange.data.Table(self.domain, data)
            X, y, _ = data.to_numpy()
            mat = np.column_stack((dot(X, self.transform.T), y))
            fdata = Orange.data.Table(self.new_domain, mat)
            for fins, ins in zip(fdata, data):
                for m in ins.get_metas():
                    fins[m] = ins[m]
            return fdata

    def task_distances(self, metric='euclidean', **kwargs):
        """Compute the task distance matrix.
        See scipy.spatial.distance.pdist for metrics and additional parameters.
        """
        dm = pdist(self.W.T, metric, **kwargs)
        t = len(self.tasks)
        dsm = Orange.misc.SymMatrix(t)
        i = 0
        for ti in range(t - 1):
            for tj in range(ti + 1, t):
                dsm[ti, tj] = dm[i]
                i += 1
        items = Orange.data.Table(Orange.data.Domain([Orange.feature.String(
            'task')], False), [[str(task)] for task in self.tasks])
        dsm.items = items
        return dsm


if __name__ == '__main__':
    school = Orange.data.Table('datasets/school')
    train = Orange.data.Table('datasets/train1')
    test = Orange.data.Table('datasets/test1')
    mean = Orange.regression.mean.MeanLearner(name='Mean')
    ridge = Orange.regression.linear.LinearRegressionLearner(
        intercept=False, ridge_lambda=1, name='Ridge')
    mt_ridge = multitask.MultiTaskLearner(learner=ridge, name='Independent')
    mtf = MTFeatLearner(learner=ridge, max_iterations=30, name='MTFeat')
    res = Orange.evaluation.testing.learn_and_test_on_test_data(
        learners=[mean, ridge, mt_ridge, mtf], learn_set=train, test_set=test)
    print zip(res.classifier_names, Orange.evaluation.scoring.RMSE(res))
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.