PRMLの第5章-混合密度ネットワークの例題、図5.19
区間(0, 1)に一様分布する変数$x_n$をサンプリングし、目的値$t_n$ を$x_n + 0.3 sin(2\pi x_n)$に区間(-0.1, 0.1)の一様乱数を 加えたデータを200点生成します。
|
Done Done |
生成したデータとx,yを入れ替えたグラフをプロットします。
|
|
生成したデータをニューラルネットワークを使って、フィッティングさせたグラフを 以下に示します。(入力層1個、隠れ層6個、出力層1個で計算)
図5.19のbの様にx,tを入れ替えた場合には上手くフィッティングしません。
|
|
|
WARNING: Output truncated! full_output.txt 2 0.961554425907 3.77799819103 3 1.04117097273 2.41177356492 途中省略 499 1.14719236412 0.281449667896 500 0.626481765163 0.281421869851 WARNING: Output truncated! full_output.txt 2 0.961554425907 3.77799819103 3 1.04117097273 2.41177356492 途中省略 499 1.14719236412 0.281449667896 500 0.626481765163 0.281421869851 |
|
WARNING: Output truncated! full_output.txt 2 1.00671085821 5.493956126 3 0.983878474346 5.09315153282 途中省略 487 -19.5656631037 4.5870585875 488 -117.393981949 4.5870585875 489 -156.525310456 4.5870585875 WARNING: Output truncated! full_output.txt 2 1.00671085821 5.493956126 3 0.983878474346 5.09315153282 途中省略 487 -19.5656631037 4.5870585875 488 -117.393981949 4.5870585875 489 -156.525310456 4.5870585875 |
|
混合モデルの条件付き密度関数$p(t|x)$を式(5.148) $$ p(t|x) = \sum_{k=1}^{K} \pi_k(x) \mathcal{N}(t|\mu_k(x), \sigma_k^2(x) I) $$ で表し、モデルパラメータ
混合計数は、式(5.149) $$ \sum_{k=1}^{K} \pi_k(x) = 1, 0 \le \pi_k(x) \le 1 $$ の制約を満たすために、ソフトマックス関数、式(5.150) $$ \pi_k(x) = \frac{exp(a_k^{\pi})}{\sum_{k=1}^{K}} $$ とし、分散$\sigma_k^2(x) \ge 0$を満たすように、式(5.151) $$ \sigma_k(x) = exp(a_k^{\sigma}) $$ 平均は、ニューラルネットワークの出力をそのまま使って、式(5.152) $$ \mu_{kj}(x) = a_{kj}^{\mu} $$ とします。
誤差関数を式(5.153) $$ E(w) = - \sum_{n=1}^{N} ln \left \{ \sum_{k=1}^K \pi_k(x_n, w) \mathcal{N}(t_n|\mu_k(x_n, w), \sigma_k^2(x_n, w) I) \right \} $$ とすると、事後分布式(5.154) $$ \gamma_{nk}(t_n | x_n) = \frac{\pi_k \mathcal{N}_{nk}}{\sum_{l=1}^{K} \pi_l \mathcal{N}_{nl}} $$ を使って各係数の微分を求めます。
混合係数の微分は、式(5.155) $$ \frac{\partial E_n}{\partial a_k^{\pi}} = \pi_k - \gamma_{nk} $$ 平均の微分は、式(5.156) $$ \frac{\partial E_n}{\partial a_{kl}^{\mu}} = \gamma_{nk} \left ( \frac{\mu_{kl} - t_{nl}}{\sigma_k^2} \right ) $$ 分散の微分は、式(5.157) $$ \frac{\partial E_n}{\partial a_{k}^{\sigma}} = \gamma_{nk} \left ( L - \frac{||t_{n} - \mu_k||^2}{\sigma_k^2} \right ) $$
重みwへの部分は、 $$ \frac{\partial E_n}{\partial w_{ji}^{1}} = \delta_j x_i $$ ここで、$\delta_j$は、 $$ \delta_j = (1 - z_j^2) \sum_{k=1}^2 \left ( w_k^{pi} \frac{\partial E_n}{\partial a_k^{\pi}} + w_k^{\mu} \frac{\partial E_n}{\partial w_{k}^{\sigma}} + \sum_{l=1}^{L} w_{kl}^{\mu} \frac{\partial E_n}{\partial w_{kl}^{\mu}} \right ) $$ 混合係数の重み$w_{kj}^{\pi}$微分は、 $$ \frac{\partial E_n}{\partial w_{kj}^{\pi}} = \frac{\partial E_n}{\partial a_k^{\pi}} z_j $$ 平均の微分の重み$w_{kj}^{\pi}$微分は、 $$ \frac{\partial E_n}{\partial w_{kj}^{\mu}} = \sum_{l=1}^{L} \frac{\partial E_n}{\partial w_{kl}^{\mu}} z_j $$ 分散の微分の重み$w_{kj}^{\sigma}$微分は、 $$ \frac{\partial E_n}{\partial w_{kj}^{\sigma}} = \frac{\partial E_n}{\partial w_{k}^{\sigma}} z_j $$
MixtureDensityNetworkクラスを以下のように定義します。
|
MixtureDensityNetworkの計算にランダムな重みを使った結果を以下に示します。
式(5.148)の分布が平均や分散に強く依存しているため、すべてランダムな重みで計算すると なかなか収束しません。
|
2 0.853606593211 220.691284013135 2 -1.02380954822 295.681734779221 4 49619.0430098 220.688261274285 途中省略 29 0.97008460868 -5.34286542507738 30 0.799581594259 -10.4327483758505 2 0.853606593211 220.691284013135 2 -1.02380954822 295.681734779221 4 49619.0430098 220.688261274285 途中省略 29 0.97008460868 -5.34286542507738 30 0.799581594259 -10.4327483758505 |
|
収束をよくするために、重みのバイアス項の初期値をk-meansクラスター分析の結果を 使うことにします。
データをクラスタ分析した結果を以下に示します。
Iteration 0 Iteration 1 Iteration 2 Iteration 3 Iteration 4 Iteration 5 Iteration 6 [[0.755995147678366, 0.906635821415391], [0.489412268723978, 0.535817638206947], [0.325834642861869, 0.133408609458438]] [[0.0184016725675581, 0.00227560546930994], [0.00822710110832403, 0.0257078799298951], [0.0280504201990768, 0.00434725503090522]] Iteration 0 Iteration 1 Iteration 2 Iteration 3 Iteration 4 Iteration 5 Iteration 6 [[0.755995147678366, 0.906635821415391], [0.489412268723978, 0.535817638206947], [0.325834642861869, 0.133408609458438]] [[0.0184016725675581, 0.00227560546930994], [0.00822710110832403, 0.0257078799298951], [0.0280504201990768, 0.00434725503090522]] |
|
初期値を変えて回帰分析をし直します。
今回は、バイアス項を
2 1.00140531678 162.746607216641 3 0.978118009568 94.6561189652474 4 1.00006922311 74.3749728302636 途中省略 28 0.989452644705 -76.7534334638356 29 1.32320485458 -81.6847672263680 30 0.876928985373 -86.8153622565841 2 1.00140531678 162.746607216641 3 0.978118009568 94.6561189652474 4 1.00006922311 74.3749728302636 途中省略 28 0.989452644705 -76.7534334638356 29 1.32320485458 -81.6847672263680 30 0.876928985373 -86.8153622565841 |
|
|
2 0.999254054027 -87.3044338521603 3 0.982374480725 -91.6889548386661 途中省略 99 1.01213525834 -188.881647537428 100 1.03434125073 -189.734542888085 2 0.999254054027 -87.3044338521603 3 0.982374480725 -91.6889548386661 途中省略 99 1.01213525834 -188.881647537428 100 1.03434125073 -189.734542888085 |
|
2 1.00010153793 -189.851273562036 3 1.00013157633 -190.041029348136 途中省略 99 0.962650552011 -206.426993641372 100 1.03840031165 -206.497416181699 2 1.00010153793 -189.851273562036 3 1.00013157633 -190.041029348136 途中省略 99 0.962650552011 -206.426993641372 100 1.03840031165 -206.497416181699 |
|
|
2 1.00024395346 -206.557574296743 3 1.00161969024 -206.711515824961 途中省略 99 1.03276294743 -212.624652019776 100 0.971359659233 -212.656180208402 2 1.00024395346 -206.557574296743 3 1.00161969024 -206.711515824961 途中省略 99 1.03276294743 -212.624652019776 100 0.971359659233 -212.656180208402 |
最終結果は、PRMLの図5.21
|
|
|
|
混合密度の平均として、PRMLでは条件付き密度の近似的な条件付きモード
|
|
今回、プログラムが収束しない現象続き、とても悩みました!
そこで、5.3.3逆伝搬の効率で紹介されている $$ \frac{\partial E_n}{\partial w_{ji}} = \frac{En(w_{ji} + \epsilon) - En(w_{ji} - \epsilon )}{2 \epsilon} $$ を使って重みを近似計算してみました。これと手計算の値を比較しながら、プログラムをデバッグしました。
|
|
最初に単純なニューラルネットワークの場合について近似計算とネットワークの計算を比較しました。 非常に良い一致を得ました。
(0.0668611517486, 0.0, 0.0207058923852, 0.0, -1.05811442086, 0.258067961374) (0.0668611517486, 0.0, 0.0207058923852, 0.0, -1.05811442086, 0.258067961374) |
(0.0, 0.0, 0.0207058923896, 0.0, -1.05811442086, 0.258067961278) (0.0, 0.0, 0.0207058923896, 0.0, -1.05811442086, 0.258067961278) |
次に混合密度ネットワークの場合についても近似計算とネットワークの計算を比較しました。 複雑な処理をしているにも関わらず、かなり一致することに驚きました。
|
(-0.024679341566, -0.0, 0.0, -0.0, -0.199641572236, 0.162818802495, -1.20402804974, 0.981951820123) (-0.024679341566, -0.0, 0.0, -0.0, -0.199641572236, 0.162818802495, -1.20402804974, 0.981951820123) |
(-0.0246793415704, 0.0, 0.0, 0.0, -0.199641572274, 0.162818802441, -1.20402804982, 0.981951820078) (-0.0246793415704, 0.0, 0.0, 0.0, -0.199641572274, 0.162818802441, -1.20402804982, 0.981951820078) |
|