MNISTを使ってオートエンコーダーによる異常検知を試してみよう作って試そう! ディープラーニング工作室(2/3 ページ)

» 2020年09月11日 05時00分 公開
[かわさきしんじDeep Insider編集部]

全結合型のオートエンコーダーで試してみる

 ここまでの数回ではオートエンコーダーを実装するAutoEncoder2クラスは、エンコーダーとデコーダーを__init__メソッドのパラメーターに受け取るようにしていました。そのため、ここでも全結合型のオートエンコーダーのエンコーダーとデコーダーを次のように定義して、それをAutoEncoder2クラスのインスタンス生成時にそれを渡すようにしましょう。

enc = torch.nn.Sequential(
    torch.nn.Linear(1 * 28 * 28, 384),
    torch.nn.ReLU(),
    torch.nn.Linear(384, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 8),
    torch.nn.Tanh(),
)

dec = torch.nn.Sequential(
    torch.nn.Linear(8, 64),
    torch.nn.ReLU(),
    torch.nn.Linear(64, 384),
    torch.nn.ReLU(),
    torch.nn.Linear(384, 1 * 28 * 28),
    torch.nn.Tanh()
)

全結合型のオートエンコーダーのエンコーダーとデコーダー

 エンコーダーとデコーダーはPyTorchのLinearクラスを使って、最終的には8次元のデータに次元を削減し、それを28×28(=784)次元に復元するものです。

 これらを使って、学習を行うのが次のコードです。

net_linear = AutoEncoder2(enc, dec)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net_linear.parameters(), lr=0.5)
EPOCHS = 100

output_and_label, losses = train(net_linear, criterion, optimizer, EPOCHS, trainloader, linear=True)

全結合型のオートエンコーダーの学習

 学習が終わったら、最後の学習結果から元画像と復元画像を確認しておきましょう。

out, org = output_and_label[-1]

imshow(org.reshape(-1, 1, 28, 28))
imshow(out.reshape(-1, 1, 28, 28))

最後の学習時の元画像と復元データ

 もちろん、以下のように、復元できています。

実行結果 実行結果

 いい感じで学習ができたのを確認したところで、関数を1つ、定義しておきます。この関数は、「1」と「8」を含んだテストデータの全てを(テストローダー経由で)今作成したニューラルネットワークモデルに入力して、その結果を得るものです。ここでは、元画像、出力画像、元画像と出力画像の差分の3つを要素とするタプルを戻り値とします。差分は単純に元画像のテンソルから出力画像のテンソルを減算して、絶対値を取るだけとしました。

def test(net, testloader, linear=False):
    result = []
    diff = []
    org = []

    for item, _ in testloader:
        if linear:
            out = net(item.reshape(-1, 1 * 28 * 28))
            out = out.reshape(-1, 1, 28, 28)
        else:
            out = net(item)
        org.extend(item)
        result.extend(out)
        diff.extend(abs(item - out))
    
    return (org, result, diff)

テストを実行する関数

 PyTorchではテンソル同士を減算できるので、ここでは単に「item - out」のようにするだけで、各ピクセルの値の差分を得られていることに注意してください。

 この後に畳み込みオートエンコーダーでもこの関数を使い回せるように、全結合型のオートエンコーダーと畳み込みエンコーダーでのニューラルネットワークモデルへの入力の違いを吸収するようにしています。

 この関数を実行するには、先ほど作成したニューラルネットワークモデル(net_linear)とテストローダーを引数に指定するだけです。

(org, result, diff) = test(net_linear, testloader, linear=True)

imshow(org[0:20])
imshow(result[0:20])
imshow(diff[0:20])

test関数の呼び出し

 関数から元画像、出力画像、差分を受け取ったら、imshow関数でそれらを表示しています。その結果は次のようになります。

実行結果 実行結果
上から元画像、出力画像、差分

Copyright© Digital Advantage Corp. All Rights Reserved.

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。