Commits

Jiangge Zhang committed 2b4ee8f

add test module.

Comments (0)

Files changed (3)

+*.pyc
+*~
+*.sublime-*
 from collections import namedtuple
 
 import numpy as np
-import matplotlib.pyplot as plt
+try:
+    import matplotlib.pyplot as plt
+except ImportError:
+    plt = None
 
 
 NeuronNumber = namedtuple("NeuronNumber", ["input", "hidden", "output"])
         data = ((data - min_values) / (max_values - min_values) - 0.5) * 2
 
         # iterate training network
-        alpha, beta, gamma = self.learning_rate  
+        alpha, beta, gamma = self.learning_rate
         evaluate_error = 0
         for index, item in enumerate(data):
             # desired result of current item
         return hidden_activation, output_activation
 
 
-# -------------
-# Training Case
-# -------------
-
-def load_and_preprocess_data(filepath):
-    """Load data from a text format file."""
-    flower = namedtuple("flower", ["typeid", "a", "b", "c", "d"])
-    with open(filepath, "r") as datafile:
-        dataset = datafile.readlines()
-    for record in dataset:
-        record = record.strip().split(" ")[1:]
-        record = (int(item) for item in record)
-        yield flower(*record)
-
-
-def main():
-    # create a neuron net
-    bpnet = BackPropagationNetwork(NeuronNumber(input=4, hidden=3, output=3),
-                                   LearningRate(0.1, 0.1, 0.85))
-
-    # load and preprocess train data
-    raw_data = load_and_preprocess_data("./test-data.dat")
-    raw_data = np.array(list(raw_data), dtype=np.double)
-    train_data = raw_data[:, 1:]
-    desired_result = np.array([[(1 if i == typeid else 0) for i in range(3)]
-                               for typeid in raw_data[:, 0]], dtype=np.double)
-
-    # train
-    training = bpnet.train_until(data=train_data[:115, :],
-                                 desired_result=desired_result[:115, :],
-                                 error_less_than=0.0005)
-    training_result = list(training)
-    
-    x = np.arange(len(training_result), step=10)
-    y = np.array([training_result[ix] for ix in x])
-    plt.figure(figsize=(8, 4))
+def plot_error(training_error, step=10, title=""):
+    """Plot the error with matplotlib"""
+    assert plt, "You should install the matplotlib."
+    x = np.arange(len(training_error), step=10)
+    y = np.array([training_error[ix] for ix in x])
+    plt.figure()
     plt.plot(x, y, label="error (x)", color="blue")
     plt.xlabel("training times")
-    plt.ylabel("error")
-    plt.title("BP Demo")
+    plt.ylabel("training error")
+    plt.title(title)
     plt.legend()
     plt.show()
-
-
-if __name__ == "__main__":
-    main()
+#!/usr/bin/env python
+#-*- coding:utf-8 -*-
+
+from collections import namedtuple
+
+import numpy as np
+
+from bpann import BackPropagationNetwork, NeuronNumber, LearningRate
+from bpann import plot_error
+
+
+def load_and_preprocess_data(filepath):
+    """Load data from a text format file."""
+    flower = namedtuple("flower", ["typeid", "a", "b", "c", "d"])
+    with open(filepath, "r") as datafile:
+        dataset = datafile.readlines()
+    for record in dataset:
+        record = record.strip().split(" ")[1:]
+        record = (int(item) for item in record)
+        yield flower(*record)
+
+
+def train_network(plot=False):
+    # create a neuron net
+    ann = BackPropagationNetwork(NeuronNumber(input=4, hidden=3, output=3),
+                                 LearningRate(0.1, 0.1, 0.85))
+
+    # load and preprocess train data
+    raw_data = load_and_preprocess_data("./test-data.dat")
+    raw_data = np.array(list(raw_data), dtype=np.double)
+    train_data = raw_data[:, 1:]
+    desired_result = np.array([[(1 if i == typeid else 0) for i in range(3)]
+                               for typeid in raw_data[:, 0]], dtype=np.double)
+
+    # train with 115 records
+    training = ann.train_until(data=train_data[:115, :],
+                               desired_result=desired_result[:115, :],
+                               error_less_than=0.0005)
+    training_result = list(training)
+
+    # plot error
+    if plot:
+        plot_error(training_result, step=10)
+
+    # ! TODO: write a unit test
+
+    # return a trained network
+    return ann
+
+
+if __name__ == "__main__":
+    train_network(plot=True)