CIFAR-10の画像データのエンコード/デコードをうまく行うために、圧縮後の次元数とエポック数を変化させながら学習させてみましょう。
前回はMNISTの手書き数字を入力すると、その情報(元は28×28=784次元のデータ)を段階的に次元削減していきながら、2次元の情報とした後で、それを今度は784次元のデータへと復元する「オートエンコーダー」を作成してみました。その中で、オートエンコーダーのエンコーダー部では、元のデータをより次元数の少ない空間(前回は最終的には2次元空間)へとマッピングすること、デコーダー部ではマッピングされたデータを基に元の情報へと復元することを見てきました。そして、エンコーダー部で圧縮されたデータは「元データを復元するために不要な情報を削除したもの」「潜在変数と呼ばれる」であるといったことも見ました。
最後に、MNISTの手書き数字ではそれなりの結果を出した「全結合型のオートエンコーダー」を使って、CIFAR-10と呼ばれるMNISTよりも情報量が多い画像データのエンコード/デコードがうまくいくかを試してみたところ、見事に失敗をしたのでした。
そこで、今回は幾つかの手段で、前回はダメだったCIFAR-10の画像データのエンコード/デコードが成功するように実験してみましょう。
どうすればCIFAR-10の画像データをうまくエンコード/デコードできるかについてはいろいろと考える点がありそうです。まず考えておきたいのは、元画像のデータ量の違いについてです。MNISTとCIFAR-10の画像データの大ざっぱなフォーマットは次のようになっています。
MNISTの画像1枚当たりのデータ量は28×28×1バイト=784バイトです。一方、CIFAR-10の画像1枚当たりのデータ量は32×32×3バイト=3072バイトです。つまり、CIFAR-10とMNISTとを比べると、画像1枚当たり、約4倍の情報量があるということです。
それから、MNISTの手書き数字はグレイスケールで背景部分も多かったことも思い出してください。これに対して、CIFAR-10は32×32ピクセルのあらゆる部分に情報が詰まっている点も大きな違いといえそうです。
上の画像を見れば、MNISTのようなデータなら2次元のデータまで圧縮してもうまくいきそうですが、CIFAR-10の画像データを2次元まで圧縮するというのはさすがに無理がありそうです。
というわけで、まずはどこまで次元を削減すればよいかが問題となりそうです。
それから学習をどこまで進めるかという話もあります。簡単にいえば、「圧縮後のデータの次元数が実は十分かもしれないけれど、学習が足りていないから、うまくエンコード/デコードできない」という事態に陥る可能性もありそうです。MNISTの手書き数字を処理するオートエンコーダーではエポック数を100としていましたが、CIFAR-10ではより多くのエポックを学習する必要があるかもしれません。
以上の2点について考えながら、どうにかCIFAR-10の画像データを処理できるオートエンコーダーを完成させることにしましょう。
今回も以下のような関数やクラスを定義します。
この他にも必要なパッケージやモジュールのインポート、データセット読み込みとデータローダーの定義を行うコードなどがあります。これらのコードとimshow関数とtrain関数については以下にまとめて示します(train関数は前回から些細な修正点がありますが、これについては説明を省略します)。また、今回のコードはこのノートブックで公開しているので必要に応じて参照してください。
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import numpy as np
import matplotlib.pyplot as plt
def imshow(img):
img = torchvision.utils.make_grid(img)
img = img / 2 + 0.5
npimg = img.detach().numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def train(net, criterion, optimizer, epochs, trainloader, input_size):
losses = []
output_and_label = []
for epoch in range(1, epochs+1):
print(f'epoch: {epoch}, ', end='')
running_loss = 0.0
for counter, (img, _) in enumerate(trainloader, 1):
optimizer.zero_grad()
img = img.reshape(-1, input_size)
output = net(img)
loss = criterion(output, img)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / counter
losses.append(avg_loss)
print('loss:', avg_loss)
output_and_label.append((output, img))
print('finished')
return output_and_label, losses
次にオートエンコーダーを実装するAutoEncoder2クラスを以下に示します。
class AutoEncoder2(torch.nn.Module):
def __init__(self, enc, dec):
super().__init__()
self.enc = enc
self.dec = dec
def forward(self, x):
x = self.enc(x)
x = self.dec(x)
return x
前回は__init__メソッド内で、別途定義したEncoderクラスとDecoderクラスのインスタンスを生成していましたが、今回はインスタンス生成時にパラメーターencにエンコーダーとなるオブジェクトを、パラメーターdecにデコーダーとなるオブジェクトを受け取り、それらをインスタンス変数に保存して、forwardメソッドではそれらを呼び出すだけとしています。こうすれば、外部でエンコーダーとデコーダーのインスタンスを自由に生成して、それらをオートエンコーダーに渡せるようになります。
エンコーダーやデコーダーは例えば、次のようにPyTorchのSequentialクラス(torch.nn.Sequentialクラス)を利用して作成することにしました。Sequentialクラスは、そのインスタンス生成時に引数に指定したニューラルネットワークモジュールを格納するコンテナとなります(いってみれば、独自のクラスを定義して、その__init__メソッドで必要なものをインスタンス化するのではなく、その場その場で必要なものを使って、エンコーダーとデコーダーを組み立てる感じです)。
input_size = 3 * 32 * 32
encoder = torch.nn.Sequential(
torch.nn.Linear(input_size, input_size // 4),
torch.nn.ReLU(),
torch.nn.Linear(input_size // 4, input_size // 12),
torch.nn.ReLU(),
torch.nn.Linear(input_size // 12, input_size // 24)
)
decoder = torch.nn.Sequential(
torch.nn.Linear(input_size // 24, input_size // 12),
torch.nn.ReLU(),
torch.nn.Linear(input_size // 12, input_size // 4),
torch.nn.ReLU(),
torch.nn.Linear(input_size // 4, input_size),
torch.nn.Tanh()
)
net = AutoEncoder2(encoder, decoder)
変数input_sizeには上で述べたCIFAR-10の画像データのサイズ(次元数)を代入しています。その後がエンコーダーとデコーダーとなるオブジェクトをSequentialクラスを使って生成しているコードです。encoderオブジェクトには、input_sizeの4分の1、12分の1、24分の1という具合に段階を踏んで次元を削減しています。decoderはその逆を行っています(活性化関数のインスタンスもSequentialクラスのインスタンスには含まれている点にも注意してください。こうすることで、AutoEncoder2クラスのforwardメソッドでは、それぞれのインスタンスに入力を引き渡すだけで済むようになっています)。
削減後のデータは3072÷24=128次元になるということです。MNISTの手書き数字を処理したときには、最終的に2次元に圧縮していたことを考えると、次元がかなり多いような気もしますが、はてさてうまくいくのでしょうか。
Copyright© Digital Advantage Corp. All Rights Reserved.