FrontPage

2010/03/16からのアクセス回数 14119

このページのsageノートブックは、以下のURLにあります。

http://www.sagenb.org/home/pub/1762

はじめに

SVMは、オーバーフィッティングを避けて効率よく識別関数を求めることができる 手法です。

「集合知」の9章で紹介されているSVMをSageを使ってまとめてみます。

簡単な例題

いきなりSVMに進む前にクラス分けを簡単な例を使って解いてみます。

下図は、赤(-1)のグループと青(1)のグループの分布です。

sageへの入力:

# 線形クラス分類の例
# データの用意
c1 = [[1,2],[1,4],[2,4]]
c2 = [[2,1],[5,1],[4,2]]
# プロットして分布を確認
pl1 = list_plot(c1, rgbcolor='red')
pl2 = list_plot(c2, rgbcolor ='blue')
(pl1+pl2).show(xmin=0, xmax=5, ymin=0)

1.png

線形分類

青点の値に1、赤点の値に-1をセットし、\(w_1 x_1 + w_2 x_2 + b\)の線形モデルを find_fit関数を使って解いて、\(w_1, w_2, b\)の値を求めます。

find_fitのデータは、\(x_1, x_2, 値\)の順にセットします。

sageへの入力:

# 各クラスに判別値をセット
v1 = [-1, -1, -1]
v2 = [1, 1, 1]
# (x, y, 判別値)のリストを作成
data = [flatten((pt, v)) for (pt, v) in zip(c1 + c2, v1 + v2)]
data
[[1, 2, -1], [1, 4, -1], [2, 4, -1], [2, 1, 1], [5, 1, 1], [4, 2, 1]]

sageへの入力:

# 最もフィットするw1, w2, bを求める
var('x1 x2 w1 w2 b')
model(x1, x2) = w1*x1 + w2*x2 + b
fit = find_fit(data, model, solution_dict=True); print fit
# 求まった解(判別式)を返す関数を定義します
f(x1, x2) = model.subs(fit);
{b: 0.19629629629454115, w1: 0.32592592592445591, w2: -0.43333333333645996}

結果の表示

implicit_plot関数を使ってデータ(赤、青)と判別式が0となる線を表示します。 上手く赤と青の点を分離しているのが分かります。

sageへの入力:

# 判別式が0の線を表示
pl6 = implicit_plot(f(x1, x2) == 0, (x1, 0, 5), (x2, 0, 5))
(pl1 + pl2 +pl6).show(xmin=0, xmax=5, ymin=0)

2.png

求まった判別式がどのような形なのかplot3dを使って表示してみます。 データ(赤、青)の判別式の値も合わせて表示してみます。

一つの平面上として表されています。

sageへの入力:

pl3 = plot3d(f(x1, x2), (x1, 0, 5), (x2, 0, 5))
pl4 = list_plot([(x, y, f(x, y)) for (x, y) in c1], rgbcolor='red')
pl5 = list_plot([(x, y, f(x, y)) for (x, y) in c2], rgbcolor='blue')
(pl3+pl4+pl5).show(xmin=0, xmax=5)

3.jpeg

SVMを使った分類

SVM(サポートベクターマシン)は、クラスの境界線と分離平面(超平面) の距離(マージン)を最大になるようにクラスを分類します。

先ほどの例では、下図のように境界線(半線)の中間に\(y = x\)の分離平面が、 あります。(線形分類の境界線とずれていることに注意して下さい)

境界線を求めるとき使った学習データのことを「サポートベクター」と呼びます。

sageへの入力:

# SVMでは、クラスの境界線と分離平面(超平面)の距離を最大になるように求めます
sv = plot(lambda x : x, (x, 0, 5))
cl1 = plot(lambda x : x + 1, (x, 0, 5), linestyle='dashed')
cl2 = plot(lambda x : x - 1, (x, 0, 5), linestyle='dashed')
(sv+cl1+cl2+pl1+pl2).show(xmin=0, xmax=5)
# 青の実線が分離平面で、半線がクラスの境界線です

4.png

LIBSVMのインストール

