検索
連載

[AI・機械学習の数学]合成関数の微分(連鎖律)とニューラルネットワーク初歩の初歩AI・機械学習の数学入門(2/2 ページ)

微分法は回帰分析だけでなく、機械学習のさまざまなタスクで使われる。特に、合成関数の微分(連鎖律)はニューラルネットワークの学習において必須となる。今回はそのための第一歩として、合成関数がどのようなものであるかを見た後、合成関数の微分法の公式とその計算方法を紹介する。

Share
Tweet
LINE
Hatena
前のページへ |       

目標【その2】: 合成関数の微分

 合成関数の微分には、以下の公式が使えます。

 どの変数で微分するかを明確にするために、以下のように表すこともできます。

 さらに、y=f(x)z=g(y)とすると、この公式は、以下のように簡潔に表せます。

 合成関数はyで微分したものとxで微分したものの積になっています。これらの式は分数ではありませんが、あたかも分数であるかのように約分できるというわけです。(1)式よりも意味が分かりやすいですね。

 この式は、微分が次々とつながっているように見えるので「連鎖律(chain rule)」と呼ばれることもあり、ニューラルネットワークで損失関数(正解と出力値の差を表す関数)の微分を各層に逆順にさかのぼりながら次々と伝えること(=逆伝播)を表すのに使われます。それによって、重みやバイアスの値を調整していくというわけです。

 偏微分の場合も同様です。zyの関数で、yx1,x2,...の関数であるものとし、zx1で偏微分するなら、以下のようになります。

解説【その2】: 合成関数の微分(連鎖律)

 合成関数は、1ページ目で見た方法で式を展開してから微分することもできますが、公式を使うと、ステップを分けて微分できます。ニューラルネットワーク(特にディープラーニング)の損失関数を逆伝播する場合のように、何段階にもわたる合成関数の場合、いちいち合成関数を全て展開してから微分するのは現実的ではありません。公式を使ってステップを分けると、途中の式が簡単になり、計算がかなりラクになります。

 合成関数の微分の公式の意味を(1)式と(2)式とを例に見ておきましょう。

合成関数の微分の公式(1)
図7 合成関数の微分の公式(1)

 (1)式では、g'(y)の部分がyでの微分になることに注意してください。

合成関数の微分の公式(2)
図8 合成関数の微分の公式(2)

 (2)式の書き方だと、どの変数で微分するのかがよく分かりますね。この式をよく見ると、以下の図9のように約分と同じような計算ができることが分かります。逆に言うと、何らかの変数を使って合成関数の微分を右辺のような何段階かの微分に変形できるということです。

合成関数は段階を分けて微分できる
図9 合成関数は段階を分けて微分できる

 では、簡単な例を使って、合成関数の微分の公式の使い方に慣れておきましょう。例によって穴埋め問題にしておくので、考えながら読み進めていってください。答えはオレンジ色の部分をクリックまたはタップすれば表示できます。なお、動画でも解説しているので、ぜひ参照してみてください。

合成関数の微分の例1

動画2 合成関数の微分の例1


 2つの関数が以下のようなものであったとします。

 このとき、合成関数g(f(x))xで微分してみましょう。

  {g(f(x))}' = g'(f(x)) ⋅ f'(x)
       = g'(y) ⋅ f'(x)    ⋯ [A]
       = ( y 2)' ⋅ (3 x 2+4)'    ⋯ [B]
       =  2 y ⋅  6 x    ⋯ [C]
       = 2(3x2+4) ⋅  6 x    ⋯ [D]
       = (6x2+8) ⋅  6 x
       = 36x3+48x

[A] ⋯ ここまでは合成関数の微分の公式
[B]g(y)= y 2f(x)=3 x 2+4を代入した
[C] ⋯ 微分した。積の前の項はyで微分、後ろはxで微分していることに注意(後述)
[D]y=3x2+4を代入した

 念のため、式を展開してから微分した結果と見比べておきましょう。

[E]y=3x2+4を代入した
[F] ⋯ 二乗を展開した
[G]xで微分した

 当然のことながら、ちゃんと答えが一致していますね。少し話を戻しますが「'」を使った書き方だと、[B]式を微分して[C]式にするところが分かりづらかったと思います。[B]式を再掲しておきます。

 この場合、y2yで微分し、3x2+4xで微分する必要があるわけですが、それがはっきりと分かりませんね。しかし、以下の書き方で表すと、何で微分するかを迷うことはありません。「分母」にあたる部分に書かれた変数(例えばdz/dyならy)で微分するということが分かります。

 このとき、y=3x2+4z=y2なので、以下のようになります。

