パターン認識と機械学習 は、機械学習に関するとても優れた教科書です。ここでは、Sageを使って教科書の例題を実際に試してみます。
第1章の例題として、$ sin(2 \pi x) $の回帰問題について見てみます。
上巻の付録Aにあるように http://research.microsoft.com/~cmbishop/PRML から 例題のデータをダウンロードすることができます。
これから計算に使うSin曲線は、 $$ y = sin(2 \pi x) + \mathcal{N}(0,0.3) $$ で与えられたデータです。ここでは、curve fitting dataで提供されているデータを使用します。
最初に、データを座標Xと目的値tにセットし、sin曲線と一緒に表示してみます。
|
3章の式(3.15)によると最小自乗法の解は、 $$ W_{ML} = ( \Phi^T \Phi )^{-1} \Phi^T t $$ で与えられます。ここで、$\Phi$ は計画行列(design matrix)と呼ばれ、その要素は、$\Phi_{nj} = \phi_j(x_n)$ 与えらます。
また、行列 $ \Phi^{\dagger} $は、ムーア_ベンローズの疑似逆行列と呼ばれています。 $$ \Phi^{\dagger} = \left( \Phi^T \Phi \right)^{-1} \Phi^T $$
多項式フィッティングの式(1.1)では、$y(x, w)$を以下のように定義します。 $$ y(x, w) = \sum^{M}_{j=0} w_j x^j $$
そこで、$\phi_j(x)$を以下のように定義します。 $$ \phi_j(x) = x^j $$
|
定義に従って計画行列$\Phi$、計画行列の転置行列$\Phi^T$、ムーア_ベンローズの疑似逆行列$\Phi^\dagger$を求め、 平均の重み$W_ML$を計算します。
(0.313702727439780, 7.98537103199157, -25.4261022423404, 17.3740765279263) (0.313702727439780, 7.98537103199157, -25.4261022423404, 17.3740765279263) |
sageでは、多項式y(x)を以下のように定義します。
|
$M=3$の時の、多項式回帰の結果(赤)をサンプリング(青)とオリジナルの$sin(2 \pi x)$を合わせて プロットします。
|
図1.3に相当する図を表示します。スライダーでMの値を変えてみてください。
$M=9$ではすべての点を通る曲線になりますが、元のsin曲線とはかけ離れた形になります。 これは、過学習(over fitting)と呼ばれる現象で、機械学習はこの過学習との戦いになります。
Click to the left again to hide and once more to show the dynamic interactive window |
サンプルを100個に増やした検証用データを生成します。
|
サンプルを100個に増やし、$M=9$の多項式フィッティングを行ったのが、以下の図です(原書の図1.6に対応)。
サンプルを増やせば、$M=9$でも元のsin曲線に近い形になりますが、少数の貴重なサンプルではこのような方法は使えません。
|
原書の1.1では訓練データとテスト用データの平均自乗平方根誤差、 $$ E(w) = \frac{1}{2} \sum^N_{n=1} \{ y(x_n, w) - t_n \}^2 $$ を使って最適なMの値を求める手法を紹介しています。
以下に訓練用データの$E_{RMS}$(青)、テスト用データの$E_{RMS}$(赤)でMに対する 値の変化を示します。M=3で訓練、テスト共に低い値となることが見て取れます。
|
誤差関数にペナルティ項を追加し、過学習を防ぐ例として、リッジ回帰が紹介されています。 $$ \tilde{E}(w) = \frac{1}{2} \sum^N_{n=1} \{ y(x_n, w) - t_n \}^2 + \frac{\lambda}{2} ||w||^2 $$ を使って正規化します。
$y(x_n,w)$を$\phi$と$w$で書き替えると、原書式(3.27)になる。 $$ \tilde{E}(w) = \frac{1}{2} \sum^N_{n=1} \{ t_n - w^T\phi(x_n) \}^2 + \frac{\lambda}{2} ||w||^2 $$
wに関する勾配を0にすると、原書式(3.28)が求まる。 $$ w = \left( \lambda I + \Phi^T\Phi \right)^{-1} \Phi^T t $$
以下にM=9の練習用データにリッジ回帰を行った結果を示します。
|
リッジ回帰に対する訓練用データの$E_{RMS}$(青)、テスト用データの$E_{RMS}$(赤)でMに対する 値の変化を示します。
|
もっとも良い結果を示すのが、ベイズ的なフィッティングの結果です(原著図1.17)。 $$ m_N = \beta S_N \Phi^T t $$ $$ S^{-1}_{N} = \alpha I + \beta \Phi^T \Phi $$ で与えられます。($m_N$は、直線回帰の$W_{ML}$に相当します)
以下に$\alpha = 5 \times 10^{-3}, \beta = 11.1$の練習用データのベイズ的なフィッティング結果を示します。
|
鈴木由宇さんから、最近はlasso回帰も注目されていると聞き、 smlyさんのブログ「線形回帰モデルとか」 から、cvx_optを使った方法を使わせて頂き、Sageでプロットしてみました。
lasso回帰は、リッジ回帰のペナルティ項が2次から絶対値になったものです。 $$ \tilde{E}(w) = \frac{1}{2} \sum^N_{n=1} \{ y(x_n, w) - t_n \}^2 + \frac{\lambda}{2} ||w|| $$
pcost dcost gap pres dres 0: -7.3129e-01 -1.5730e+04 2e+04 5e-17 8e+01 1: -1.1908e+00 -2.3198e+02 2e+02 2e-16 1e+00 2: -1.6934e+00 -1.1863e+01 1e+01 2e-16 5e-02 3: -1.8817e+00 -2.4425e+00 6e-01 2e-16 2e-03 4: -1.8890e+00 -1.9346e+00 5e-02 2e-16 2e-04 5: -1.8949e+00 -1.9100e+00 2e-02 2e-16 4e-05 6: -1.8990e+00 -1.9027e+00 4e-03 3e-16 4e-14 7: -1.8998e+00 -1.9009e+00 1e-03 2e-16 3e-14 8: -1.9003e+00 -1.9005e+00 2e-04 3e-16 4e-14 9: -1.9004e+00 -1.9005e+00 1e-04 3e-16 3e-14 10: -1.9004e+00 -1.9004e+00 1e-05 5e-16 4e-14 11: -1.9004e+00 -1.9004e+00 2e-07 2e-16 3e-14 Optimal solution found. [ 3.50668337e-01 4.72355549e+00 -5.66598952e-01 -3.33176635e+01 5.28640500e-04 5.46111228e+01 2.04434686e+01 -9.73943631e-04 -1.30982583e+02 8.49936038e+01] pcost dcost gap pres dres 0: -7.3129e-01 -1.5730e+04 2e+04 5e-17 8e+01 1: -1.1908e+00 -2.3198e+02 2e+02 2e-16 1e+00 2: -1.6934e+00 -1.1863e+01 1e+01 2e-16 5e-02 3: -1.8817e+00 -2.4425e+00 6e-01 2e-16 2e-03 4: -1.8890e+00 -1.9346e+00 5e-02 2e-16 2e-04 5: -1.8949e+00 -1.9100e+00 2e-02 2e-16 4e-05 6: -1.8990e+00 -1.9027e+00 4e-03 3e-16 4e-14 7: -1.8998e+00 -1.9009e+00 1e-03 2e-16 3e-14 8: -1.9003e+00 -1.9005e+00 2e-04 3e-16 4e-14 9: -1.9004e+00 -1.9005e+00 1e-04 3e-16 3e-14 10: -1.9004e+00 -1.9004e+00 1e-05 5e-16 4e-14 11: -1.9004e+00 -1.9004e+00 2e-07 2e-16 3e-14 Optimal solution found. [ 3.50668337e-01 4.72355549e+00 -5.66598952e-01 -3.33176635e+01 5.28640500e-04 5.46111228e+01 2.04434686e+01 -9.73943631e-04 -1.30982583e+02 8.49936038e+01] |
|
|