sageを使ってSVMを計算したいところですが、sageにはSVMを計算する関数がありません。 そこで、LIBSVMをインストールします。

LIBSVMの ホームページ のDownload LIBSVMから最新のtar.gzファイルをダウンロードします。

ダウンロードしたファイル(filelibsvm-2.9.tar.gz)を適当な場所(~/local)で解凍します。

$ tar xzvf libsvm-2.9.tar.gz
$ cd libsvm-2.9/python	

MacOSXの場合、Makefileの一部を変更します。

#LDFLAGS = -shared
# Mac OS
LDFLAGS = -framework Python -bundle			

また、そのままではsageで動かなかったので、svm.pyの以下の2点にfloatのキャストを 追加するように修正してください。

126c126
< 		svmc.svm_node_array_set(data,j,k,x[k])
---
> 		svmc.svm_node_array_set(data,j,k,float(x[k]))
138c138
< 			svmc.double_setitem(y_array,i,y[i])			
---
> 			svmc.double_setitem(y_array,i,float(y[i]))

sage内部のpythonにLIBSVMをインストールするには、sageのpythonで seup.pyを実行します。私は、sageを~/localにインストールしているので

$ ~/local/sage/local/bin/python setup.py install			

と実行しました。

正しく動くか、集合知の9.9.2の例題で確認します。 無事、同じ結果がでたので、次に進みます。

sageへの入力:

# svm.pyを2カ所修正した
# 動作を確認するために、集合知の9.9.2の例題を実行する
from svm import *
prob = svm_problem([1,-1],[[1,0,1],[-1,0,-1]])
param = svm_parameter(kernel_type = LINEAR, C=float(10))
m = svm_model(prob, param)
*
optimization finished, #iter = 1
nu = 0.025000
obj = -0.250000, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 2

sageへの入力:

m.predict([1,1,1])
1.0

LIBSVMで簡単な例題を解く

LIBSVMへの入力は、各データどのクラスに属するかを示すclsと データのベクトル値datを引数に取ります。

sageへの入力:

# LIBSVM用にデータを加工する
cls = v1 + v2; print cls
dat = c1 + c2; print dat
[-1, -1, -1, 1, 1, 1]
[[1, 2], [1, 4], [2, 4], [2, 1], [5, 1], [4, 2]]

sageへの入力:

# 例題にSVMを適応
prob = svm_problem(cls, dat)
m = svm_model(prob, param)
*
optimization finished, #iter = 4
nu = 0.033333
obj = -1.000000, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 2

データが正しく分類されているか赤(1,4)の座標を入力すると、-1.0が返され、 正しい結果が返ってきます。

sageへの入力:

m.predict([1, 4])
-1.0

結果の表示

contour_plotを使って、SVMの予想をコンタマップで表示すると、\(y = x\) の線でうまく、分類されているのが、分かります。

sageへの入力:

contour_plot(lambda x, y : m.predict([x, y]), (x, 0, 5), (y, 0, 5))

5.png

さらに、各点(赤、青)の予想値と境界線の形状をlist_plot3で表示します。

sageへの入力:

# plot3dがうまく表示できないので、list_plot3dで代替
pl7 = list_plot3d([(x, y, m.predict([x, y])) for x in srange(0, 5, 0.1) for y in srange(0, 5, 0.1)])
pl8 = list_plot([(x, y, m.predict([x, y])) for (x, y) in c1], rgbcolor='red')
pl9 = list_plot([(x, y, m.predict([x, y])) for (x, y) in c2], rgbcolor='blue')
(pl8 + pl9 +pl7).show(xmin=0, xmax=5, ymin=0)

6.jpeg

SVMのカーネルメソッドのすごさを確かめるために点の値がチェスボードのように 格子状に部分布する点を分析することにします。

まずは、チェスボードのマスの値を返すchessBox関数を作成します。 ただしく、値がセットされるか、conour_plotでみてみましょう。

sageへの入力:

# チェスボックスの例
def chessBox(x, y):
    if ((int(x)+int(y))%2) == 0: 
        return 1 
    else: 
        return -1  
