検索
連載

PyTorchでオートエンコーダーによる画像生成をしてみよう作って試そう! ディープラーニング工作室(3/3 ページ)

画像生成の手始めとして「オートエンコーダー」と呼ばれるニューラルネットワークを作って、MNISTの手書き数字を入力して、復元してみましょう。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
前のページへ |       

オートエンコーダーがしていることって?

 ここで784次元だったデータを2次元に次元削減しました。そこで2次元に削減されたデータを少しのぞいてみましょう。これには、Autoencoderクラスのインスタンスであるnetの属性encを使えます。net.enc属性はEncoderクラスのインスタンスで、forwardメソッドを持っていますから、これにデータセットからデータを入力すれば、次元削減されたデータが取り出せるということです。

sampleloader = DataLoader(testset, batch_size=1500)
iterator = iter(sampleloader)
img, label = next(iterator)
z = net.enc(img.reshape(-1, 28 * 28))
z = z.detach().numpy()  # 後から簡単に使えるようにするための処理
print(z.shape)  # (1500, 2)

テストデータから先頭1500個のデータをエンコーダーに入力

 変数zには、テストデータから先頭1500個のデータをエンコーダー(net.enc)に入力して、次元削減を行った結果が代入されます。実行結果は省略しますが、コメントに残した通り、これは1500行2列のテンソルになります。

 ここで変数labelには、MNISTの正解ラベルが代入されていることを思い出してください(DataLoaderクラスがそうした処理をしてくれます)。つまり、変数zに得た1500個の2次元のデータが数字の0〜9のどれに対応しているかが変数labelに記録されているということです。

 そこで、これらを正解ラベルごとに集合にまとめてみることにしました。

set_list = [set() for x in range(10)]
for coord, lbl in zip(z.tolist(), label):
    set_list[lbl].add(tuple(coord))

正解ラベルの値で、変数zの値を分類して、集合にまとめるコード

 「set_list[正解ラベルの値]」には、正解ラベルが0の2次元データ、正解ラベルが1の2次元データ、……、というようにデータがまとめられました。具体的なデータを少し見てみましょう。

for idx in range(10):
    print(f'items in set_list[{idx}]:')
    for cnt, item in enumerate(set_list[idx]):
        print(item)
        if cnt > 5:
            break

各集合から要素を7つずつ取り出して表示してみるコード

 実行結果を以下に示します。

実行結果
実行結果

 さすがに人が目視でこれらのデータを見ても、その傾向はつかめそうにありません(しかも、ここで表示しているのはほんの7個ずつのデータですから、それはそういうものです)。

 ところで、Encoderクラスは元のデータを2次元に削減するものでした。そして、2次元のデータであれば、グラフにプロットできるはずです。そこで、次元削減後のデータをグラフにプロットしてみた方が目視するよりも全然よいような気がします。ここでは、以下のコードで2次元に削減されたデータを散布図としてプロットしてみましょう。

colorlist = ["r", "g", "b", "c", "k", "y", "orange", "lightgreen", "hotpink", "yellow"]
plt.figure(figsize=(10, 10))
for idx in range(10):
    for x, y in set_list[idx]:
        plt.scatter(x, y, c=colorlist[idx])
description = [f"{idx}: {colorlist[idx]}" for idx in range(10)]
print(description)

正解ラベルごとに色を付け、2次元データを座標としてグラフにプロットする

 実際には、これよりも簡単で高速に散布図をプロットできるのですが(「plt.scatter(z[:, 0], z[:, 1], c=label, ……)」のようにできるでしょう)、ここではせっかく正解ラベルごとにデータをまとめたので、for文を二重に使った遅いグラフ描画を選んでいます(ノートブックには上で示した方法でグラフをプロットするコードも掲載してあるので、興味のある方は見てみてください)。それはともかく、実行結果を以下に示します。

