ここではエンコーダーをEncoderクラスとして、デコーダーをDecoderクラスとして、それらを使用するオートエンコーダーをAutoencoderクラスとして定義しましょう。実際のコードは次のようになります。なお、今回の内容はこのノートブックで公開しています。
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10 # CIFAR10もインポートしておく
import numpy as np
import matplotlib.pyplot as plt
class Encoder(torch.nn.Module):
def __init__(self, input_size):
super().__init__()
self.fc1 = torch.nn.Linear(input_size, 512)
self.fc2 = torch.nn.Linear(512, 64)
self.fc3 = torch.nn.Linear(64, 16)
self.fc4 = torch.nn.Linear(16, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = self.fc4(x)
return x
class Decoder(torch.nn.Module):
def __init__(self, output_size):
super().__init__()
self.fc1 = torch.nn.Linear(2, 16)
self.fc2 = torch.nn.Linear(16, 64)
self.fc3 = torch.nn.Linear(64, 512)
self.fc4 = torch.nn.Linear(512, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = torch.tanh(self.fc4(x)) # -1〜1に変換
return x
class AutoEncoder(torch.nn.Module):
def __init__(self, org_size):
super().__init__()
self.enc = Encoder(org_size)
self.dec = Decoder(org_size)
def forward(self, x):
x = self.enc(x) # エンコード
x = self.dec(x) # デコード
return x
EncoderクラスとDecoderクラスは全結合型のニューラルネットワークとなっていること、各層の出力が段階的に減っていく/増えていく点には注意してください。これが上で述べた段階亭に次元削減(次元増加)を行う部分に合致します。上のコードだと、エンコーダーでは「784(MNISTの手書き数字の次元数)→512→64→16→2」のように次元が削減され、デコーダーでは「2→16→64→512」のように次元が増加していきます。果たして、784次元のデータを2次元のデータに削減できるかどうかは、この後の学習結果で分かるでしょう。
なお、MNISTデータセットの手書き数字だけではなく、CIFAR-10の画像も扱えるようにすることを狙って、Encoderクラスへの入力数とDecoderクラスからの出力数を指定できるようにしてあります(Autoencoderクラスのインスタンス生成時に指定する値が、Encoderクラスへの入力/Decoderクラスからの出力の数となります)。冒頭のimport文でもそのためにCIFAR10モジュールをインポートしています。
各クラスでやっていることは一目瞭然なので、ここでは説明は省略します。
次にMNISTデータセットの読み込みと、データローダーを定義しましょう。これはPyTorchのいつものコードです。テスト用のデータセット/データローダーについてもここでは定義しておきます。バッチサイズは50として、ミニバッチで学習する際に読み込むデータ数が50個となるように設定しています。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = MNIST('./data', train=True, transform=transform, download=True)
testset = MNIST('./data', train=False, transform=transform, download=True)
batch_size = 50
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)
画像が復元できているかを確認するには、MNISTの手書き数字を表示できる必要もあります。ここでは手抜きをして、PyTorchの公式サイトで提供されている「TRAINING A CLASSIFIER」ページにあるimshow関数を拝借(して、少し改変)することにしました。以下にそのコードと、訓練データから50個を取り出して、それを表示するコードを示します。
def imshow(img):
img = torchvision.utils.make_grid(img)
img = img / 2 + 0.5
npimg = img.detach().numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
iterator = iter(trainloader)
x, _ = next(iterator)
imshow(x)
実行結果を以下に示します。
これでもともとのMNISTの手書き数字と、ニューラルネットワークモデルからの出力を画像として表示する準備ができました。次は、実際に学習を行うコードです。このコードも関数化して、この後、CIFAR-10データセットを使った学習でこの関数を呼び出すだけで済むようにしました。
def train(net, criterion, optimizer, epochs, trainloader):
losses = []
output_and_label = []
for epoch in range(1, epochs+1):
print(f'epoch: {epoch}, ', end='')
running_loss = 0.0
for counter, (img, _) in enumerate(trainloader, 1):
optimizer.zero_grad()
img = img.reshape(-1, input_size)
output = net(img)
loss = criterion(output, img)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / counter
losses.append(avg_loss)
print('loss:', avg_loss)
output_and_label.append((output, img))
print('finished')
return output_and_label, losses
input_size = 28 * 28
net = AutoEncoder(input_size)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
EPOCHS = 100
output_and_label, losses = train(net, criterion, optimizer, EPOCHS, trainloader)
ここでは1エポック(6万個のデータを使った学習)が終わるたびに、その時点でニューラルネットワークへ入力された手書き数字(ミニバッチで読み込まれる50個)と、そこから得られた出力とをリストoutput_and_labelに保存するようにしました。こうしておけば、後からエポックが終わった時点でどんな手書き数字が出力されたかを確認できるからです。
これを実行すると、次のようにズラズラとエポックごとの平均損失(今回は、そのエポックで得られた損失の総和をミニバッチの回数で除算しています)が表示されます。なお、実際にこのコードを実行すると、学習が終わるまでには20〜30分ほどかかることには注意してください。
「finished」と表示されたら、学習完了です。いつものように損失を表示してみましょう。
plt.plot(losses)
実行結果を以下に示します。
緩やかではありますが、学習は進んだようですね。とはいえ、実際にMNISTの手書き数字を復元できたかは、実際にモノを見てみないと分かりません。そこで、リストoutput_and_labelに保存しておいた出力と元画像のうち、最後のものを実際に画面に表示してみましょう。これには先ほど定義したimshow関数を使います。
output, org = output_and_label[-1]
imshow(org.reshape(-1, 1, 28, 28))
imshow(output.reshape(-1, 1, 28, 28))
実行結果を以下に示します。
Copyright© Digital Advantage Corp. All Rights Reserved.