#[chessBox(x, y) for x in range(0,5) for y in range(0,5)]
contour_plot(chessBox, (x, 0, 5), (y, 0, 5))

7.png

テストデータの作成

0から5の範囲にランダムに点を生成し、chessBox関数で赤(red)と青(blue) に振り分けます。

sageへの入力:

# ランダムな点を生成
rndPts = [[5*random(), 5*random()] for i in range(0,1000)];

Kernel関数LINEARの場合

kernel_type=LINEAR(線形を意味する)を使って分類すると、 うまく格子状のデータを分類することができません。

sageへの入力:

red  = [pt for pt in rndPts if chessBox(pt[0], pt[1]) == -1];
blue = [pt for pt in rndPts if chessBox(pt[0], pt[1]) == 1];
redCls  = [-1 for pt in red]
blueCls = [1 for pt in blue]
chessRedPlt  = list_plot([(x, y) for (x, y) in red], rgbcolor='red')
chessBluePlt = list_plot([(x, y) for (x, y) in blue], rgbcolor='blue')
(chessRedPlt+chessBluePlt).show(xmin=0, xmax=5)

8.png

sageへの入力:

# カーネルメソッドをLINEARだとうまく分けられない
prob = svm_problem(redCls + blueCls, red + blue)
param = svm_parameter(kernel_type = LINEAR, C=float(1000))
m = svm_model(prob, param)
contour_plot(lambda x, y : m.predict([x, y]), (x, 0, 5), (y, 0, 5))
...................................................................
Warning: using -h 0 may be faster
*..................................................................
Warning: using -h 0 may be faster
*..................................................................
optimization finished, #iter = 2746026
nu = 0.992055
obj = -992226.807639, rho = 1.007356
nSV = 994, nBSV = 991
Total nSV = 994

9.png

Kernel関数RBFの場合

次に、kernel_type=RBF(ガウシアンカーネル関数)を使って分類すると、 なんとなく格子状に分類しているように見えます。

ガウシアンカーネル関数は、 $$ K(x, x') = e ^{\left( - \frac{| x - x' |^2}{\sigma^2} \right)} $$

sageへの入力:

# カーネルメソッドをRBFにするとある程度格子の形が見て取れる
prob = svm_problem(redCls + blueCls, red + blue)
param = svm_parameter(kernel_type = RBF, C=float(1000))
m = svm_model(prob, param)
contour_plot(lambda x, y : m.predict([x, y]), (x, 0, 5), (y, 0, 5))
..........................................*...............................*...............*
optimization finished, #iter = 87847
nu = 0.199029
obj = -148750.124372, rho = 13.104982
nSV = 223, nBSV = 174
Total nSV = 223

10.png

簡単な画像認識

最後に簡単な画像認識をLIBSVMを使って確かめてみます。

5x5のマス目に数字の0から4までを書いた画像5個(本当に少ないです!) をテストデータとして使用します。

sageへの入力:

# 文字認識
m0 = [[0,1,1,1,0], [1,0,0,0,1], [1,0,0,0,1], [1,0,0,0,1], [0,1,1,1,0]]
m1 = [[0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0]]
m2 = [[0,1,1,1,1], [1,0,0,1,0], [0,0,1,0,0], [0,1,0,0,0], [1,1,1,1,1]]
m3 = [[0,1,1,1,0], [1,0,0,0,1], [0,0,1,1,0], [1,0,0,0,1], [0,1,1,1,0]]
m4 = [[0,0,1,0,0], [0,1,0,0,0], [1,0,0,1,0], [1,1,1,1,1], [0,0,0,1,0]]
p0 = flatten(m0); p1 = flatten(m1); p2 = flatten(m2); p3 = flatten(m3)
p4 = flatten(m4)
v  = [0,1,2,3,4]

作成したデータをcontour_plotを使って表示して、確認します。 なんとなくそれらしく見えるでしょう。(笑)

sageへの入力:

# 画像作成用の関数
def mesh(x, y, tbl):
    idX = int(x)
    idY = int(4.9-y)
    return tbl[idY][idX]
