検索
連載

MNISTを使ってオートエンコーダーによる異常検知を試してみよう作って試そう! ディープラーニング工作室(3/3 ページ)

オートエンコーダーの活用例の一つである異常検知を、MNISTの手書き数字を例に体験してみましょう。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
前のページへ |       

畳み込みオートエンコーダーで試してみる

 次に、全結合型のオートエンコーダーよりも(多分)よい精度で画像を復元できるであろう畳み込みオートエンコーダーで試してみましょう。

 エンコーダーとデコーダーは次のようにしました。

enc = torch.nn.Sequential(
    torch.nn.Conv2d(1, 3, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(3, 5, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2)
)

dec = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(5, 3, kernel_size=2, stride=2),
    torch.nn.ReLU(),
    torch.nn.ConvTranspose2d(3, 1, kernel_size=2, stride=2),
    torch.nn.Tanh()
)

畳み込みオートエンコーダーのエンコーダーとデコーダー

 扱うものがMNISTということもあって、ここではチャネル数もカーネルサイズも比較的小さめにしてあります。

 このエンコーダーとデコーダーと使って、学習を行い、その後でtest関数を呼び出した結果を以下に示します(コードについては公開しているノートブックを参考にしてください。といっても、先ほどのコードとほとんど変わらないものです)。

元画像、出力画像、差分画像の表示結果
元画像、出力画像、差分画像の表示結果

 元画像と出力画像(復元画像)を見ると、思いの外、「8」の方もうまく復元できているように見えます。とはいえ、差分画像を見ると、やはり「1」についてはそれほど白い部分は多くなく(差分が小さい)、「8」については白い部分が多い(差分が大きい)という傾向は先ほどと同様といえるでしょう。

 ということは、「1」と「8」の各画像について、各ピクセルの差分の和を取って、そこから総和を得て、「1」と「8」の画像グループごとに差分の平均値を得ることで、正常(1)か異常(8)かの判断は可能なようです。

 最後に、チャネル数を増やした畳み込みオートエンコーダーについても簡単に見てみましょう。

enc = torch.nn.Sequential(
    torch.nn.Conv2d(1, 8, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(8, 16, kernel_size=3, padding=1),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2)
)

dec = torch.nn.Sequential(
    torch.nn.ConvTranspose2d(16, 8, kernel_size=2, stride=2),
    torch.nn.ReLU(),
    torch.nn.ConvTranspose2d(8, 1, kernel_size=2, stride=2),
    torch.nn.Tanh()
)

チャネル数を増やした畳み込みオートエンコーダーのエンコーダーとデコーダー

 こちらのエンコーダーとデコーダーを使って、学習を行い、上と同様にtest関数を呼び出した結果が以下です。

元画像、出力画像、差分画像の表示結果
元画像、出力画像、差分画像の表示結果

 学習していないハズの「8」の復元度合いはさらによくなっています。しかし、実は「1」の復元度合いも高まっているようです。それは差分画像を見ると分かります。「1」の方の差分は枠線がかすかに見えるだけでくらいになりました。「8」の方もずいぶんと白い部分は減りましたが、やはり「1」に比べると差分が大きい傾向に変わりはありません。しきい値を調整することで、これらを区別することはできるのではないでしょうか。

 「生まれたときから一流のモノしか目にしてないから、よくできた贋作(がんさく)でもすぐに分かる」というレベルにはほど遠いのでしょうが、それでもそれなりの観察眼は持っているといえるかもしれません。

 とはいえ、今回はMNISTの手書き数字、しかも形状に大きな差があるであろう「1」と「8」を例としたので、それなりの結果が出たと考えられます。冒頭でも述べましたが、実際にはネジに付いた小さなキズや、絨毯の少しのほつれやシミなどを見つけられるほどの精度が必要になるでしょう。

 MNISTの特定の画像を正常、それ以外を異常とする異常検知と、ほんとうに実践的な異常検知との間には大きな差があることは確かです。とはいえ、その雰囲気はなんとなくつかめたのではないでしょうか。

 ここまで数回に分けて、オートエンコーダーを話題としてきましたが、次回からはまた別のことを模索していく予定です。

いつものコード

 以下はこれまでの回と同様のコードです。ただし、今回は全結合型のオートエンコーダーと畳み込みオートエンコーダーが登場しました。このとき、ニューラルネットワークモデルへの入力形式が異なるため(28×28の2次元画像を全結合型では一次元のデータにする必要があります)、それに対応するパラメーターを持たせて、内部で処理を分けるようにしました。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

def imshow(img):
    img = torchvision.utils.make_grid(img)
    img = img / 2 + 0.5
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

class AutoEncoder2(torch.nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        return x

def train(net, criterion, optimizer, epochs, trainloader, linear=False):
    losses = []
    output_and_label = []

    for epoch in range(1, epochs+1):
        print(f'epoch: {epoch}, ', end='')
        running_loss = 0.0
        for counter, (img, _) in enumerate(trainloader, 1):
            optimizer.zero_grad()
            if linear:
                img = img.reshape(-1, 1 * 28 * 28)
            output = net(img)
            loss = criterion(output, img)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / counter
        losses.append(avg_loss)
        print('loss:', avg_loss)
        output_and_label.append((output, img))
    print('finished')
    return output_and_label, losses

いつものコードの皆さん


「作って試そう! ディープラーニング工作室」のインデックス

作って試そう! ディープラーニング工作室

Copyright© Digital Advantage Corp. All Rights Reserved.

前のページへ |       
[an error occurred while processing this directive]
ページトップに戻る