多層パーセプトロンをPythonで実装してみた-手書き数字認識-
前回(↓)の続き。
簡単な分類で動作を確認できたので次は数字認識をやってみる。
多層パーセプトロンをPythonで実装してみた-1- - 日曜技術者のメモ
で書いたクラスは出力層が二個以上に対応していなかったので修正して使った。
ブログは修正済み。
数字の分類では出力層にソフトマックス関数を使って出力を
[-1,1]ではなく確立にするのが普通の様だが今回は出力層をそのままにして
学習データは正しい数字は0.5、その他の数字は0として学習した。
sklearn.datasets
sklearn.datasetsに手書き文字があったのでまずはそれでやってみる。
The Digit Dataset — scikit-learn 0.18 documentation
8x8の数字なのでかなり荒い。
入力層は8pix×8pix=64個
出力層は0から9の10個
学習率0.03
600文字学習を繰り返す。
実装コード
def digit_sepalate(): #テストデータ作成 iteration_num = 100000 sample_num = 600 digits = load_digits() x = digits.data y = digits.target x_max = x.max() x /= x.max() x *= 0.5 iLayer = nn.inputLayer(64,10,10) oLayer = nn.outputLayer(64,10,1,0.05) inputLayerY = np.zeros(64) dut_output = np.zeros(10) x1_array = np.arange(0,8,1) x2_array = np.arange(0,8,1) pos_x1,pos_x2 = np.meshgrid(x1_array,x2_array) val_y = np.identity(40) plt.ion() match_list = [] cnt_list = [] shuffle_list = list(range(sample_num)) plt.pause(0.05) for i in range(iteration_num): match = 0 for j in range(sample_num): inputLayerY = x[j] iLayer.foward(inputLayerY) oLayer.foward(iLayer.getY()) exp_y = np.zeros(10) - 1 dut_output = oLayer.getY() exp_y[y[j]] = 1.0 oLayer.backprop(exp_y) oLayer.updateWeight() if(dut_output.argmax() == digits.target[j]): match += 1 #次の学習時に順番をシャッフルする random.shuffle(shuffle_list) if (1): plt.clf() match_list.append(match/sample_num) cnt_list.append(i) for i in range(10): plt.subplot(3,5,6+i) plt.xlim(0,7) plt.ylim(0,7) plt.pcolor(pos_x1,pos_x2,oLayer._weight[i].reshape(8,8)) plt.subplot(3,1,1) plt.plot(cnt_list,match_list) plt.pause(0.05) while True: plt.pause(0.05)
↓は実行結果
10個のヒートマップは各出力層の重み値を出力している。
うまくいけば数字に見えるが・・・微妙。
数回の学習でかなりの精度になった。
MNIST
次は有名なMNISTでやってみる。
データは以下にあるがsklearnで関数が容易されているのでそっちを使った。
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
入力層は28pix×28pix=784個
出力層は0から9の10個
学習率0.03
1000文字学習を繰り返す。
↓は実行結果
10回の学習で正解率91%!!
幾ら何でも学習速すぎないかと思っていたが
これは学習した数字で学習率を見てるので速いのは当たり前だった
次は学習する数字と正解率を評価する数字を別々にしてみる。
正解率は100文字中の正解数
1000文字学習
5000文字学習
10000文字学習
50000文字学習
さっきよりは学習に時間がかかっているが問題なく分類できてる
重みのヒートマップも0,3,8はそれっぽい。