# 学習用画像を表示
trn0 = contour_plot(lambda x, y : mesh(x, y, m0), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
trn1 = contour_plot(lambda x, y : mesh(x, y, m1), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
trn2 = contour_plot(lambda x, y : mesh(x, y, m2), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
trn3 = contour_plot(lambda x, y : mesh(x, y, m3), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
trn4 = contour_plot(lambda x, y : mesh(x, y, m4), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
html.table([[trn0, trn1, trn2], [trn3, trn4]])

11.png

画像認識

いよいよ画像認識の準備が整いました。 えいや〜で、モデルを作成します。

sageへの入力:

# 画像認識
param = svm_parameter(kernel_type = RBF, C=float(10))
prob = svm_problem(v,[p0,p1,p2,p3,p4])
m = svm_model(prob, param)
*
optimization finished, #iter = 1
nu = 0.246622
obj = -2.466216, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.676333
obj = -6.763327, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.280928
obj = -2.809276, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.330771
obj = -3.307713, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.202682
obj = -2.026823, rho = 0.000000
nSV = 2, nBSV = 0
*
optimization finished, #iter = 1
nu = 0.262318
obj = -2.623181, rho = 0.000000
nSV = 2, nBSV = 0
Total nSV = 5

テスト用の画像

テストに使う画像は、以下の3個です。 判別結果は、ただしく0, 1, 1となっています。(めでたし、めでたし)

sageへの入力:

# テスト用の画像を作成する
m0_1 = [[0,1,1,1,0], [0,1,0,0,1], [0,1,0,0,1], [0,1,0,0,1], [0,1,1,1,0]]
m1_1 = [[0,0,1,0,0], [0,1,1,0,0], [0,0,1,0,0], [0,0,1,0,0], [0,1,1,1,0]]
m1_2 = [[0,0,0,1,0], [0,0,0,1,0], [0,0,1,0,0], [0,0,1,0,0], [0,0,1,0,0]]
t0_1 = flatten(m0_1)
t1_1 = flatten(m1_1)
t1_2 = flatten(m1_2)

sageへの入力:

# テスト用画像の表示
cnt1 = contour_plot(lambda x, y : mesh(x, y, m0_1), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
cnt2 = contour_plot(lambda x, y : mesh(x, y, m1_1), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
cnt3 = contour_plot(lambda x, y : mesh(x, y, m1_2), (x, 0, 4.9), (y, 0, 4.9), aspect_ratio=1, figsize=(2, 1.23))
html.table([[cnt1, cnt2, cnt3]])

12.png

sageへの入力:

# テスト用画像の認識
print m.predict(t0_1)
print m.predict(t1_1)
print m.predict(t1_1)
0.0
1.0
1.0

念のため、学習に使ったデータが正しく識別されるかもみてみました。

こんなに少ないデータでよく認識できるものだと感心しました。これがSVMのすごさなのでしょうか?

sageへの入力:

# 学習用画像の認識
print m.predict(p0)
print m.predict(p1)
print m.predict(p2)
print m.predict(p3)
print m.predict(p4)
0.0
1.0
2.0
3.0
4.0

コメント

選択肢 投票
おもしろかった 26  
そうでもない 0  
わかりずらい 3  

皆様のご意見、ご希望をお待ちしております。


(Input image string)


添付ファイル: filelibsvm-2.9.tar.gz 1356件 [詳細] file12.png 1812件 [詳細] file11.png 1770件 [詳細] file10.png 1736件 [詳細] file9.png 1657件 [詳細] file8.png 1770件 [詳細] file7.png 1684件 [詳細] file6.jpeg 1665件 [詳細] file5.png 1791件 [詳細] file4.png 1750件 [詳細] file3.jpeg 1813件 [詳細] file2.png 1725件 [詳細] file1.png 1735件 [詳細]

トップ   編集 凍結解除 差分 バックアップ 添付 複製 名前変更 リロード   新規 一覧 単語検索 最終更新   ヘルプ   最終更新のRSS
Last-modified: 2022-12-21 (水) 20:55:33 (492d)
SmartDoc