NeuralNetwork_Discriminant

5031 days ago by takepwave

Hiroshi TAKEMOTO (take@pwv.co.jp)

第5章-多層パーセプトロン分類問題をSageで試す

PRMLの第5章-多層パーセプトロン分類問題の例題、図5.4 (赤が多層パーセプトロンの結果、緑が最適な分類結果)

をSageを使って試してみます。

データの準備

上巻の付録Aにあるように http://research.microsoft.com/~cmbishop/PRML から 例題のデータをダウンロードすることができます。

ここでは人工的に作成された分類問題用のデータclassfication.txtを使用します。

データの読み込みは、numpyのloadtxtを使い、sageのmatrixに変換します。

データの分布は(0を赤、1を青でプロット)します。

# ニューラルネットワークで識別を行う from numpy import loadtxt data = loadtxt(DATA+"classification.txt") # 1,2カラムはx, y、3カラムには、0:red cosses×, 1:blue circle ○がセットされている # これをx, yと識別子に変換する X = [[x[0], x[1]] for x in data[:,0:2]] T = data[:,2] 
       
data_plt = Graphics() for i in range(len(T)): if T[i] == 0: data_plt += point(X[i], rgbcolor='red') else: data_plt += point(X[i], rgbcolor='blue') data_plt.show() 
       

計算条件

分類問題に使用するニューラルネットワーク(多層パーセプトロン)は、

  • 入力ユニット数2個(バイアス項を除く)
  • 隠れ層ユニット2個(バイアス項を除く)、活性化関数はtanh
  • 出力ユニット1個、活性化関数はロジスティックシグモイド
  • $$ \sigma(a) = \frac{1}{1 + exp(-a)} $$
としました。

# 入力1点、隠れ層3個、出力1点 N_in = 2 N_h = 6 N_out = 1 # 隠れ層の活性化関数 h_1(x) = tanh(x) # 出力層の活性化関数 h_2(x) = 1/(1 + e^(-x)) 
       

プログラムのロード

SCG_Classで 使った、 ニューラルネットワーク(NeuralNetwork.sage)、スケール共役勾配法(SCG.sage) をプログラムとして読み込みます。

# ニューラルネットワーク用のクラスのロード attach "NeuralNetwork.sage" # Scaled Conjugate Gradients(スケール共役勾配法)のロード attach "SCG.sage" 
       
# 出力関数 def _f(x, net): y = net.feedForward(x) return y[0] 
       

人工データの識別(分類問題を解く)

人工データの識別はとても簡単です。共役勾配法(SCG_Class)で使用した設定 の内、

  • 出力層の活性化関数をロジスティックシグモイドに変更
  • 初期値の分散を0.5に変更
するだけです。

最終的な収束には、1時間程度かかります。収束の過程を500回毎に表示し、最終的な 識別面と最後に表示します。

# 定数設定 LEARN_COUNT = 500 RESET_COUNT = 25 beta = 0.5 # ニューラルネットワークの生成 net = NeuralNetwork(N_in, N_h, N_out, h_1, h_2, beta) # wの初期値を保存 saved_w = net.getW() # 逐次勾配降下法による学習 _learn_csg(net, X, T, LEARN_COUNT, RESET_COUNT) 
       
2 0.216907574411 18.1512434768
3 0.367274198164 18.0991262322
4 0.342563144731 18.0575836323
途中省略
497 0.510679849338 16.7768761226
498 0.510738410001 16.7760287748
499 0.510796526139 16.775182439
500 0.510854199995 16.7743371123
2 0.216907574411 18.1512434768
3 0.367274198164 18.0991262322
4 0.342563144731 18.0575836323
途中省略
497 0.510679849338 16.7768761226
498 0.510738410001 16.7760287748
499 0.510796526139 16.775182439
500 0.510854199995 16.7743371123
var('x y') f = lambda x, y : _f([x, y], net) # sage 4.6.2の場合、fill=Falseを追加する cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3]) (cnt_plt + data_plt).show() # 500回目の図 
       
saved_w = net.getW(); print saved_w 
       
