Commits

Aleš Erjavec committed 0453b7a

Added back single tree output using the simple tree converter.

Comments (0)

Files changed (2)

Orange/OrangeWidgets/Classify/OWRandomForest.py

                         self.setPreprocessor)]
 
         self.outputs = [("Learner", orange.Learner),
-                        ("Random Forest Classifier", orange.Classifier)]
+                        ("Random Forest Classifier", orange.Classifier),
+                        ("Selected Tree", Orange.classification.tree.TreeClassifier)]
 
         self.name = 'Random Forest'
         self.trees = 10
         self.limitDepth = 0
         self.limitDepthP = 3
         self.rseed = 0
+        self.outtree = 0
 
         self.maxTrees = 10000
 
 
         OWGUI.separator(self.controlArea)
 
+        self.streesBox = OWGUI.spin(self.controlArea, self, "outtree", -1,
+                                    self.maxTrees,
+                                    orientation="horizontal",
+                                    label="Index of tree on the output",
+                                    callback=[self.period, self.extree])
+        self.streeEnabled(False)
+
         OWGUI.separator(self.controlArea)
 
         self.btnApply = OWGUI.button(self.controlArea, self,
         self.data = data
 
         #self.setLearner()
-
+        self.streeEnabled(False)
         if self.data:
             learner = self.constructLearner()
             self.progressBarInit()
             learner.callback = lambda v: self.progressBarSet(100.0 * v)
             try:
                 self.classifier = learner(self.data)
+                self.streeEnabled(True)
                 self.classifier.name = self.name
             except Exception, (errValue):
                 self.error(str(errValue))
         self.setLearner()
         self.setData(self.data)
 
+    def period(self):
+        if self.outtree == -1:
+            self.outtree = self.claTrees - 1
+        elif self.outtree >= self.claTrees:
+            self.outtree = 0
+
+    def extree(self):
+        stc = self.classifier.classifiers[self.outtree]
+        if self.preprocessor:
+            # TODO: get the transformed data at learning step from the
+            # wrapped learner (or at least cache it here)
+            train_data = self.data.translate(self.classifier.domain)
+        else:
+            train_data = self.data
+
+        # Replay the bootstrap sampling as done by RandomForestLearner
+        rand = random.Random(self.claSeed)
+        n = len(train_data)
+        selection = [rand.randrange(n)
+                     for _ in range((self.outtree + 1) * n)]
+        # need the last n samples
+        selection = selection[-n:]
+        train_data = train_data.get_items_ref(selection)
+
+        tree = Orange.classification.tree._simple_tree_convert(
+            stc, self.classifier.domain, train_data)
+
+        self.send("Selected Tree", tree)
+
+    def streeEnabled(self, status):
+        if status:
+            self.claTrees = self.trees
+            self.claSeed = self.rseed
+            self.streesBox.setDisabled(False)
+            self.period()
+            self.extree()
+        else:
+            self.streesBox.setDisabled(True)
+
 
 if __name__ == "__main__":
     a = QApplication(sys.argv)

Orange/OrangeWidgets/Regression/OWRandomForestRegression.py

 <icon>icons/RandomForestRegression.svg</icon>
 <contact>Marko Toplak (marko.toplak(@at@)gmail.com)</contact>
 <priority>320</priority>
-<keywords>bagging, ensemble</keywords>
+<tags>bagging,ensemble</tags>
 
 """
 
                        ("Preprocess", PreprocessedLearner, self.setPreprocessor)]
 
         self.outputs = [("Learner", orange.Learner),
-                        ("Random Forest Classifier", orange.Classifier)]
+                        ("Random Forest Classifier", orange.Classifier),
+                        ("Selected Tree", Orange.classification.tree.TreeClassifier)]
 
     def setData(self, data):
-        self.data = self.isDataWithClass(data, orange.VarTypes.Continuous, checkMissing=True) and data or None
-        
+        if not self.isDataWithClass(data, orange.VarTypes.Continuous,
+                                    checkMissing=True):
+            data = None
+        self.data = data
+
+        self.streeEnabled(False)
         if self.data:
             learner = self.constructLearner()
             self.progressBarInit()
             try:
                 self.classifier = learner(self.data)
                 self.classifier.name = self.name
+                self.streeEnabled(True)
             except Exception, (errValue):
                 self.error(str(errValue))
                 self.classifier = None