Aleš Erjavec avatar Aleš Erjavec committed 32f2571

Added Neural Network widget

Comments (0)

Files changed (3)

_multitarget/widgets/OWNeuralNetwork.py

+"""
+<name>Neural Network</name>
+<description>Neural network learner/classifier supporting multi target problems.</description>
+<category>Multi Target</category>
+<priority>50<priority>
+<tags>neural,network,multitarget</tags>
+
+
+"""
+
+import Orange
+import Orange.multitarget
+from orngWrap import PreprocessedLearner
+
+from OWWidget import *
+import OWGUI
+
+class OWNeuralNetwork(OWWidget):
+    settingsList = ["name", "n_mid", "reg_fact", "max_iter"]
+
+    def __init__(self, parent=None, signalManager=None,
+                 title="Neural Network"):
+        OWWidget.__init__(self, parent, signalManager, title,
+                          wantMainArea=False)
+
+        self.inputs = [("Data", Orange.data.Table, self.set_data),
+                       ("Preprocess", PreprocessedLearner,
+                        self.set_preprocessor)
+                       ]
+        self.outputs = [("Learner", Orange.classification.Learner),
+                        ("Classifier", Orange.classification.Classifier)
+                        ]
+
+        self.name = "Neural Network"
+        self.n_mid = 20
+        self.reg_fact = 1
+        self.max_iter = 1000
+
+        self.loadSettings()
+
+        box = OWGUI.widgetBox(self.controlArea, "Name", addSpace=True)
+        OWGUI.lineEdit(box, self, "name")
+
+        box = OWGUI.widgetBox(self.controlArea, "Settings", addSpace=True)
+        OWGUI.spin(box, self, "n_mid", 2, 10000, 1,
+                   label="Hidden layer neurons",
+                   tooltip="Number of neurons in the hidden layer."
+                   )
+
+        OWGUI.doubleSpin(box, self, "reg_fact", 0.1, 10.0, 0.1,
+                         label="Regularization factor",
+                         )
+
+        OWGUI.spin(box, self, "max_iter", 100, 10000, 1,
+                   label="Max iterations",
+                   tooltip="Maximal number of optimization iterations."
+                   )
+
+        OWGUI.button(self.controlArea, self, "&Apply",
+                     callback=self.apply,
+                     tooltip="Create the learner and apply it on input data.",
+                     autoDefault=True
+                     )
+
+        self.data = None
+        self.preprocessor = None
+        self.apply()
+
+    def set_data(self, data=None):
+        self.data = data
+        self.error([0])
+        if data is not None and not data.domain.class_vars:
+            data = None
+            self.error(0, "Input data must have multi target domain.")
+
+        self.data = data
+        self.apply()
+
+    def set_preprocessor(self, preprocessor=None):
+        self.preprocessor = preprocessor
+
+    def apply(self):
+        learner = Orange.multitarget.neural.NeuralNetworkLearner(
+            name=self.name, n_mid=self.n_mid,
+            reg_fact=self.reg_fact, max_iter=self.max_iter
+        )
+
+        if self.preprocessor is not None:
+            learner = self.preprocessor.wrapLearner(learner)
+
+        classifier = None
+        self.error([1])
+        if self.data is not None:
+            try:
+                classifier = learner(self.data)
+                classifier.name = self.name
+            except Exception, ex:
+                self.error(1, str(ex))
+
+        self.send("Learner", learner)
+        self.send("Classifier", classifier)
+
+    def sendReport(self):
+        self.reportSettings("Parameters",
+                            [("Hidden layer neurons", self.n_mid),
+                             ("Regularization factor", self.reg_fact),
+                             ("Max iterations", self.max_iter)]
+                            )
+
+
+if __name__ == "__main__":
+    app = QApplication([])
+    w = OWNeuralNetwork()
+    data = Orange.data.Table("multitarget:emotions.tab")
+    w.set_data(data)
+    w.set_data(None)
+    w.set_data(data)
+    w.show()
+    app.exec_()

_multitarget/widgets/__init__.py

+"""
+Multi-target widgets.
+
+"""
     'orange.addons': (
         'multitarget = _multitarget',
     ),
+    'orange.widgets': (
+        'Multitarget = _multitarget.widgets'
+    ),
 	'orange.data.io.search_paths': (
 		'multitarget = _multitarget:datasets',
 	),
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.