Aleš Erjavec avatar Aleš Erjavec committed 856bb21

Added Multi Target Tree widget.

Comments (0)

Files changed (1)

_multitarget/widgets/OWMultiTargetTree.py

+"""
+<name>Multi target tree</name>
+<description>Classification tree learner/classifier for multi
+target classification</description>
+<priority>100</priority>
+<tags>tree,multitarget</tags>
+
+"""
+
+import Orange
+import Orange.multitarget
+
+from Orange.tuning import PreprocessedLearner
+
+from OWWidget import *
+import OWGUI
+
+
+class OWMultiTargetTree(OWWidget):
+    settingsList = ["name", "max_depth"]
+
+    BINARIZATION = ["No binarization",
+                    "Exhaustive search for optimal split",
+                    "One value against others"]
+
+    def __init__(self, parent=None, signalManager=None,
+                 title="Multi target tree"):
+        OWWidget.__init__(self, parent, signalManager, title,
+                          wantMainArea=False)
+
+        self.inputs = [("Data", Orange.data.Table, self.set_data),
+                       ("Preprocess", PreprocessedLearner,
+                        self.set_preprocessor)
+                       ]
+        self.outputs = [("Learner", Orange.classification.Learner),
+                        ("Classifier", Orange.classification.Classifier)
+                        ]
+
+        self.name = "Multi Target Tree"
+        self.binarization = 0
+        self.use_min_subset = True
+        self.min_subset = 2
+        self.use_min_instances = False
+        self.min_instances = 5
+        self.use_max_majority = False
+        self.max_majority = 95
+        self.use_max_depth = False
+        self.max_depth = 100
+        self.same_majority_pruning = True
+        self.use_m_pruning = True
+        self.m_pruning = 2
+
+        box = OWGUI.widgetBox(self.controlArea, "Classifier/Learner Name")
+        OWGUI.lineEdit(box, self, "name")
+
+        box = OWGUI.widgetBox(self.controlArea, "Binarization")
+        OWGUI.radioButtonsInBox(box, self, "binarization", self.BINARIZATION)
+
+        box = OWGUI.widgetBox(self.controlArea, "Pre-Pruning")
+
+        OWGUI.checkWithSpin(box, self, "Min. instances in leaves",
+                            1, 1000,
+                            "use_min_subset", "min_subset"
+                            )
+
+        OWGUI.checkWithSpin(box, self,
+                            "Stop splitting nodes with less instances than",
+                            1, 1000,
+                            "use_min_instances", "min_instances"
+                            )
+
+        OWGUI.checkWithSpin(box, self,
+                            "Stop splitting nodes with a majority class of (%)",
+                            1, 100,
+                            "use_max_majority", "max_majority")
+
+        OWGUI.checkWithSpin(box, self,
+                            "Stop splitting nodes at depth",
+                            1, 1000,
+                            "use_max_depth", "max_depth")
+
+        box = OWGUI.widgetBox(self.controlArea, "Post-Pruning")
+
+        OWGUI.checkBox(box, self, "same_majority_pruning",
+                       "Recursively merge leaves with same majority class")
+
+        OWGUI.checkWithSpin(box, self, "Pruning with m-estimate, m=",
+                            0, 1000,
+                            "use_m_pruning", 'm_pruning')
+
+        OWGUI.button(self.controlArea, self, "&Apply",
+                     callback=self.apply,
+                     tooltip="Create the learner and apply it on input data.",
+                     autoDefault=True
+                     )
+
+        self.data = None
+        self.preprocessor = None
+        self.apply()
+
+    def set_data(self, data=None):
+        self.data = data
+        self.error([0])
+        if data is not None and not data.domain.class_vars:
+            data = None
+            self.error(0, "Input data must have multi target domain.")
+
+        self.data = data
+        self.apply()
+
+    def set_preprocessor(self, preprocessor=None):
+        self.preprocessor = preprocessor
+
+    def apply(self):
+        def choice(name, default=0):
+            if getattr(self, "use_" + name):
+                return getattr(self, name)
+            else:
+                return default
+
+        params = \
+            {"binarization": self.binarization,
+             "min_subset": choice("min_subset"),
+             "min_instances": choice("min_instances"),
+             "max_majority": choice("max_majority"),
+             "max_depth": choice("max_depth"),
+             "same_majority_pruning": self.same_majority_pruning,
+             "m_pruning": choice("m_pruning"),
+             "name": self.name,
+             }
+
+        learner = Orange.multitarget.tree.MultiTreeLearner(**params)
+
+        if self.preprocessor is not None:
+            learner = self.preprocessor.wrapLearner(learner)
+
+        classifier = None
+        self.error([1])
+        if self.data is not None:
+            try:
+                classifier = learner(self.data)
+                classifier.name = self.name
+            except Exception, ex:
+                self.error(1, str(ex))
+
+        self.send("Learner", learner)
+        self.send("Classifier", classifier)
+
+    def sendReport(self):
+        self.reportSettings(
+            "Parameters",
+            [("Binarization", self.BINARIZATION[self.binarization]),
+             ("Pruning",
+              ", ".join(s for s, c in (
+                             ("%i instances in leaves" % self.min_subset,
+                              self.use_min_subset),
+                             ("%i instance in node" % self.min_instances,
+                              self.use_min_instances),
+                             ("stop on %i%% purity" % self.max_majority,
+                              self.use_max_majority),
+                             ("maximum depth %i" % self.max_depth,
+                              self.use_max_depth))
+                        if c)
+                or "None"
+                ),
+             ("Recursively merge leaves with same majority class",
+              OWGUI.YesNo[self.same_majority_pruning]),
+             ("Pruning with m-estimate",
+              ["No", "m=%i" % self.m_pruning][self.use_m_pruning])]
+        )
+
+        self.reportData(self.data)
+
+
+if __name__ == "__main__":
+    app = QApplication([])
+    w = OWMultiTargetTree()
+    data = Orange.data.Table("multitarget:emotions.tab")
+    w.set_data(data)
+    w.set_data(None)
+    w.set_data(data)
+    w.show()
+    app.exec_()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.