PyTorchで全結合型のGANを作ってみよう:作って試そう! ディープラーニング工作室(2/2 ページ)
GANの基本的な構成を見た後に、PyTorchのLinearクラスでMNISTを対象としたGANを実装してみましょう。
学習
2つのニューラルネットワーククラスが定義できたので、次は実際に学習を行うコードの番です。ここではGPUを使って計算を高速に行うことにしました(Google ColabでGPUを有効にする方法については「PyTorchからGPUを使って畳み込みオートエンコーダーの学習を高速化してみよう」を参照してください)。
GPUを使うには、計算に必要なニューラルネットワークモデル、計算対象などを全てGPUに転送しておかなければなりませんでした。というわけで、学習に使用するニューラルネットワークは次のようにして、GPUに転送しています。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'using: {device}')
netD = Discriminator(in_features).to(device)
netG = Generator(zsize, in_features).to(device)
また、損失関数としてはここではtorch.nn.BCELossを、最適化アルゴリズムにはtorch.optim.Adamクラスを使うことにしました。これについては、PyTorchのドキュメント「DCGAN TUTORIAL」などを参考にしてください(Adamはこれまでに使っていたSGDによる最適化をより巧妙に行えるようにしたものだと考えておいてください)。
ここで覚えておきたいのは、次のことです。
- 識別器では、訓練データが入力されたら出力が1になり、生成器によるデータが入力されたら出力が0となることが理想的
- 生成器では、それが生成したデータを識別器に入力したときに、その出力が1となることが理想的
これを実現するような損失関数が必要なのですが、それをうまく取り扱ってくれるのがBCELossクラスです。
識別器では、訓練データを入力したときと生成器からのデータを入力したときという2つの条件があることに注意してください。学習時には識別器に訓練データを入力してその結果と正解ラベル1から損失を計算し、次に生成器からのデータを入力してその結果と正解ラベル0から損失を計算し、それら2つをまとめたものが実際の損失となります。
一方、生成器の学習に関しては、生成器から得たデータを識別器に入力して、その出力と1から損失を計算することになります。そのため、学習を行うコードは少しばかり長ったらしくなってしまいます。なるべく数式を使わないように説明することを目指しているので少し分かりにくかったかもしれませんが、だいたいこんなもんだと思ってくれれば十分です。
というわけで、以下は損失関数と最適化アルゴリズムを指定するコードです。間にあるのは、固定の正解ラベルです。one_labelsは正解ラベル1をデータローダーから読み込む数(バッチサイズ)だけ並べたもので、zero_labelsはその正解ラベル0版です。
criterion = torch.nn.BCELoss().to(device)
one_labels = torch.ones(batch_size).to(device)
zero_labels = torch.zeros(batch_size).to(device)
optimizer_netD = optim.Adam(netD.parameters(), lr=0.0002, betas=[0.5, 0.999])
optimizer_netG = optim.Adam(netG.parameters(), lr=0.0002, betas=[0.5, 0.999])
そして、今述べたような損失計算を含んだ学習コードが以下です。
losses_netD = []
losses_netG = []
EPOCHS = 50
for epoch in range(1, EPOCHS+1):
running_loss_netD = 0.0
running_loss_netG = 0.0
for count, (real_imgs, _) in enumerate(trainloader, 1):
netD.zero_grad()
# 識別器の学習
real_imgs = real_imgs.to(device)
# データローダーからデータを読み込み、識別器に入力し、損失を計算
output_real_imgs = netD(real_imgs.reshape(batch_size, -1))
output_real_imgs = output_real_imgs.reshape(batch_size)
loss_real_imgs = criterion(output_real_imgs, one_labels)
loss_real_imgs.backward()
# 生成器から得たデータを、識別器に入力し、損失を計算
z = torch.randn(batch_size, zsize).to(device)
fake_imgs = netG(z)
output_fake_imgs = netD(fake_imgs.detach()).reshape(batch_size)
loss_fake_imgs = criterion(output_fake_imgs, zero_labels)
loss_fake_imgs.backward()
# それらをまとめたものが最終的な損失
loss_netD = loss_real_imgs + loss_fake_imgs
optimizer_netD.step()
running_loss_netD += loss_netD
# 生成器の学習
netG.zero_grad()
z = torch.randn(batch_size, zsize).to(device)
fake_imgs = netG(z)
output_fake_imgs = netD(fake_imgs).reshape(batch_size)
loss_netG = criterion(output_fake_imgs, one_labels)
loss_netG.backward()
optimizer_netG.step()
running_loss_netG += loss_netG
running_loss_netD /= count
running_loss_netG /= count
print(f'epoch: {epoch}, netD loss: {running_loss_netD}, netG loss: {running_loss_netG}')
losses_netD.append(running_loss_netD.cpu())
losses_netG.append(running_loss_netG.cpu())
if epoch % 10 == 0:
z = torch.randn(batch_size, zsize).to(device)
generated_imgs = netG(z).cpu()
imshow(generated_imgs[0:8].reshape(8, 1, 28, 28))
このコードでは、エポック数を50として、10エポックごとにその時点での生成器を使って画像を生成して、表示するようにしています。実際に実行してみると次のようになります。生成された画像は最後の3つだけを掲載します。
どうでしょう。うん、ダメですね。MNIST風といえばMNIST風ですが、これで識別器をだますなんて無理だと思います。もう少しまともな画像にならないものかと考えましたが、一番簡単なのは全結合層をもう一つ増やしてみることな気がします。というわけで、DiscriminatorクラスとGeneratorクラスを次のようにしてみましょう。
class Discriminator(torch.nn.Module):
def _init_weights(self):
for weight in self.parameters():
torch.nn.init.normal_(weight, 0.0, 0.02)
def __init__(self, in_features):
super().__init__()
self.fc1 = torch.nn.Linear(in_features, 384, bias=False)
self.fc2 = torch.nn.Linear(384, 128, bias=False)
self.fc3 = torch.nn.Linear(128, 32, bias=False)
self.fc4 = torch.nn.Linear(32, 1, bias=False)
self.relu = torch.nn.LeakyReLU()
self._init_weights()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = torch.sigmoid(self.fc4(x)) # 0 to 1(0: fake, 1: true)
return x
class Generator(torch.nn.Module):
def _init_weights(self):
for weight in self.parameters():
torch.nn.init.normal_(weight, 0.0, 0.02)
def __init__(self, zsize, in_features):
super().__init__()
self.fc1 = torch.nn.Linear(zsize, 256, bias=False)
self.fc2 = torch.nn.Linear(256, 512, bias=False)
self.fc3 = torch.nn.Linear(512, 1024, bias=False)
self.fc4 = torch.nn.Linear(1024, in_features, bias=False)
self.relu = torch.nn.ReLU()
self._init_weights()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = torch.tanh(self.fc4(x)) # -1 to 1
return x
層を増やしたのに加えて、正規分布に従って重みを初期化するコードを追加しています。これは重みを正規分布に従う値に初期化している、という話をDCGANについての論文で見かけたからです(今回行ったのはDCGANではないので、その方法論が有効かどうかは微妙かもしれません。なお、その論文では層が深くなるときには全結合層は取り去る、といったことも書いてあるので、そもそも今回の試みは否定されているような気もします)。
これらのクラスを使って、先ほどと同じコードで学習をした結果を以下に示します。
今度はどうでしょう。少なくとも先ほどのものよりはメリハリが付いたように思えます(その一方で数字はあまりうまく生成できていないような気もします)。本来は損失関数をグラフにしたり、考察をもう少し加えたりするべきなのですが、今回はこの辺で終わりにしましょう。次回はもう少しよい結果を得る方法について考えてみます。では、最後にいつものコードを掲載しておきましょう。
いつものコード
import torch
from torch import optim
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
])
batch_size = 100
trainset = MNIST('.', train=True, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
def imshow(img):
img = torchvision.utils.make_grid(img)
img = img / 2 + 0.5
npimg = img.detach().numpy()
plt.figure(figsize=(12, 12))
plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
plt.show()
Copyright© Digital Advantage Corp. All Rights Reserved.