orange-multitask / _multitask / __init__.py

from __future__ import absolute_import
from collections import defaultdict
from pkg_resources import resource_filename

import Orange
from . import mtfeat


def datasets():
    yield ('multitask', resource_filename(__name__, 'datasets'))

def split_by_task(data, task_id='task'):
    """
    Split data set into individual task data sets.
    
    :param data: Data set with all instances from all tasks.
    :type data: :obj:`~Orange.data.Table`
    :param task_id: Attribute with task ids used for the split.
    
    :return: Dictionary with task identifiers and corresponding data sets.
    """
    task_indices = [int(ins[task_id]) for ins in data]
    return {ind: data.select(task_indices, ind)
            for ind in sorted(set(task_indices))}
    
def get_groups(domain, feat_att=None):
    if feat_att is None and domain.features and domain.features[0].attributes:
        feat_att = domain.features[0].attributes.keys()[0]
    groups = defaultdict(list)
    for i, f in enumerate(domain.features):
        groups[f.attributes[feat_att]].append(i)
    return groups.values()

def join(datas):
    """Join data sets with same domains into a single multi-task data set."""
    dom = datas[0].domain
    for d in datas[1:]:
        if not all(f1 == f2 for f1, f2 in zip(dom, d.domain)):
            print d
            raise Exception('Different domains')
    task = Orange.feature.Continuous('task')
    mid = Orange.feature.Descriptor.new_meta_id()
    dom.add_meta(mid, task)
    data = Orange.data.Table(dom)
    for i, d in enumerate(datas):
        for ins in d:
            data.append(ins)
            data[-1]['task'] = float(i)
    return data

class MultiTaskLearner(Orange.classification.Learner):
    """Learn a model for each task independently."""
    def __new__(cls, data=None, weights=0, **kwargs):
        self = Orange.classification.Learner.__new__(cls, **kwargs)
        if data:
            self.__init__(**kwargs)
            return self.__call__(data, weights)
        else:
            return self

    def __init__(self, learner, **kwargs):
        self.learner = learner
        self.__dict__.update(kwargs)

    def __call__(self, data, weights=0):
        datas = split_by_task(data)
        tasks = sorted(datas.keys())
        classifiers = {t: self.learner(datas[t]) for t in tasks}
        return MultiTaskClassifier(classifiers=classifiers)

class MultiTaskClassifier(Orange.classification.Classifier):
    """Classify instances with the appropriate task-specific classifier."""
    def __init__(self, classifiers):
        self.classifiers = classifiers

    def __call__(self, instance, return_type=Orange.core.GetValue):
        return self.classifiers[int(instance['task'])](instance, return_type)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.