(0.448813729769, -0.611233027386, 0.486835724148, -0.290585284889,
-1.03630665196, 0.452577662537, -1.14585774724, 0.571240013162,
-1.03099854342, 0.1503080603, 0.817838469072, 0.467411325316,
-0.0277714945845, 0.979474053476, -0.039090385248, 0.702582727597,
-0.880592325533, -0.783731911532, 0.0687552540301, -0.19079101757,
-0.0922401202569, 0.561182751992, 0.512563236679, -0.78366734313,
-0.583498381549, -1.00840952578, 0.726987191536, 0.471177919157)
(0.448813729769, -0.611233027386, 0.486835724148, -0.290585284889, -1.03630665196, 0.452577662537, -1.14585774724, 0.571240013162, -1.03099854342, 0.1503080603, 0.817838469072, 0.467411325316, -0.0277714945845, 0.979474053476, -0.039090385248, 0.702582727597, -0.880592325533, -0.783731911532, 0.0687552540301, -0.19079101757, -0.0922401202569, 0.561182751992, 0.512563236679, -0.78366734313, -0.583498381549, -1.00840952578, 0.726987191536, 0.471177919157)
# ニューラルネットワークの生成 net = NeuralNetwork(N_in, N_h, N_out, h_1, h_2, beta) net.setW(saved_w) # 逐次勾配降下法による学習 _learn_csg(net, X, T, LEARN_COUNT, RESET_COUNT) saved_w = net.getW(); print saved_w 
       
2 0.201388807123 16.0539445534
2 -0.0226802116084 16.0909970491
4 0.336295264963 15.9617385159
途中省略
499 0.45831937529 15.735211345
500 0.458188960582 15.7350770604
(0.463683030527, -0.56293207292, 0.218480621204, -0.31491635651,
-1.27681774227, -0.0716986165936, -1.10320850555, 0.402202160422,
-1.20023849386, 0.12705414136, 0.845197018616, 0.107344161584,
0.247807528573, 1.39970048134, 0.0420225004071, 1.37754376505,
-0.541677357113, -1.04233421748, 0.116739905427, -0.116106177509,
-0.479015772634, 0.627074347258, 0.905236086429, -1.31173882589,
-0.316982725192, -1.49418534171, 1.15029017764, 0.57536654947)
2 0.201388807123 16.0539445534
2 -0.0226802116084 16.0909970491
4 0.336295264963 15.9617385159
途中省略
499 0.45831937529 15.735211345
500 0.458188960582 15.7350770604
(0.463683030527, -0.56293207292, 0.218480621204, -0.31491635651, -1.27681774227, -0.0716986165936, -1.10320850555, 0.402202160422, -1.20023849386, 0.12705414136, 0.845197018616, 0.107344161584, 0.247807528573, 1.39970048134, 0.0420225004071, 1.37754376505, -0.541677357113, -1.04233421748, 0.116739905427, -0.116106177509, -0.479015772634, 0.627074347258, 0.905236086429, -1.31173882589, -0.316982725192, -1.49418534171, 1.15029017764, 0.57536654947)
# sage 4.6.2の場合、fill=Falseを追加する cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3]) (cnt_plt + data_plt).show() # 1000回目の図 
       
# ニューラルネットワークの生成 net = NeuralNetwork(N_in, N_h, N_out, h_1, h_2, beta) net.setW(saved_w) # 逐次勾配降下法による学習 _learn_csg(net, X, T, LEARN_COUNT, RESET_COUNT) saved_w = net.getW(); print saved_w 
       
2 0.208822794387 15.5938250686
2 -0.673620650045 15.8350168407
4 0.405394880115 15.5670978861
途中省略
30 0.652380586982 15.5650231922
30 -3.26190293491 15.5650231922
31 -4.93211531559 15.5650231922
(0.484677123159, -0.437904913112, 0.00510633275617, -0.289549653944,
-1.21616826687, 0.0898450756053, -0.795481159001, 0.463839739874,
-1.17314384739, 0.138351578124, 0.820705220596, 0.0979314611414,
0.219918450255, 1.30551657334, -0.2256257687, 1.46921682627,
-0.474870542331, -1.32689691961, 0.128038180889, -0.034594526316,
-0.581518239396, 0.707970942864, 0.994897607746, -1.20098576529,
-0.243835214816, -1.63574350655, 1.37831269014, 0.715120352325)
2 0.208822794387 15.5938250686
2 -0.673620650045 15.8350168407
4 0.405394880115 15.5670978861
途中省略
30 0.652380586982 15.5650231922
30 -3.26190293491 15.5650231922
31 -4.93211531559 15.5650231922
(0.484677123159, -0.437904913112, 0.00510633275617, -0.289549653944, -1.21616826687, 0.0898450756053, -0.795481159001, 0.463839739874, -1.17314384739, 0.138351578124, 0.820705220596, 0.0979314611414, 0.219918450255, 1.30551657334, -0.2256257687, 1.46921682627, -0.474870542331, -1.32689691961, 0.128038180889, -0.034594526316, -0.581518239396, 0.707970942864, 0.994897607746, -1.20098576529, -0.243835214816, -1.63574350655, 1.37831269014, 0.715120352325)
# sage 4.6.2の場合、fill=Falseを追加する cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3]) (cnt_plt + data_plt).show() # 1273回目の図 
       
cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3], contours=(0.5,), fill=False) (cnt_plt + data_plt).show() # z=0.5の等高線