合成関数の微分には、以下の公式が使えます。
どの変数で微分するかを明確にするために、以下のように表すこともできます。
さらに、y=f(x)、z=g(y)とすると、この公式は、以下のように簡潔に表せます。
合成関数はyで微分したものとxで微分したものの積になっています。これらの式は分数ではありませんが、あたかも分数であるかのように約分できるというわけです。(1)式よりも意味が分かりやすいですね。
この式は、微分が次々とつながっているように見えるので「連鎖律(chain rule)」と呼ばれることもあり、ニューラルネットワークで損失関数(正解と出力値の差を表す関数)の微分を各層に逆順にさかのぼりながら次々と伝えること(=逆伝播)を表すのに使われます。それによって、重みやバイアスの値を調整していくというわけです。
偏微分の場合も同様です。zがyの関数で、yがx1,x2,...の関数であるものとし、zをx1で偏微分するなら、以下のようになります。
合成関数は、1ページ目で見た方法で式を展開してから微分することもできますが、公式を使うと、ステップを分けて微分できます。ニューラルネットワーク(特にディープラーニング)の損失関数を逆伝播する場合のように、何段階にもわたる合成関数の場合、いちいち合成関数を全て展開してから微分するのは現実的ではありません。公式を使ってステップを分けると、途中の式が簡単になり、計算がかなりラクになります。
合成関数の微分の公式の意味を(1)式と(2)式とを例に見ておきましょう。
(1)式では、g'(y)の部分がyでの微分になることに注意してください。
(2)式の書き方だと、どの変数で微分するのかがよく分かりますね。この式をよく見ると、以下の図9のように約分と同じような計算ができることが分かります。逆に言うと、何らかの変数を使って合成関数の微分を右辺のような何段階かの微分に変形できるということです。
では、簡単な例を使って、合成関数の微分の公式の使い方に慣れておきましょう。例によって穴埋め問題にしておくので、考えながら読み進めていってください。答えはオレンジ色の部分をクリックまたはタップすれば表示できます。なお、動画でも解説しているので、ぜひ参照してみてください。
2つの関数が以下のようなものであったとします。
このとき、合成関数g(f(x))をxで微分してみましょう。
{g(f(x))}' = g'(f(x)) ⋅ f'(x)
= g'(y) ⋅ f'(x) ⋯
[A]
= ( y 2)' ⋅ (3 x 2+4)' ⋯
[B]
= 2 y ⋅ 6 x ⋯
[C]
= 2(3x2+4) ⋅ 6 x ⋯
[D]
= (6x2+8) ⋅ 6 x
= 36x3+48x
[A] ⋯
ここまでは合成関数の微分の公式
[B] ⋯
g(y)= y 2、f(x)=3 x 2+4を代入した
[C] ⋯
微分した。積の前の項はyで微分、後ろはxで微分していることに注意(後述)
[D] ⋯
y=3x2+4を代入した
念のため、式を展開してから微分した結果と見比べておきましょう。
[E] ⋯
y=3x2+4を代入した
[F] ⋯
二乗を展開した
[G] ⋯
xで微分した
当然のことながら、ちゃんと答えが一致していますね。少し話を戻しますが「'」を使った書き方だと、[B]式を微分して[C]式にするところが分かりづらかったと思います。[B]式を再掲しておきます。
この場合、y2はyで微分し、3x2+4はxで微分する必要があるわけですが、それがはっきりと分かりませんね。しかし、以下の書き方で表すと、何で微分するかを迷うことはありません。「分母」にあたる部分に書かれた変数(例えばdz/dyならy)で微分するということが分かります。
このとき、y=3x2+4、z=y2なので、以下のようになります。
[H] ⋯
右辺のzにy2を代入し、yに3x2+4を代入した
[I] ⋯
微分した
[j] ⋯
y=3x2+4を代入した
偏微分の場合も計算方法は同じです。これについても具体例で見ておきましょう。なお、こちらも動画での解説を用意してあります。ぜひ参照してみてください。
のとき、zをx1で偏微分すると以下のようになります。
∂z ∂z ∂y
―― = ―― ⋅ ――
∂x1 ∂y ∂x1
∂ ∂
= ―― y 2 ⋅ ―― (3 x1 + 2 x2 ) ⋯
[A]
∂y ∂x1
= 2y ⋅ 3 ⋯
[B]
= 2(3 x1 + 2 x2 ) ⋅ 3 ⋯
[C]
= (6x1 + 4x2) ⋅ 3
= 18x1 + 12x2
[A] ⋯
右辺のzに y 2を代入し、yに3 x1 +2 x2 を代入した
[B] ⋯
微分した
[C] ⋯
y=3 x1 + 2 x2 を代入した
念のため、合成関数の計算を行ってから偏微分して結果が一致することを確認しておきます。
はい、確かに一致していますね。合成関数の微分(連鎖律)も、微分の基本的な計算方法が分かっていればあとは代入、四則演算、べき乗といった単純な計算の積み重ねだけでできることが分かったと思います。
ところで、今回の例として示したXORの計算を行うニューラルネットワークでは、重みを表す変数が6個、バイアスを表す変数が3個あります。そのような単純な例でも、損失関数を重みやバイアスで偏微分し、逆伝播の計算を行う式を全て書くとあまりにも煩雑になってしまいます。とりあえず、同じ構造の例を図で表したものだけ掲載しておきますが、このまま計算することを考えると気が遠くなりそうですね。
損失関数とは正解の値とニューラルネットワークの出力との差(誤差)を表すような関数です。例えば、誤差を二乗した値の総和(二乗和誤差)などが損失関数として使われます。損失関数を使って重みやバイアスを調整するには、損失関数を偏微分し、学習率η(イータと読みます)を掛けて値を更新します。例えば、損失関数をLとし、w(1)11を更新するのであれば、以下のような更新式が考えられます(第7回の記事で簡単に説明した勾配降下法です)。
ここでは、損失関数が合成関数になっていることを確認しておいていただくだけで十分です。例えば、
となっており、さらに、
となっています。枠で囲んだ部分を見れば合成関数であることが分かります(XORを求めるニューラルネットワークの例で具体的に見たことも思い出してください)。これらの更新式を全ての重みとバイアスについて計算する必要があるわけです。
これらの式をもっと簡単に表す方法があれば、重みやバイアスを一気に求めることができるのですが、現時点ではまだ道具が足りていません。そういうわけで、今回は「はじめの一歩」ということで合成関数の偏微分の方法を確認するところまでにとどめておくことにします。……で、実は、その道具というのがベクトルや行列、つまり線形代数です。次回からは線形代数を学んでいき、少しずつモヤモヤを晴らしていきたいと思います。
今回は線形代数の基本として、まずはベクトルの考え方と基本的な計算(和・差・内積・ベクトルの長さ)について見ていきます。ベクトルが使えるようになると多数の値や変数をたった1つの文字で表すことができます。また、ベクトルは「類似度」を求める計算などへの応用もできます。
Copyright© Digital Advantage Corp. All Rights Reserved.