Snippets

7shi n [Python] MNISTを認識するNNの画像化

Created by 7shi n last modified
import pickle
import numpy as np
from PIL import Image

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def softmax(x):
    e = np.exp(x - np.max(x))
    return e / np.sum(e)

def predict(x):
    a1 = np.dot(x, W1) + b1
    z1 = sigmoid(a1)
    a2 = np.dot(z1, W2) + b2
    z2 = sigmoid(a2)
    a3 = np.dot(z2, W3) + b3
    z3 = softmax(a3)
    return z1, z2, a3, z3

def toimg(img, imin, imax):
    img = (img - imin) * 255 / (imax - imin)
    return np.uint8(img.reshape(28, 28))

def amax(imgs):
    return max([max(abs(x)) for x in imgs])

def concat(imgs, w, h):
    ws, hs = [0] * w, [0] * h
    for y in range(h):
        for x in range(w):
            img = imgs[x + y * w]
            ws[x] = max(ws[x], img.shape[1])
            hs[y] = max(hs[y], img.shape[0])
    ret = Image.new("RGB", (sum(ws) + w - 1, sum(hs) + h - 1), "white")
    px, py = 0, 0
    for y in range(h):
        px = 0
        for x in range(w):
            img = imgs[x + y * w]
            dx = int((ws[x] - img.shape[1]) / 2)
            dy = int((hs[y] - img.shape[0]) / 2)
            ret.paste(Image.fromarray(img), (px + dx, py + dy))
            px += ws[x] + 1
        py += hs[y] + 1
    return ret

def hist(img):
    vs = [0] * 28
    for p in img.reshape(img.size):
        vs[int(p * 28 / 256)] += 1
    h = np.zeros((28, 28))
    for x in range(28):
        v = 28 - (vs[x] + 27) / 28
        for y in range(28):
            if y < v:
                h[y, x] = 1
    return toimg(h, 0, 1)

def toimgs(imgs, w, h):
    am = amax(imgs)
    return concat([toimg(img, -am, am) for img in imgs], w, h)

with open("sample_weight.pkl", "rb") as f:
    network = pickle.load(f)

with open("mnist.pkl", "rb") as f:
    mnist = pickle.load(f)

# (784, 50), (50, 100), (100, 10)
W1, W2, W3, b1, b2, b3 = (
    network["W1"], network["W2"], network["W3"],
    network["b1"], network["b2"], network["b3"])

# (10000,), (10000, 784)
test_label, test_img = (
    mnist["test_label"], mnist["test_img"])

sp = np.uint8([[255] * 4] * 28)
pt = np.uint8([[255] * 4] * 20 + [[0,0,0,0]] * 4 + [[255] * 4] * 4)
num, nimg = [], []
i, j = 0, 0
while i < 10:
    if i == test_label[j]:
        img = test_img[j] / 255
        num += [img]
        nimg += [toimg(img, 1, 0)]
        i += 1
    j += 1

def chimg(ch):
    if "0" <= ch <= "9":
        return nimg[int(ch)]
    if ch == ".":
        return pt
    return sp

z1_0, z2_0, a3_0, z3_0 = predict(np.zeros(784))

def diff(x):
    l = x.size
    z1s = np.zeros((z1_0.size, l))
    z2s = np.zeros((z2_0.size, l))
    a3s = np.zeros((a3_0.size, l))
    z3s = np.zeros((z3_0.size, l))
    for i in range(l):
        img = np.zeros(l)
        img[i] = x[i]
        z1, z2, a3, z3 = predict(img)
        for j in range(z1.size): z1s[j, i] = z1[j] - z1_0[j]
        for j in range(z2.size): z2s[j, i] = z2[j] - z2_0[j]
        for j in range(a3.size): a3s[j, i] = a3[j] - a3_0[j]
        for j in range(z3.size): z3s[j, i] = z3[j] - z3_0[j]
    return z1s, z2s, a3s, z3s

def diffimg(n, x):
    z1s, z2s, a3s, z3s = diff(x)
    z1 , z2 , a3 , z3  = predict(x)
    imgs = [
        toimg(x, 0, 1), sp,
        np.asarray(toimgs(z1s,  5, 10)), sp,
        np.asarray(toimgs(z2s, 10, 10)), sp,
        np.asarray(toimgs(a3s,  1, 10)),
        np.asarray(concat([nimg[i] for i in range(10)], 1, 10)),
        np.asarray(toimgs(z3s,  1, 10)),
        np.asarray(concat([chimg(ch) for v in z3 for ch in "%0.5f" % v], 7, z3.size))]
    concat(imgs, len(imgs), 1).save(n + ".png")

diffimg("white", np.array([1.0] * 784))
for i in range(10):
    diffimg(str(i), num[i])

Comments (0)