Jure Žbontar avatar Jure Žbontar committed 05d6847

Add normalize flag for NNs

Comments (0)

Files changed (1)

Orange/classification/neural.py

 
 class NeuralNetworkLearner(Orange.classification.Learner):
     """
-    NeuralNetworkLearner uses jzbontar's implementation of neural networks and wraps it in
-    an Orange compatible learner. 
+    NeuralNetworkLearner implements a multilayer perceptron. Learning is performed by minimizing an L2-regularized
+    cost function with scipy's implementation of L-BFGS. The current implementations is limited to a single
+    hidden layer. 
 
-    NeuralNetworkLearner supports all types of data and returns a classifier, regression is currently not supported.
-
-    More information about neural networks can be found at http://en.wikipedia.org/wiki/Artificial_neural_network.
+    Regression is currently not supported.
 
     :param name: learner name.
     :type name: string
 
     :param n_mid: Number of nodes in the hidden layer
-    :type n_mid: integer
+    :type n_mid: int
 
     :param reg_fact: Regularization factor.
     :type reg_fact: float
 
     :param max_iter: Maximum number of iterations.
-    :type max_iter: integer
+    :type max_iter: int
+
+    :param normalize: Normalize the data prior to learning (subtract each column by the mean and divide by the standard deviation)
+    :type normalize: bool
 
     :rtype: :class:`Orange.multitarget.neural.neuralNetworkLearner` or 
             :class:`Orange.multitarget.chain.NeuralNetworkClassifier`
             self.__init__(**kwargs)
             return self(data,weight)
 
-    def __init__(self, name="NeuralNetwork", n_mid=10, reg_fact=1, max_iter=1000, rand=None):
+    def __init__(self, name="NeuralNetwork", n_mid=10, reg_fact=1, max_iter=300, normalize=True, rand=None):
         """
         Current default values are the same as in the original implementation (neural_networks.py)
         """
-
         self.name = name
         self.n_mid = n_mid
         self.reg_fact = reg_fact
         self.max_iter = max_iter
         self.rand = rand
+        self.normalize = normalize
 
         if not self.rand:
             self.rand = random.Random(42)
         #converts attribute data
         X = data.to_numpy()[0] 
 
+        mean = X.mean(axis=0)
+        std = X.std(axis=0)
+        if self.normalize:
+            X = (X - mean) / std
+
         #converts multi-target or single-target classes to numpy
-
-
         if data.domain.class_vars:
             for cv in data.domain.class_vars:
                 if cv.var_type == Orange.feature.Continuous:
         
         self.nn.fit(X,Y)
                
-        return NeuralNetworkClassifier(classifier=self.nn.predict, domain = data.domain)
+        return NeuralNetworkClassifier(classifier=self.nn.predict,
+            domain=data.domain, normalize=self.normalize, mean=mean, std=std)
 
 class NeuralNetworkClassifier():
     """    
         if not self.domain.class_vars: example = [example[i] for i in range(len(example)-1)]
         input = np.array([[float(e) for e in example]])
 
+        if self.normalize:
+            input = (input - self.mean) / self.std
+
         # transform results from numpy
         results = self.classifier(input).tolist()[0]
+        if len(results) == 1:
+            prob_positive = results[0]
+            results = [1 - prob_positive, prob_positive]
         mt_prob = []
         mt_value = []
           
     print "STARTED"
     global_timer = time.time()
 
-    data = Orange.data.Table('iris')
-    l1 = NeuralNetworkLearner(n_mid=10, reg_fact=1, max_iter=1000)
-    res = Orange.evaluation.testing.cross_validation([l1],data, 3)
+    data = Orange.data.Table('wine')
+    l1 = NeuralNetworkLearner(n_mid=40, reg_fact=1, max_iter=200)
+
+#    c1 = l1(data)
+#    print c1(data[0], 3), data[0]
+
+    l2 = Orange.classification.bayes.NaiveLearner()
+    res = Orange.evaluation.testing.cross_validation([l1, l2],data, 5)
     scores = Orange.evaluation.scoring.CA(res)
-
     for i in range(len(scores)):
         print res.classifierNames[i], scores[i]
 
-    print "--DONE %.2f --" % (time.time()-global_timer)
+    print "--DONE %.2f --" % (time.time()-global_timer)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.