オートエンコーダーの活用例の一つである異常検知を、MNISTの手書き数字を例に体験してみましょう。
この記事は会員限定です。会員登録(無料)すると全てご覧いただけます。
前回までに全結合型のオートエンコーダー、CNNを利用した畳み込みオートエンコーダー、それから学習を高速に行えるようにGPUを使用する方法などを見てきました。
ところで、オートエンコーダーが持つエンコーダーで元画像を縮小(次元削減)して、それをデコーダーで復元することはどんなことで役に立つのでしょうか。一般には異常検知やノイズ削減などがオートエンコーダーの用途として挙げられています。そこで、今回は、MNISTの手書き文字を例に異常検知とはどんなものなのか、その本当に表面的な部分だけを見てみることにします。
「異常検知」とは、数多くのデータの中から、他のデータとは異なる特徴を持つデータを見つけ出すことです。例えば、工場で生産されたネジを考えてみましょう。製造されたネジの多くは何の問題もなく、製品として出荷できるものです。しかし、その中のごく一部には、傷があったり、先端が曲がっていたりと、製品としては出荷できないものが含まれているかもしれません。職人的な技術を持つ人の目を介せば、そのような異常を持つものも簡単に見つけられるかもしれませんが、異常検知を使うことで、それらを機械的に抜き出せるようになるでしょう。このためにオートエンコーダーが使えるということです。
皆さんは「一流の鑑定人は生まれたときからホンモノだけを目にしてきているので、ニセモノを見たときにはそれがすぐに分かる」のような意味の言葉を聞いたことがありませんか。筆者はマンガや小説で、そうしたセリフを登場人物が話すところを何度か読んだことがあります。つまり、「正常なモノだけを学習させていれば、異常なモノもすぐに分かる」ようになるというわけです。
マンガや小説を根拠にするなと怒られそうですが、オートエンコーダーによる異常検知はこれに近いものがあります。つまり、オートエンコーダーに正常なデータだけを入力して学習を行えば、正常なデータが入力されれば、それは問題なく復元できるけど、異常なデータが入力されたときには、うまく復元できないものになるだろうというのが基本の考え方といえるでしょう。
とはいえ、実際にはネジに付いた少しのキズを検知するようなニューラルネットワークモデルを作るのは、そうそう簡単なことではありません。この原稿を書く前に「MVTec Anomaly Detection Dataset」と呼ばれるデータセットを使って試していたのですが、そうしたニューラルネットワークとはならなかったため、ここでは基本に立ち返ってMNISTを例として、異常検知がどんなものかを見てみることになったのでした。
ここでは以下のような実験を行ってみます。
これを全結合型のオートエンコーダーと畳み込みオートエンコーダーで試してみます。MNISTなら全結合型のオートエンコーダーでも十分だろうという考えと、より高い精度で画像の復元ができる畳み込みエンコーダーではどうなるかを見てみようと考えたからです。
なぜ、「1」を正常なデータの例として、「8」を異常なデータの例としたかというと、見た目にこれらは大きく異なっているからです。つまり、正常と異常の差を大きく取ることで、異常なデータを(簡素なニューラルネットワークモデルでも)検出しやすくしようと考えたというわけです。
いつものコード(必要なモジュールのインポート、imshow関数/AutoEncoderクラス/train関数の定義など)は本稿末尾にまとめます。また、今回のコードはこのノートブックで公開しているので、必要に応じて参考にしてください。
それでは学習のための準備といきましょう。
既に述べた通り、今回用意するのは、次の2つのデータセットです。
PyTorchが標準で提供しているMNISTクラスは、データセットのダウンロードも行ってくれる便利なクラスですが、その内部には「0」〜「9」の全ての文字が含まれてしまいます。そのままでは使えないので、ここではMNISTクラスを継承する次のようなクラスを用意しました。
class MNISTNumed(MNIST):
def __init__(self, nums=[1], *args, **kwargs):
super().__init__(*args, **kwargs)
tmp = []
for n in nums:
subdata= [d for d, t in zip(self.data, self.targets) if t == n]
tmp.extend(subdata)
self.data = tmp
#self.data = [d for d, t in zip(self.data, self.targets) for n in nums if t == n]
数字を指定するということから「MNISTNumed」という名前にしましたが、あまりよい名前付けではないかもしれません。それはともかく、このクラスは__init__メソッド内で、パラメーターnumsに指定された数値を基に、self.data(学習時などにデータセットが提供するデータ)の内容をフィルタリングするだけで、後は基底クラスであるMNISTに全てを任せているだけです(forwardメソッドすら定義していませんね)。
パラメーターnumsにはリストとして、データセットに含めたい数字(数値)を指定してもらうことを前提としています。そして、for文とリスト内包表記を使って、指定された数字だけがself.dataに含まれるようにしています。コメント行はリスト内包表記だけで、同じことをするコードを例として書いてあります。興味のある方はそちらのコードも試してみてください。
このクラスを使って、訓練セットとテストセット、それらのデータローダーを定義するのが次のコードです。
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5))])
trainset = MNISTNumed(root='./data', nums=[1], train=True, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
testset = MNISTNumed(root='./data', nums=[1, 8], train=False, transform=transform, download=True)
testloader = DataLoader(testset, batch_size=16, shuffle=True)
MNISTクラスではなく、MNISTNumedクラスを使うところ以外は、これまでと同様なので、説明はしなくてもよいでしょう。
実際に「1」だけ、または「1」と「8」だけが含まれているかも確認しておきましょう。
iterator = iter(trainloader) # 訓練データ(「1」のみを含む)
img, _ = next(iterator)
imshow(img)
iterator = iter(testloader) # テストデータ(「1」と「8」を含む)
img, _ = next(iterator)
imshow(img)
実行結果は以下の通りです。
訓練データには「1」だけが、テストデータには「1」と「8」が含まれていることが確認できました(なお、「1」と「8」がランダムに登場するように、今回はテストローダーでは読み込み順がシャッフルされるようにしてあります)。
準備はこれで完了です。次ページでは、全結合型のオートエンコーダーで異常を検知できるかどうか、試してみることにします。
Copyright© Digital Advantage Corp. All Rights Reserved.