[H] ⋯ 右辺のzy2を代入し、y3x2+4を代入した
[I] ⋯ 微分した
[j]y=3x2+4を代入した

合成関数の微分の例2(偏微分の場合)

 偏微分の場合も計算方法は同じです。これについても具体例で見ておきましょう。なお、こちらも動画での解説を用意してあります。ぜひ参照してみてください。

動画3 合成関数の微分の例2(偏微分の場合)


のとき、zx1で偏微分すると以下のようになります。

   ∂z   ∂z   ∂y
  ―― = ―― ⋅ ――
   ∂x1  ∂y  ∂x1
        ∂      ∂ 
     = ――  y 2 ⋅ ―― (3 x1  + 2 x2 )    ⋯ [A]
       ∂y      ∂x1
     = 2y ⋅  3     ⋯ [B]
     = 2(3 x1  + 2 x2 ) ⋅  3     ⋯ [C]
     = (6x1 + 4x2) ⋅  3 
     = 18x1 + 12x2

[A] ⋯ 右辺のz y 2を代入し、y3 x1 +2 x2 を代入した
[B] ⋯ 微分した
[C]y=3 x1  + 2 x2 を代入した

 念のため、合成関数の計算を行ってから偏微分して結果が一致することを確認しておきます。

 はい、確かに一致していますね。合成関数の微分(連鎖律)も、微分の基本的な計算方法が分かっていればあとは代入、四則演算、べき乗といった単純な計算の積み重ねだけでできることが分かったと思います。

 ところで、今回の例として示したXORの計算を行うニューラルネットワークでは、重みを表す変数が6個、バイアスを表す変数が3個あります。そのような単純な例でも、損失関数を重みやバイアスで偏微分し、逆伝播の計算を行う式を全て書くとあまりにも煩雑になってしまいます。とりあえず、同じ構造の例を図で表したものだけ掲載しておきますが、このまま計算することを考えると気が遠くなりそうですね。

ニューラルネットワークと損失関数の例
図10 ニューラルネットワークと損失関数の例
XORを求めるニューラルネットワークと構造は同じだが、重みなどを文字で表してある(変数名も少し変えてある)。変数名の右肩の(1)(2)(3)は何層目であるかを表すために付けてある。

 損失関数とは正解の値とニューラルネットワークの出力との差(誤差)を表すような関数です。例えば、誤差を二乗した値の総和(二乗和誤差)などが損失関数として使われます。損失関数を使って重みやバイアスを調整するには、損失関数を偏微分し、学習率η(イータと読みます)を掛けて値を更新します。例えば、損失関数をLとし、w(1)11を更新するのであれば、以下のような更新式が考えられます(第7回の記事で簡単に説明した勾配降下法です)。

 ここでは、損失関数が合成関数になっていることを確認しておいていただくだけで十分です。例えば、

となっており、さらに、

となっています。枠で囲んだ部分を見れば合成関数であることが分かります(XORを求めるニューラルネットワークの例で具体的に見たことも思い出してください)。これらの更新式を全ての重みとバイアスについて計算する必要があるわけです。

 これらの式をもっと簡単に表す方法があれば、重みやバイアスを一気に求めることができるのですが、現時点ではまだ道具が足りていません。そういうわけで、今回は「はじめの一歩」ということで合成関数の偏微分の方法を確認するところまでにとどめておくことにします。……で、実は、その道具というのがベクトルや行列、つまり線形代数です。次回からは線形代数を学んでいき、少しずつモヤモヤを晴らしていきたいと思います。

次回は……

 今回は線形代数の基本として、まずはベクトルの考え方と基本的な計算(和・差・内積・ベクトルの長さ)について見ていきます。ベクトルが使えるようになると多数の値や変数をたった1つの文字で表すことができます。また、ベクトルは「類似度」を求める計算などへの応用もできます。

「AI・機械学習の数学入門」のインデックス

AI・機械学習の数学入門

Copyright© Digital Advantage Corp. All Rights Reserved.

前のページへ |       
[an error occurred while processing this directive]
ページトップに戻る