FrontPage

2011/06/05からのアクセス回数 2578

ここで紹介したSageワークシートは、以下のURLからダウンロードできます。

http://sage.math.canterbury.ac.nz/home/pub/102/

また、Sageのサーバを公開しているサイト(http://sage.math.canterbury.ac.nz/http://www.sagenb.org/)にユーザIDを作成することで、ダウンロードしたワークシートを アップロードし、実行したり、変更していろいろ動きを試すことができます。

第5章-多層パーセプトロン分類問題をSageで試す

PRMLの第5章-多層パーセプトロン分類問題の例題、図5.4 (赤が多層パーセプトロンの結果、緑が最適な分類結果)

Figure5.4.jpg

をSageを使って試してみます。

データの準備

上巻の付録Aにあるように http://research.microsoft.com/~cmbishop/PRML から 例題のデータをダウンロードすることができます。

ここでは人工的に作成された分類問題用のデータclassfication.txtを使用します。

データの読み込みは、numpyのloadtxtを使い、sageのmatrixに変換します。

データの分布は(0を赤、1を青でプロット)します。

sageへの入力:

# ニューラルネットワークで識別を行う
from numpy import loadtxt 

data = loadtxt(DATA+"classification.txt")
# 1,2カラムはx, y、3カラムには、0:red cosses×, 1:blue circle ○がセットされている
# これをx, yと識別子に変換する
X = [[x[0], x[1]] for x in data[:,0:2]]
T = data[:,2]

sageへの入力:

data_plt = Graphics()
for i in range(len(T)):
    if T[i] == 0:
        data_plt += point(X[i], rgbcolor='red')
    else:
        data_plt += point(X[i], rgbcolor='blue')
data_plt.show()

計算条件

分類問題に使用するニューラルネットワーク(多層パーセプトロン)は、

  • 入力ユニット数2個(バイアス項を除く)
  • 隠れ層ユニット2個(バイアス項を除く)、活性化関数はtanh
  • 出力ユニット1個、活性化関数はロジスティックシグモイド $$ \sigma(a) = \frac{1}{1 + exp(-a)} $$ としました。

sageへの入力:

# 入力1点、隠れ層3個、出力1点
N_in = 2   
N_h = 6     
N_out = 1
# 隠れ層の活性化関数
h_1(x) = tanh(x)
# 出力層の活性化関数
h_2(x) = 1/(1 + e^(-x))

プログラムのロード

SCG_Class で使った、 ニューラルネットワーク(NeuralNetwork.sage)、スケール共役勾配法(SCG.sage) をプログラムとして読み込みます。

sageへの入力:

# ニューラルネットワーク用のクラスのロード
attach "NeuralNetwork.sage"
# Scaled Conjugate Gradients(スケール共役勾配法)のロード
attach "SCG.sage"

sageへの入力:

# 出力関数
def _f(x, net):
    y = net.feedForward(x)
    return y[0]

人工データの識別(分類問題を解く)

人工データの識別はとても簡単です。共役勾配法(SCG_Class)で使用した設定 の内、

  • 出力層の活性化関数をロジスティックシグモイドに変更
  • 初期値の分散を0.5に変更

するだけです。

最終的な収束には、1時間程度かかります。収束の過程を500回毎に表示し、最終的な 識別面と最後に表示します。

sageへの入力:

# 定数設定
LEARN_COUNT = 500
RESET_COUNT = 25
beta = 0.5
# ニューラルネットワークの生成
net = NeuralNetwork(N_in, N_h, N_out, h_1, h_2, beta)
# wの初期値を保存
saved_w = net.getW()
# 逐次勾配降下法による学習
_learn_csg(net, X, T, LEARN_COUNT, RESET_COUNT)
2 0.216907574411 18.1512434768
3 0.367274198164 18.0991262322
4 0.342563144731 18.0575836323
途中省略
497 0.510679849338 16.7768761226
498 0.510738410001 16.7760287748
499 0.510796526139 16.775182439
500 0.510854199995 16.7743371123

sageへの入力:

var('x y')
f = lambda x, y : _f([x, y], net)
# sage 4.6.2の場合、fill=Falseを追加する
cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3])
(cnt_plt + data_plt).show()
# 500回目の図