実行結果
実行結果

 各色のドットがある範囲にまとまって分布していると、このグラフからはいえるでしょう。つまり、オートエンコーダーとは、多数の次元を持つデータをより少ない次元の空間へとマッピングするものだと考えられます(ここでは2次元でしたが、もちろんもっと高次元の空間にマッピングすることもあるでしょう)。「次元削減後の空間」で元データがどんな分布になるのかが、手書き数字で何が描かれているのか、その特徴を表したものといえるかもしれません。

 そして、エンコーダーにあるニューラルネットワークの各層では、元データの次元を削減したものが取る分布を決定するように重みとバイアスが学習によって調整され、デコーダーにあるニューラルネットワークの各層では、エンコーダーが算出した特徴を基に画像を復元できるように重みとバイアスが学習によって調整されるのだと考えられます。

 ここで、黒いドット(4)とイエローのドット(9)に注目してみてください。先ほどは文字「4」と「9」は、復元時にお互いの文字に間違えられることがあるといったことを述べましたが、このグラフでは黒と黄色のドットは似た位置に描かれています(さらにいえば、文字「7」を表すライトグリーンのドットはそれらの範囲を含んでより広い範囲に分布しています。文字「7」も「4」と間違って復元されることがあったというのが筆者の印象で、この分布はこの印象がある程度は合っていることを示唆しているといえるでしょう)。

 これはきっと、784次元のデータを2次元まで次元削減したときに、「手書き数字の特徴を表す」これら2次元のデータが近い値になったことを意味しています。そのため、これら2次元のデータから784次元のデータへとデコード(復元)する際には、ちょっとしたことで、間違った数字になってしまうのだと考えられます。

 となると、文字の復元をより精度高くしたいのであれば、分布が重ならないようにすればよさそうです。そのためには、例えば、最終的な削減後の次元を2次元よりも大きくするといったことが考えられるでしょう(ここでは、そこまではしませんが、後でこっそり試してみる予定です)。

 なお、このように元のデータを圧縮した後に得られるデータのことを「潜在変数」と呼びます。潜在変数とは、入力されたデータから、そのデータを復元するために不要なデータを取り除いたもの、いわば「手書き数字の特徴を凝縮した」データだといえます。

 ここまでオートエンコーダーがどんなことをしているのかを少し考えました。最後に、MNISTと同様なデータセットであるCIFAR-10を使って、上で作成したAutoencoderクラスのインスタンスを学習させてみたいと思います。どうなるでしょうか。

CIFAR-10の画像を使って学習してみた

 ここからは早足でいきます。といっても、コード量はごくわずかで、これまでのコードとほぼ同様です。まずはCIFAR-10のデータセットの読み込みとデータローダーのセットアップです。

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10('./data', train=True, transform=transform, download=True)
testset = CIFAR10('./data', train=False, transform=transform, download=True)

batch_size = 50
trainloader = DataLoader(trainset, batch_size=50, shuffle=True)
testloader = DataLoader(testset, batch_size=50, shuffle=False)

CIFAR-10データセットの読み込みとデータローダーの設定

 CIFAR-10のデータセットは、MNISTの白黒画像とは異なり、カラー画像で、画素数も32×32ピクセルになります。MNISTは「1×28×28」というサイズの画像でしたが、CIFAR-10は「3×32×32」というサイズの画像となっています。幾つかの画像を表示してみましょう。

iterator = iter(trainloader)
img, _ = next(iterator)
imshow(img)

CIFAR-10の画像を50個表示するコード

 実行結果を以下に示します。

実行結果
実行結果

 これは明らかにMNISTよりもはるかに複雑な画像です。果たして、上で作ったAutoencoderクラスはうまくこれらを処理できるのでしょうか。というわけで、学習してみましょう。

input_size = 3 * 32 * 32
net2 = AutoEncoder(input_size)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net2.parameters(), lr=0.1)
EPOCHS = 100

output_and_label2, losses2 = train(net2, criterion, optimizer, EPOCHS, trainloader)

CIFAR-10を使ってAutoencoderクラスのインスタンスを学習させる

 上でも述べましたが、画像サイズが違うので、変数input_sizeの値が「28 * 28」から「3 * 32 * 32」に変わったこと以外は、基本的には上と変わりません。これを実行すると次のようになります。

実行結果
実行結果

 出力の最初の数行を見た感じでは、損失の減少の度合いが思ったよりもよくなさそうですが、一応、学習は進んでいるようです。そこで、学習が終わった後に最後の学習で得られた出力とその元画像をこれまでと同様にimshow関数で表示してみましょう。

output, org = output_and_label2[-1]
imshow(org.reshape(-1, 3, 32, 32))
imshow(output.reshape(-1, 3, 32, 32))

最後の学習で得られた出力とその元データを表示してみる

 実行結果を以下に示します。

実行結果
実行結果

 これはひどい! 元画像を少しでも復元できているのであれば、まだよいのですが、全く意味が分からない画像となってしまいました。というわけで、次回はCIFAR-10の画像を復元できるようなオートエンコーダーを作ってみることにしましょう。

「作って試そう! ディープラーニング工作室」のインデックス

作って試そう! ディープラーニング工作室

Copyright© Digital Advantage Corp. All Rights Reserved.

前のページへ |       
[an error occurred while processing this directive]
ページトップに戻る