draft-ann / testbpann.py

#!/usr/bin/env python
#-*- coding:utf-8 -*-

from __future__ import division
from collections import namedtuple

import numpy as np

from bpann import BackPropagationNetwork, NeuronNumber, LearningRate
from bpann import plot_error


def load_and_preprocess_data(filepath):
    """Load data from a text format file."""
    flower = namedtuple("flower", ["typeid", "a", "b", "c", "d"])
    with open(filepath, "r") as datafile:
        dataset = datafile.readlines()
    for record in dataset:
        record = record.strip().split(" ")[1:]
        record = (int(item) for item in record)
        yield flower(*record)


class AnnTestCase(object):
    """Test Case for ANN."""

    neuron_number = NeuronNumber(input=4, hidden=3, output=3)
    learning_rate = LearningRate(0.1, 0.1, 0.85)
    data_source = "./test-data.dat"
    error_less_than = 0.0001
    plot_error = True
    plot_error_step = 10000

    def __init__(self):
        # create a neuron net
        self.ann = BackPropagationNetwork(self.neuron_number,
                                          self.learning_rate)

        # load and preprocess train data
        raw_data = load_and_preprocess_data(self.data_source)
        raw_data = np.array(list(raw_data), dtype=np.float64)
        data = raw_data[:, 1:]
        desired_result = np.array([[(1 if i == t else 0) for i in range(3)]
                                   for t in raw_data[:, 0]], dtype=np.float64)

        # train with 115 records
        training_data = data[:115, :]
        training_desired_result = desired_result[:115, :]
        training = self.ann.train_until(self.error_less_than,
                                        training_data,
                                        training_desired_result)
        training_result = list(training)

        if self.plot_error:
            plot_error(training_result, step=self.plot_error_step)

        # prepare test data
        self.data = data[115:, :]
        self.result = desired_result[115:, :]

    def test_data(self):
        error_count = 0
        for index, item in enumerate(self.data):
            desired_result = self.result[index]
            output_activation = self.ann.calculate(item)[1]
            real_result = np.round_(output_activation[0, ::-1])
            if not np.all(desired_result == real_result):
                error_count += 1
        return error_count, len(self.data)

    def run(self):
        error_count, iter_count = self.test_data()
        print("=" * 50)
        print("Error: %d/%d" % (error_count, iter_count))
        print("Error Rate: %f%%" % ((error_count / iter_count) * 100))
        print("=" * 50)


if __name__ == "__main__":
    test = AnnTestCase()
    test.run()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.