sage0.png

sageへの入力:

saved_w = net.getW(); print saved_w
(0.448813729769, -0.611233027386, 0.486835724148, -0.290585284889, -1.03630665196, 0.452577662537, -1.14585774724, 
0.571240013162, -1.03099854342, 0.1503080603, 0.817838469072, 0.467411325316, -0.0277714945845, 0.979474053476, 
-0.039090385248, 0.702582727597, -0.880592325533, -0.783731911532, 0.0687552540301, -0.19079101757, -0.0922401202569, 
0.561182751992, 0.512563236679, -0.78366734313, -0.583498381549, -1.00840952578, 0.726987191536, 0.471177919157)

sageへの入力:

# ニューラルネットワークの生成
net = NeuralNetwork(N_in, N_h, N_out, h_1, h_2, beta)
net.setW(saved_w)
# 逐次勾配降下法による学習
_learn_csg(net, X, T, LEARN_COUNT, RESET_COUNT)
saved_w = net.getW(); print saved_w
2 0.201388807123 16.0539445534
2 -0.0226802116084 16.0909970491
4 0.336295264963 15.9617385159
途中省略
499 0.45831937529 15.735211345
500 0.458188960582 15.7350770604
(0.463683030527, -0.56293207292, 0.218480621204, -0.31491635651, -1.27681774227, -0.0716986165936, -1.10320850555, 
0.402202160422, -1.20023849386, 0.12705414136, 0.845197018616, 0.107344161584, 0.247807528573, 1.39970048134, 
0.0420225004071, 1.37754376505, -0.541677357113, -1.04233421748, 0.116739905427, -0.116106177509, -0.479015772634, 
0.627074347258, 0.905236086429, -1.31173882589, -0.316982725192, -1.49418534171, 1.15029017764, 0.57536654947)

sageへの入力:

# sage 4.6.2の場合、fill=Falseを追加する
cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3])
(cnt_plt + data_plt).show()
# 1000回目の図

sage0-1.png

sageへの入力:

# ニューラルネットワークの生成
net = NeuralNetwork(N_in, N_h, N_out, h_1, h_2, beta)
net.setW(saved_w)
# 逐次勾配降下法による学習
_learn_csg(net, X, T, LEARN_COUNT, RESET_COUNT)
saved_w = net.getW(); print saved_w
2 0.208822794387 15.5938250686
2 -0.673620650045 15.8350168407
4 0.405394880115 15.5670978861
途中省略
30 0.652380586982 15.5650231922
30 -3.26190293491 15.5650231922
31 -4.93211531559 15.5650231922
(0.484677123159, -0.437904913112, 0.00510633275617, -0.289549653944, -1.21616826687, 0.0898450756053, -0.795481159001, 
0.463839739874, -1.17314384739, 0.138351578124, 0.820705220596, 0.0979314611414, 0.219918450255, 1.30551657334, 
-0.2256257687, 1.46921682627, -0.474870542331, -1.32689691961, 0.128038180889, -0.034594526316, -0.581518239396, 
0.707970942864, 0.994897607746, -1.20098576529, -0.243835214816, -1.63574350655, 1.37831269014, 0.715120352325)

sageへの入力:

# sage 4.6.2の場合、fill=Falseを追加する
cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3])
(cnt_plt + data_plt).show()
# 1273回目の図

sage0-2.png

sageへの入力:

cnt_plt = contour_plot(f, [x, -3, 3], [y, -3, 3], contours=(0.5,), fill=False)
(cnt_plt + data_plt).show()
# z=0.5の等高線

sage0-3.png

コメント

選択肢 投票
おもしろかった 2  
そうでもない 0  
わかりずらい 0  

皆様のご意見、ご希望をお待ちしております。


(Input image string)


添付ファイル: filesage0-3.png 389件 [詳細] filesage0-2.png 365件 [詳細] filesage0-1.png 351件 [詳細] filesage0.png 391件 [詳細]

トップ   編集 凍結解除 差分 バックアップ 添付 複製 名前変更 リロード   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS
Last-modified: 2013-02-18 (月) 10:58:22 (1587d)
SmartDoc