Commits

Jiangge Zhang  committed 150a03f

leave a problem.

  • Participants
  • Parent commits 2b4ee8f

Comments (0)

Files changed (1)

File testbpann.py

 #-*- coding:utf-8 -*-
 
 from collections import namedtuple
+from unittest import TestCase, main as runtest
 
 import numpy as np
 
         yield flower(*record)
 
 
-def train_network(plot=False):
-    # create a neuron net
-    ann = BackPropagationNetwork(NeuronNumber(input=4, hidden=3, output=3),
-                                 LearningRate(0.1, 0.1, 0.85))
+class AnnTestCase(TestCase):
+    """Test Case for ANN."""
 
-    # load and preprocess train data
-    raw_data = load_and_preprocess_data("./test-data.dat")
-    raw_data = np.array(list(raw_data), dtype=np.double)
-    train_data = raw_data[:, 1:]
-    desired_result = np.array([[(1 if i == typeid else 0) for i in range(3)]
-                               for typeid in raw_data[:, 0]], dtype=np.double)
+    neuron_number = NeuronNumber(input=4, hidden=3, output=3)
+    learning_rate = LearningRate(0.1, 0.1, 0.85)
+    data_source = "./test-data.dat"
+    error_less_than = 0.0001
+    plot_error = True
+    plot_error_step = 1000
 
-    # train with 115 records
-    training = ann.train_until(data=train_data[:115, :],
-                               desired_result=desired_result[:115, :],
-                               error_less_than=0.0005)
-    training_result = list(training)
+    def setUp(self):
+        # create a neuron net
+        self.ann = BackPropagationNetwork(self.neuron_number,
+                                          self.learning_rate)
 
-    # plot error
-    if plot:
-        plot_error(training_result, step=10)
+        # load and preprocess train data
+        raw_data = load_and_preprocess_data(self.data_source)
+        raw_data = np.array(list(raw_data), dtype=np.float64)
+        data = raw_data[:, 1:]
+        desired_result = np.array([[(1 if i == t else 0) for i in range(3)]
+                                   for t in raw_data[:, 0]], dtype=np.float64)
 
-    # ! TODO: write a unit test
+        # train with 115 records
+        training_data = data[:115, :]
+        training_desired_result = desired_result[:115, :]
+        training = self.ann.train_until(self.error_less_than,
+                                        training_data,
+                                        training_desired_result)
+        training_result = list(training)
 
-    # return a trained network
-    return ann
+        if self.plot_error:
+            plot_error(training_result, step=self.plot_error_step)
+
+        # prepare test data
+        self.data = data[:115, :]
+        self.result = desired_result[:115, :]
+
+    def test_data(self):
+        for index, item in enumerate(self.data):
+            desired_result = self.result[index]
+            output_activation = self.ann.calculate(item)[1]
+            real_result = np.round_(output_activation.reshape(-1))
+            print desired_result, real_result, output_activation.reshape(-1)
+            self.assertEqual(desired_result, real_result)
 
 
 if __name__ == "__main__":
-    train_network(plot=True)
+    #runtest()
+    AnnTestCase.__init__ = AnnTestCase.setUp
+    a = AnnTestCase()
+    a.test_data()