2つのニューラルネットワーククラスが定義できたので、次は実際に学習を行うコードの番です。ここではGPUを使って計算を高速に行うことにしました(Google ColabでGPUを有効にする方法については「PyTorchからGPUを使って畳み込みオートエンコーダーの学習を高速化してみよう」を参照してください)。
GPUを使うには、計算に必要なニューラルネットワークモデル、計算対象などを全てGPUに転送しておかなければなりませんでした。というわけで、学習に使用するニューラルネットワークは次のようにして、GPUに転送しています。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'using: {device}')
netD = Discriminator(in_features).to(device)
netG = Generator(zsize, in_features).to(device)
また、損失関数としてはここではtorch.nn.BCELossを、最適化アルゴリズムにはtorch.optim.Adamクラスを使うことにしました。これについては、PyTorchのドキュメント「DCGAN TUTORIAL」などを参考にしてください(Adamはこれまでに使っていたSGDによる最適化をより巧妙に行えるようにしたものだと考えておいてください)。
ここで覚えておきたいのは、次のことです。
これを実現するような損失関数が必要なのですが、それをうまく取り扱ってくれるのがBCELossクラスです。
識別器では、訓練データを入力したときと生成器からのデータを入力したときという2つの条件があることに注意してください。学習時には識別器に訓練データを入力してその結果と正解ラベル1から損失を計算し、次に生成器からのデータを入力してその結果と正解ラベル0から損失を計算し、それら2つをまとめたものが実際の損失となります。
一方、生成器の学習に関しては、生成器から得たデータを識別器に入力して、その出力と1から損失を計算することになります。そのため、学習を行うコードは少しばかり長ったらしくなってしまいます。なるべく数式を使わないように説明することを目指しているので少し分かりにくかったかもしれませんが、だいたいこんなもんだと思ってくれれば十分です。
というわけで、以下は損失関数と最適化アルゴリズムを指定するコードです。間にあるのは、固定の正解ラベルです。one_labelsは正解ラベル1をデータローダーから読み込む数(バッチサイズ)だけ並べたもので、zero_labelsはその正解ラベル0版です。
criterion = torch.nn.BCELoss().to(device)
one_labels = torch.ones(batch_size).to(device)
zero_labels = torch.zeros(batch_size).to(device)
optimizer_netD = optim.Adam(netD.parameters(), lr=0.0002, betas=[0.5, 0.999])
optimizer_netG = optim.Adam(netG.parameters(), lr=0.0002, betas=[0.5, 0.999])
そして、今述べたような損失計算を含んだ学習コードが以下です。
losses_netD = []
losses_netG = []
EPOCHS = 50
for epoch in range(1, EPOCHS+1):
running_loss_netD = 0.0
running_loss_netG = 0.0
for count, (real_imgs, _) in enumerate(trainloader, 1):
netD.zero_grad()
# 識別器の学習
real_imgs = real_imgs.to(device)
# データローダーからデータを読み込み、識別器に入力し、損失を計算
output_real_imgs = netD(real_imgs.reshape(batch_size, -1))
output_real_imgs = output_real_imgs.reshape(batch_size)
loss_real_imgs = criterion(output_real_imgs, one_labels)
loss_real_imgs.backward()
# 生成器から得たデータを、識別器に入力し、損失を計算
z = torch.randn(batch_size, zsize).to(device)
fake_imgs = netG(z)
output_fake_imgs = netD(fake_imgs.detach()).reshape(batch_size)
loss_fake_imgs = criterion(output_fake_imgs, zero_labels)
loss_fake_imgs.backward()
# それらをまとめたものが最終的な損失
loss_netD = loss_real_imgs + loss_fake_imgs
optimizer_netD.step()
running_loss_netD += loss_netD
# 生成器の学習
netG.zero_grad()
z = torch.randn(batch_size, zsize).to(device)
fake_imgs = netG(z)
output_fake_imgs = netD(fake_imgs).reshape(batch_size)
loss_netG = criterion(output_fake_imgs, one_labels)
loss_netG.backward()
optimizer_netG.step()
running_loss_netG += loss_netG
running_loss_netD /= count
running_loss_netG /= count
print(f'epoch: {epoch}, netD loss: {running_loss_netD}, netG loss: {running_loss_netG}')
losses_netD.append(running_loss_netD.cpu())
losses_netG.append(running_loss_netG.cpu())
if epoch % 10 == 0:
z = torch.randn(batch_size, zsize).to(device)
generated_imgs = netG(z).cpu()
imshow(generated_imgs[0:8].reshape(8, 1, 28, 28))
このコードでは、エポック数を50として、10エポックごとにその時点での生成器を使って画像を生成して、表示するようにしています。実際に実行してみると次のようになります。生成された画像は最後の3つだけを掲載します。
どうでしょう。うん、ダメですね。MNIST風といえばMNIST風ですが、これで識別器をだますなんて無理だと思います。もう少しまともな画像にならないものかと考えましたが、一番簡単なのは全結合層をもう一つ増やしてみることな気がします。というわけで、DiscriminatorクラスとGeneratorクラスを次のようにしてみましょう。
Copyright© Digital Advantage Corp. All Rights Reserved.