用語「Mish関数」について説明。「0」を基点として、入力値が0以下なら出力値は「ほぼ0」だが(わずかに「負の値」になる)、0より上なら「入力値とほぼ同じ値」を返す、ニューラルネットワークの活性化関数を指す。類似するReLUやSwish関数の代替として使われる。
AI/機械学習のニューラルネットワークにおけるMish関数(ミッシュ関数)とは、関数への入力値が0以下の場合には出力値がほぼ0(※わずかに負の値になり、最小値は約-0.31)、入力値が0より上の場合には出力値が入力値とほぼ同じ値(最大値は∞)となる関数である。
図1を見ると分かるように、ReLUという活性化関数に似ている曲線を描く(=ReLUをMish関数に置き換えやすい)が、その曲線が連続的で滑らか(smooth)かつ非単調(non-monotonic)である点が異なる。基点として(0, 0)を通るが、滑らかであるため、いったん下側(マナイス側)に少し膨らむ点が特徴的だ。この特徴も含めて、Swish関数に酷似しており、Swish関数と比較するとさらにReLUに近い曲線を描くという特徴がある。
現在のディープニューラルネットワークでは、隠れ層(中間層)の活性化関数としては、ReLUを使うのが一般的である。しかし、より良い結果を求めて、ReLU以外にもさまざまな代替の活性化関数が考案されてきている。その中でも最近(2019年後半〜執筆時点=2020年4月時点で)、一部でやや注目されているのが(本稿で解説する)Mish関数である。
まずは、ReLUを試した後、より良い精度を求めてMish関数に置き換えて検証してみる、といった使い方が考えられる。2019年の論文「arXiv:1908.08681v2 [cs.LG]」や「arXiv:1908.08681 [cs.LG]」によると、ReLUやSwish関数よりも良い精度になりやすい、とのことである。その導関数(Derivative function:微分係数の関数)は図2のようになり、ReLUよりも滑らかに変化し、Swish関数よりもさらにReLUに近い曲線を描く、という特徴がある。
この特徴により、ReLUだと0.0と1.0の境界でカクカクとした学習結果となってしまうが、Mish関数ではその問題がなく滑らかな学習結果になる。
主要ライブラリには、執筆時点(2020年4月時点)では標準搭載されていない(※後述の定義と数式を見ると分かるが、数式が他の活性化関数と比べてかなり複雑なので、ライブラリに標準搭載される可能性は低いかもしれない、と筆者は考えている)。Mish関数をTensorFlow/KerasやPyTorchで利用したい場合は、考案者であるDiganta Misra氏のGitHubリポジトリにあるカスタム実装を参考にするとよい。具体的には下記リンク先を参照してほしい。
冒頭では文章により説明したが、厳密に数式で表現しておこう。
Mish関数は、ソフトプラス(Softplus)関数とTanh関数を内部で用いる。ここでは、ソフトプラス関数をσ(x)(=xを入力とする、「シグマ」という名前の関数)、Tanh関数をτ(x)(=xを入力とする、「タウ」という名前の関数)と表現する。数式は以下の通りである。
e(オイラー数=ネイピア数)や、それに対応するnp.exp(x)という後述のコードについては、「シグモイド関数」で説明しているので、詳しくはそちらを参照してほしい。
このσ(x)とτ(x)を用いてMish関数の数式を表現すると、次のようになる。
上記の数式の導関数を求めると、複雑な式となってしまうため、式を次のように3つの式に分割した(※各式の記述は、論文「arXiv:1908.08681v2 [cs.LG]」に倣った)。1つ目はω(x)(=xを入力とする、「オメガ」という名前の関数)で、2つ目はδ(x)(=xを入力とする、「デルタ」という名前の関数)である。その2つの関数を用いて記述した3つ目の関数が、導関数f'(x)である。
上記のMish 関数の数式をPythonコードの関数にするとリスト1のようになる。なお、オイラー数に関して計算するため、ライブラリ「NumPy」をインポートして使った。
import numpy as np
def softplus(x):
return np.log(1.0 + np.exp(x))
def tanh(x):
return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))
def mish(x):
return x * tanh(softplus(x))
Mish 関数の導関数(derivative function)のPythonコードも示しておくと、リスト2のようになる。
import numpy as np
def omega(x):
return 4*(x+1) + 4*np.exp(2*x) + np.exp(3*x) + np.exp(x)*(4*x+6)
def delta(x):
return 2*np.exp(x) + np.exp(2*x) + 2
def der_mish(x):
return np.exp(x) * omega(x) / delta(x)**2
Copyright© Digital Advantage Corp. All Rights Reserved.