orange-multitask / tests / generate_data.py

from numpy import (dot, zeros, column_stack, vstack, concatenate,
                   exp, sqrt, random as rnd)

import Orange


def sigmoid(x):
    return 1. / (1 + exp(-x))

def to_table(X, y=None):
    n, m = X.shape
    features = [Orange.feature.Continuous('f%i' % i) for i in range(m)]
    if y is None:
        return Orange.data.Table(Orange.data.Domain(features, False), X)
    if set(y) == {0, 1}:
        cls = Orange.feature.Discrete('class', values=['0', '1'])
    else:
        cls = Orange.feature.Continuous('class')
    return Orange.data.Table(Orange.data.Domain(features, cls),
                             column_stack((X, y)))

class Generator(object):
    def generate_matrix(self, **kwargs):
        data = self.generate_table(**kwargs)
        return data.to_numpy()[:2]
    
    def generate_table(self, **kwargs):
        data = self.generate_matrix(**kwargs)
        return to_table(*data)

    def __call__(self, orange=True, **kwargs):
        if orange:
            return self.generate_table(**kwargs)
        else:
            return self.generate_matrix(**kwargs)

class Group(Generator):
    def __init__(self, groups=10, pergroup=20, factors=[5, 4, 3, 2, 1],
                 tasks=20, overlap=2, seed=42):
        self.groups = groups
        self.pergroup = pergroup
        self.factors = factors
        self.m = groups * (pergroup - overlap) + overlap
        self.tasks = tasks
        self.seed = seed

        rnd.seed(seed)
        self.transform = zeros((self.m, sum(factors)))
        f = 0
        for i, k in enumerate(factors):
            start = i * (pergroup - overlap)
            self.transform[start:start + pergroup, f:f + k] = \
                rnd.normal(0, 1, (pergroup, k))
            f += k
        mus = rnd.normal(0, 1, sum(factors))
        self.weights = rnd.normal(mus, 1, (tasks, sum(factors)))
        self.intercepts = rnd.normal(100, 20, tasks)

    def get_model(self):
        return self.transform, self.weights, self.intercepts

    def generate_matrix(self, pertask=50, seed=None):
        if seed is None:
            seed = self.seed
        rnd.seed(seed)
        Xs = [rnd.normal(0, 1, (pertask, self.m))
              for _ in range(self.tasks)]
        ys = [dot(dot(X, self.transform), w) + i
              for X, w, i in zip(Xs, self.weights, self.intercepts)]
        X = vstack(Xs)
        y = concatenate(ys)
        y += rnd.normal(0, 1, len(y))
        return (X, y)

    def __call__(self, **kwargs):
        data = self.generate_table(**kwargs)
        values=[str(i) for i in range(self.tasks)]
        task = Orange.feature.Discrete('task', values=values)
        id = Orange.feature.Descriptor.new_meta_id()
        data.domain.add_meta(id, task)
        for ins, t in zip(data, (str(i) for i in range(self.tasks)
                                 for _ in range(len(data) // self.tasks))):
            ins[task] = t
        return data


class Group_binary(Group):
    def __init__(self, mu=0, sigma=1, **kwargs):
        Group.__init__(self, **kwargs)
        self.intercepts = (self.intercepts - 100 + mu) / 40.
        self.weights *= 4 * sigma / sqrt(sum(self.factors) * self.pergroup)

    def generate_matrix(self, pertask=50, seed=None):
        X, y = Group.generate_matrix(self, pertask, seed)
        y = (rnd.rand(len(y)) < sigmoid(y)).astype(int)
        return X, y
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.