Commits

Jiangge Zhang committed 1ec4d0f

finished the test script.

  • Participants
  • Parent commits 150a03f

Comments (0)

Files changed (3)

 *.pyc
+*.pyo
 *~
 *.sublime-*
 #!/usr/bin/env python
 #-*- coding:utf-8 -*-
 
+from __future__ import division
 from collections import namedtuple
 
 import numpy as np

File testbpann.py

 #!/usr/bin/env python
 #-*- coding:utf-8 -*-
 
+from __future__ import division
 from collections import namedtuple
-from unittest import TestCase, main as runtest
 
 import numpy as np
 
         yield flower(*record)
 
 
-class AnnTestCase(TestCase):
+class AnnTestCase(object):
     """Test Case for ANN."""
 
     neuron_number = NeuronNumber(input=4, hidden=3, output=3)
     data_source = "./test-data.dat"
     error_less_than = 0.0001
     plot_error = True
-    plot_error_step = 1000
+    plot_error_step = 10000
 
-    def setUp(self):
+    def __init__(self):
         # create a neuron net
         self.ann = BackPropagationNetwork(self.neuron_number,
                                           self.learning_rate)
             plot_error(training_result, step=self.plot_error_step)
 
         # prepare test data
-        self.data = data[:115, :]
-        self.result = desired_result[:115, :]
+        self.data = data[115:, :]
+        self.result = desired_result[115:, :]
 
     def test_data(self):
+        error_count = 0
         for index, item in enumerate(self.data):
             desired_result = self.result[index]
             output_activation = self.ann.calculate(item)[1]
-            real_result = np.round_(output_activation.reshape(-1))
-            print desired_result, real_result, output_activation.reshape(-1)
-            self.assertEqual(desired_result, real_result)
+            real_result = np.round_(output_activation[0, ::-1])
+            if not np.all(desired_result == real_result):
+                error_count += 1
+        return error_count, len(self.data)
+
+    def run(self):
+        error_count, iter_count = self.test_data()
+        print("=" * 50)
+        print("Error: %d/%d" % (error_count, iter_count))
+        print("Error Rate: %f%%" % ((error_count / iter_count) * 100))
+        print("=" * 50)
 
 
 if __name__ == "__main__":
-    #runtest()
-    AnnTestCase.__init__ = AnnTestCase.setUp
-    a = AnnTestCase()
-    a.test_data()
+    test = AnnTestCase()
+    test.run()