しまてく

学んだ技術を書きためるブログ

ITエンジニアのための機械学習理論入門 2-1

毎週少しずつ読んでいるんですが、なかなか理解が遅いので記事としてまとめてみます。

完全に自分用メモ。

2.1.3 数学徒の小部屋

$$ 誤差 E_D = \dfrac{1}{2} \sum_{n=1}^{N} (\sum_{m'=0}^{M}ω_{m'}x_{n}^{m'} - t_{n})^2 \tag{2.4} $$

ここで(2.4)を最小にする $ {ω_m}_{m=0}^M $ を決定する。

$$ \dfrac{∂E_{D}}{∂ω_{m}} = 0     (m = 0, …, M)\tag{2.5} $$

に、(2.4)を代入すると以下の式になる。

$$ \dfrac{∂{ \dfrac{1}{2} \sum_{n=1}^{N} (\sum_{m'=0}^{M}ω_{m'}x_{n}^{m'} - t_{n})^2 }}{∂ω_m} = 0 $$

ここで、$ E_D $に代入するmは$ \sum_{m=0}^{M} $の、冪級数を表すmなので、 偏微分したいωmのmとは違う。

よって、冪級数を表すmをm'として置く。 そこで、

$$ \dfrac{1}{2} \sum_{n=1}^{N} (\sum_{m'=0}^{M}ω_{m'}x_{n}^{m'} - t_{n})^2 $$

を$ ω_m $で偏微分する。

実際の偏微分は、つい2乗を展開したくなるのが これは間違いで、合成関数の偏微分を利用する。


【合成関数の偏微分

$$ \dfrac{∂f(g(x, y))}{∂x} = f'(g(x, y)) \cdot \dfrac{∂g(x, y)}{∂x} $$


ここで、

$$ \begin{eqnarray} g(x) &=& \sum_{m'=0}^{M}ω_{m'}x_{n}^{m'} - tn \nonumber \\ f(x) &=& g(x)^2 \nonumber \end{eqnarray} $$

と置くと、

$$ \begin{eqnarray} f'(x) &=& 2g(x) \nonumber \\ &=& 2(\sum_{m'=0}^{M}ω_{m'}x_n^{m'} - tn) \nonumber \end{eqnarray} $$

となる。

また、$ \dfrac{∂g(x, y)}{∂x} $ は $ \dfrac{∂(\sum_{m'=0}^{M}ω_{m'}x_n^{m'} - tn)}{∂ω_m} = x_n^m $ となる。

従って $ \dfrac{∂{ \dfrac{1}{2} \sum_{n=1}^{N} ( \sum_{m'=0}^{M} ω_{m'} x_n^{m'} - t_n)^2 }}{∂ω_m} $ は

$$ \begin{eqnarray} \dfrac{1}{2} \sum_{n=1}^{N} \cdot 2( \sum_{m'=0}^{M}ω_{m'}x_n^{m'} - t_n) \cdot x_n^m &=& 0 \nonumber \\ \sum_{n=1}^{N}(\sum_{m'=0}^{M}ω_{m'}x_n^{m'} - t_n) \cdot x_n^m &=& 0 \nonumber \tag{2.7} \end{eqnarray} $$

となる。

やや作為的だが、これを次のように変形する。

$$ \sum_{m'=0}^{M}ω_{m'} \sum_{n=1}^{N}x_n^{m'}x_n^m - \sum_{n=1}^{N}t_nx_n^m = 0 \tag{2.8} $$

ここで、$ x_n^m $ を (n, m)成分とするNx(M+1)行列φを用いると、これは行列形式で書き直せる。

$$ w^Tφ^Tφ-t^Tφ = 0 \tag{2.9} $$

(2.8)から(2.9)の変換の正しさを逆方向の式変換で考える。

(2.9)の $ w^Tφ^Tφ $ に注目する。

$$ w = \left( \begin{array}{c} ω_0 \\ ω_1 \\ ︙ \\ ω_M \end{array} \right),      φ = \left( \begin{array}{c} x_1^0 & x_1^1 & \cdots & x_1^M \\ x_2^0 & x_2^1 & \cdots & x_2^M \\ \vdots & \vdots & \ddots & \vdots \\ x_N^0 & x_N^1 & \cdots & x_N^M \end{array} \right),      φ^T = \left( \begin{array}{c} x_1^0 & x_2^0 & \cdots & x_N^0 \\ x_1^1 & x_2^1 & \cdots & x_N^1 \\ \vdots & \vdots & \ddots & \vdots \\ x_1^M & x_2^M & \cdots & x_N^M \end{array} \right) \tag{2.10} $$

具体的な数値で、N = 2, M = 2と置くと、

$$ w = \left( \begin{array}{c} ω_0 \\ ω_1 \\ ω_2 \end{array} \right),      φ = \left( \begin{array}{c} x_1^0 & x_1^1 & x_1^2 \\ x_2^0 & x_2^1 & x_2^2 \end{array} \right),      φ^T = \left( \begin{array}{c} x_1^0 & x_2^0 \\ x_1^1 & x_2^1 \\ x_1^2 & x_2^2 \end{array} \right) $$

ここで、$ φ^Tφ $ を計算すると、

$$ \begin{eqnarray} φ^Tφ &=& \left( \begin{array}{c} x_1^0 & x_2^0 \\ x_1^1 & x_2^1 \\ x_1^2 & x_2^2 \end{array} \right) \left( \begin{array}{c} x_1^0 & x_1^1 & x_1^2 \\ x_2^0 & x_2^1 & x_2^2 \end{array} \right) \nonumber\\ &=& \left( \begin{array}{c} x_1^0x_1^0 + x_2^0x_2^0 & x_1^0x_1^1 + x_2^0x_2^1 & x_1^0x_1^2 + x_2^0x_2^2 \\ x_1^1x_1^0 + x_2^1x_2^0 & x_1^1x_1^1 + x_2^1x_2^1 & x_1^1x_1^2 + x_2^1x_2^2 \\ x_1^2x_1^0 + x_2^2x_2^0 & x_1^2x_1^1 + x_2^2x_2^1 & x_1^2x_1^2 + x_2^2x_2^2 \end{array} \right) \nonumber \\ &=& \left( \begin{array}{c} \sum_{n=1}^{N}x_n^0x_n^0 & \sum_{n=1}^{N}x_n^0x_n^1 & \sum_{n=1}^{N}x_n^0x_n^2 \\ \sum_{n=1}^{N}x_n^1x_n^0 & \sum_{n=1}^{N}x_n^1x_n^1 & \sum_{n=1}^{N}x_n^1x_n^2 \\ \sum_{n=1}^{N}x_n^2x_n^0 & \sum_{n=1}^{N}x_n^2x_n^1 & \sum_{n=1}^{N}x_n^2x_n^2 \end{array} \right) \nonumber \\ \end{eqnarray} $$

次に $ w^Tφ^Tφ $ を計算すると、

$$ \begin{eqnarray} w^Tφ^Tφ &=& \left( \begin{array}{c} ω_0 & ω_1 & ω_2 \end{array} \right) \left( \begin{array}{c} \sum_{n=1}^{N}x_n^0x_n^0 & \sum_{n=1}^{N}x_n^0x_n^1 & \sum_{n=1}^{N}x_n^0x_n^2 \\ \sum_{n=1}^{N}x_n^1x_n^0 & \sum_{n=1}^{N}x_n^1x_n^1 & \sum_{n=1}^{N}x_n^1x_n^2 \\ \sum_{n=1}^{N}x_n^2x_n^0 & \sum_{n=1}^{N}x_n^2x_n^1 & \sum_{n=1}^{N}x_n^2x_n^2 \end{array} \right) \nonumber \\ &=& \left( \begin{array}{c} ω_0(\sum_{n=1}^{N}x_n^0x_n^0) + ω_1(\sum_{n=1}^{N}x_n^1x_n^0) + ω_2(\sum_{n=1}^{N}x_n^2x_n^0) & ω_0(\sum_{n=1}^{N}x_n^0x_n^1) + ω_1(\sum_{n=1}^{N}x_n^1x_n^1) + ω_2(\sum_{n=1}^{N}x_n^2x_n^1) & ω_0(\sum_{n=1}^{N}x_n^0x_n^2) + ω_1(\sum_{n=1}^{N}x_n^1x_n^2) + ω_2(\sum_{n=1}^{N}x_n^2x_n^2) \end{array} \right) \nonumber \\ &=& \left( \begin{array}{c} \sum_{m'=0}^{N}ω_{m'}\sum_{n=1}^{N}x_n^{m'}x_n^0 & \sum_{m'=0}^{N}ω_{m'}\sum_{n=1}^{N}x_n^{m'}x_n^1 & \sum_{m'=0}^{N}ω_{m'}\sum_{n=1}^{N}x_n^{m'}x_n^2 \end{array} \right) \nonumber \\ &=& \sum_{m'=0}^{N}ω_{m'}\sum_{n=1}^{N}x_n^{m'}x_n^m     (m=0,...,M) \nonumber \\ \end{eqnarray} $$

よって、(2.8)は(2.9)に式変換できる。

話は戻って、(2.9) $ w^Tφ^Tφ-t^Tφ = 0 $を転置を取ってwについて式変換すると、

$$ \begin{eqnarray} w &=& \dfrac{φ^Tt}{φ^Tφ} \nonumber \\ &=& {(φ^Tφ)}^{-1}φ^Tt \tag{2.11} \end{eqnarray} $$

φとtはそれぞれトレーニングセットに含まれる観測データから決まるものなので、 (2.11)は、与えられたトレーニングセットを用いて多項式の係数wを決定する式になっている。

(2.11)の $ φ^Tφ $ は逆行列を持つのか?という確認。

$E_D$の2階偏微分係数を表すヘッセ行列を用いて説明する。


二階偏微分

$$ \dfrac{∂^2f}{∂x∂y} = \dfrac{∂}{∂y}\left(\dfrac{∂f}{∂x}\right) = \dfrac{∂}{∂x}\left(\dfrac{∂f}{∂y}\right) = \dfrac{∂^2f}{∂y∂x} $$


ヘッセ行列Hは、次の成分を持つ(M+1)x(M+1)の正方行列となっている。

$$ H_{mm'} = \dfrac{∂^2E_D}{∂ω_m∂ω_{m'}}       (m, m' = 0, ..., M) \tag{2.12} $$

$$ \begin{eqnarray} H_{mm'} &=& \dfrac{∂^2E_D}{∂ω_m∂ω_{m'}} \nonumber \\ &=& \dfrac{∂}{∂ω_{m'}}\left(\dfrac{∂E_D}{∂ω_{m}}\right) \nonumber \\ &=& \dfrac{∂}{∂ω_{m'}}\left(\sum_{m'=0}^{M}ω_{m'} \sum_{n=1}^{N}x_n^{m'}x_n^m - \sum_{n=1}^{N}t_nx_n^m\right) \nonumber \\ &=& \sum_{n=1}^{N}x_n^{m'}x_n^m \nonumber \\ \end{eqnarray} $$ $※ \dfrac{∂E_D}{∂ω_m} はすでに行っている計算なので結果を代入している。 $

(2.10)を用いると、(2.11)で逆行列を取っている部分の行列がヘッセ行列に一致することがわかる。

$$ H = φ^Tφ \tag{2.14} $$

この時、$M+1 \le N $、 すなわち係数の個数 M+1がトレーニングセットのデータ数N以下であれば、 ヘッセ行列は正定値であることがわかる。

正定値というのは、任意のベクトルu ($u \ne 0$)に対して、$u^THu \gt 0 $が成立することを言う。

今の例では、以下の式になる。 $$ u^TH_u = u^Tφ^Tφu = {|| φu ||}^2 > 0 $$

この不等号が成立するのは、 $φu \ne 0$の場合に限るが、$φ$の定義(2.10)を思い出すと、$φu = 0$は 要素数がM+1のベクトルuに対するN本の斉次な連立一次方程式になるので、$M+1 \le N$の場合、 自明でない解$u \ne 0$を見つけることはできない。

したがって、$u \ne 0$は必ず成立して、ヘッセ行列$φ^Tφ$は正定値となる。

さらに、正定値な行列は逆行列をもつことが証明できるので、逆行列$(φ^Tφ)^{-1}$が確かに存在し、 停留点は(2.11)で一意に決まる

そしてヘッセ行列が正定値であることから、この停留点は$E_D$の極小値を与えることが示される。 これで(2.11)は$E_D$を最小にする唯一のwを与える事が示された。

一方、$ M+1 \gt N $、すなわち、係数の個数がトレーニングセットのデータを超える場合は、 ヘッセ行列は半正定値($u^THu \ge 0$)となるため、$E_D$を最小にするwは複数存在し、一意に決定されなくなる。