orange-multitarget / _multitarget / binary.py

import Orange.core as orange
import Orange
import random
import copy
from operator import add

class BinaryRelevanceLearner(orange.Learner):
    """
    Expands single class classification techniques into multi-target classification techniques by chaining the classification
    data. A learner is constructed for each of the class variables in a random or given order. The data for each learner are
    the features extended by all previously classified variables. This chaining passes the class informationd between
    classifiers and allows class correlations to be taken into account.
    TODO: cite weka source?

    :param learner: A single class learner that will be extended by chaining.
    :type learner: :class:`Orange.core.Learner`

    :param rand: random generator used in bootstrap sampling. If None (default), 
        then ``random.Random(42)`` is used.

    :param callback: a function to be called after every iteration of
            induction of classifier. The call returns a parameter
            (from 0.0 to 1.0) that provides an estimate
            of completion of the learning progress.

    :param name: learner name.
    :type name: string

    :rtype: :class:`Orange.ensemble.forest.RandomForestClassifier` or 
            :class:`Orange.ensemble.forest.RandomForestLearner`

    """

    def __new__(cls, data=None, weight = 0, **kwargs):
        self = Orange.classification.Learner.__new__(cls, **kwargs)
        if data:   
            self.__init__(**kwargs)
            return self(data,weight)
        else:
            return self

    def __init__(self, learner=None, name="Binary Relevance", rand=None, callback=None):
        self.name = name
        self.rand = rand
        self.callback = callback

        if not learner:
            raise TypeError("Wrong specification, learner not defined")
        else:
            self.learner = learner

        if not self.rand:
            self.rand = random.Random(42)

        self.randstate = self.rand.getstate()
           

    def __call__(self, instances, weight=0):
        """
        Learn from the given table of data instances.
        
        :param instances: data for learning.
        :type instances: class:`Orange.data.Table`

        :param weight: weight.
        :type weight: int

        :rtype: :class:`Orange.ensemble.chain.ClassifierChain`
        """

        instances = Orange.data.Table(instances.domain, instances) # bypasses ownership

        self.rand.setstate(self.randstate) 
        n = len(instances)
        m = len(instances.domain.class_vars)
        progress = 0.0

        classifiers = [None for _ in xrange(m)]
        domains = [None for _ in xrange(m)]
        orig_domain = copy.copy(instances.domain)

        class_order = [cv for cv in instances.domain.class_vars]

        learner = self.learner

        for i in range(m):
            # sets one of the class_vars as class_var
            instances.pick_class(class_order[i])            

            # save domains for classification
            domains[i] = Orange.data.Domain([d for d in instances.domain])

            classifiers[i] = learner(instances, weight)

            if self.callback:
                progress+=1
                self.callback(progress / m)

        return BinaryRelevanceClassifier(classifiers=classifiers, class_order=class_order, domains=domains, name=self.name, orig_domain=orig_domain)


class BinaryRelevanceClassifier(orange.Classifier):
    """
    Uses the classifiers induced by the :obj:`ClassifierChainLearner`. An input
    instance is classified into the class with the most frequent vote.
    However, this implementation returns the averaged probabilities from
    each of the trees if class probability is requested.

    It should not be constructed manually. TODO: ask about this
    
    :param classifiers: a list of classifiers to be used.
    :type classifiers: list
    
    
    
    :param domains: the domain of the learning set.
    :type domain: :class:`Orange.data.Domain`
    
    :param class_var: the class feature.
    :type class_var: :class:`Orange.feature.Descriptor`

    :param class_vars: the multi-target class features.
    :type class_vars: list of :class:`Orange.feature.Descriptor`

    :param name: name of the resulting classifier.
    :type name: str

    """

    def __init__(self, classifiers, class_order, domains, name, orig_domain):
        self.classifiers = classifiers
        self.class_order = class_order
        self.name = name
        self.domains = domains
        self.orig_domain = orig_domain

    def __call__(self, instance, result_type = orange.GetValue):
        """
        :param instance: instance to be classified.
        :type instance: :class:`Orange.data.Instance`
        
        :param result_type: :class:`Orange.classification.Classifier.GetValue` or \
              :class:`Orange.classification.Classifier.GetProbabilities` or
              :class:`Orange.classification.Classifier.GetBoth`
        
        :rtype: :class:`Orange.data.Value`, 
              :class:`Orange.statistics.Distribution` or a tuple with both
        """
        m = len(self.class_order)
        values = [None for _ in range(m)] 
        probs = [None for _ in range(m)] 

        for i in range(m):
            #add blank class for classification
            inst = Orange.data.Instance(self.domains[i], [v for v in instance]+['?'])

            res = self.classifiers[i](inst, orange.GetBoth)
            values[i] = res[0]
            probs[i] = res[1]

        if result_type == orange.GetValue: return tuple(values)
        elif result_type == orange.GetProbabilities: return tuple(probs)
        else: 
            return [tuple(values),tuple(probs)]

    def __reduce__(self):
        return type(self), (self.classifiers, self.class_order, self.domains, self.name, self.orig_domain), dict(self.__dict__)

if __name__ == '__main__':
    import time
    print "STARTED"
    global_timer = time.time()

    data = Orange.data.Table('bridges.v2.nm')
    #data = Orange.data.Table('ntp.fp3.nm')
    
    
    l1 = BinaryRelevanceLearner(learner = Orange.classification.tree.SimpleTreeLearner)
    l2 = BinaryRelevanceLearner(learner = Orange.classification.bayes.NaiveLearner)
    l3 = BinaryRelevanceLearner(learner = Orange.classification.majority.MajorityLearner)
    l4 = Orange.multitarget.tree.MultiTreeLearner()


    #l =  EnsembleClassifierChainLearner(learner = Orange.classification.tree.SimpleTreeLearner, n_chains=3, name="ECC T")
    #l = Orange.multitarget.tree.MultiTreeLearner()
    #l = EnsembleClassifierChainLearner(learner = Orange.classification.majority.MajorityLearner, n_chains=3, name="ECC M")
    # cross_validation not working
    res = Orange.evaluation.testing.cross_validation([l1,l2,l3,l4],data, folds = 2)
    #res = Orange.evaluation.testing.cross_validation([l],data,folds=5)

    scores = Orange.evaluation.scoring.mt_average_score(res,Orange.evaluation.scoring.RMSE)
   
    print scores
    print res.classifierNames
    for i in range(len(scores)):
        print res.classifierNames[i], scores[i]


    print "--DONE %.2f --" % (time.time()-global_timer)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.