PyTorchでDCGANを作ってみよう:作って試そう! ディープラーニング工作室(2/2 ページ)
PyTorchが提供するConv2dクラスとConvTranspose2dクラスを使ってDCGANを実装しながら、その特徴を見ていきましょう。
層を増やしてチャネルも増やしてみる
ここでは識別器と生成器の両者で層とチャネルの数を増やしてみましょう。そうすれば、ニューラルネットワークの表現力は高くなる、つまり、より高精度な画像を生成できそうな気がします。今度は識別器と生成器の両方のコードをまとめて示します。
discriminator = nn.Sequential(
nn.Conv2d(1, 16, 5, 2, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(16, 32, 5, 2, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 3, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 1, 2, bias=False),
nn.Sigmoid()
)
generator = nn.Sequential(
nn.ConvTranspose2d(zsize, feature_maps * 8, 4, 1, 0, bias=False),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 2, 1, 2, 2, 2, bias=False),
nn.Tanh()
)
netD = MakeFrom(discriminator)
netG = MakeFrom(generator)
畳み込み(およびその反対の処理である転置畳み込み)を行う層を3つから4つに増やしてみました。チャネル数も先ほどのコードよりはそれなりに増やしています。こうすることでこのニューラルネットワークが持つ重み(やバイアス)の数は多くなりますから、より高い表現力を手に入れることができるはずです。
というわけで、train関数を実行した結果を以下に示します(生成画像のみ)。
どうでしょうか。先ほどの画像よりはかなりよい感じになりました。取りあえず、層とチャネル数を増やすのはよい方向に作用したようですね。
バッチ正規化を行うように変更してみる
ところで、これまでの2つのコードを実行しても、必ず上で見たような画像が生成できるとは限りません。以下のような画像が生成されることもよくあります。
GANは学習が安定せずに、上のような画像(さらにひどいと全く意味のない画像)を生成してしまうのがよくあることと知られています。これを回避するための方策が上で紹介したDCGANの論文で提示されています。そうした方策の1つが「バッチ正規化」(Batch Normalization)と呼ばれる処理を取り入れることです。
本稿では詳しくは説明をしませんが、バッチ正規化は「Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift」という論文で提案されたものです。今やっているConv2dクラスやConvTranspose2dクラスを利用したGANに関していえば、ミニバッチを学習する際に、ある層からの出力(チャネルごと)を平均が0で分散が1となるように出力を加工(正規化)してから(活性化関数を経由して)、次の層への入力とします。このようにすることで学習を安定的/効率的に進められると論文では述べられています。
PyTorchでこれを行うにはBatchNorm2dクラスなどを使えます。実際にこれを使用した識別器と生成器のコードを以下に示します。バッチ正規化は、畳み込み層と活性化関数の間に挟み込むことには注意してください。また、インスタンス生成時の引数には直前の層(畳み込み層)の出力チャネルの数(第2引数の値)を指定します(識別器の最初の層でバッチ正規化を行っていないのは論文の内容に合わせたものです)。
discriminator = nn.Sequential(
nn.Conv2d(1, 16, 5, 2, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(16, 32, 5, 2, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 3, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 1, 2, bias=False),
nn.Sigmoid()
)
generator = nn.Sequential(
nn.ConvTranspose2d(zsize, feature_maps * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 2, 1, 2, 2, 2, bias=False),
nn.Tanh()
)
netD = MakeFrom(discriminator)
netG = MakeFrom(generator)
今回のようにそれほど層が深くないニューラルネットワークではそれほど意味はないかもしれませんが、DCGANの論文ではそのアプローチの核の一つとしてバッチ正規化が挙げられているほどのものです(その他には、全て畳み込み層で構成される、全結合層を排除することが挙げられていますが、ここまでに作ってきた識別器と生成器ではそうなっていることに注意してください。バッチ正規化までを含めて初めて全部入りのDCGANだといえるでしょう)。
生成された画像は次のようになります。
先ほどのものと比べると、画像周辺部が汚いような気がしますが数字についてはまあ似たようなものです。バッチ正規化には、学習率として(バッチ正規化を行わないときよりも)大きめの値を取れる、ネットワークの収束が早期といった特性がありますが、それらは出力品質が著しく向上することを約束するものではないということでしょう。読者が自分で試したときにはこれよりも汚い画像となる可能性はあります(キレイになるかもしれません)。筆者も何度か試して、よさげなものを選択しています。学習がうまくいきやすくはなっても常にうまくいくとは限りません。
とはいえ、現在では学習を安定して効率的に進めるために、多くのニューラルネットワークでバッチ正規化やそこから派生した技術が使われるようになってもいます。そういうものがあるということは頭に入れておきましょう。
重みの初期値
DCGANの論文では、ニューラルネットワークの重みの初期値についても言及があります。論文では重みは平均0、標準偏差0.02の正規分布となるように初期化しているとのことです。とはいえ、先ほど述べたようなDCGANにおいて核となるアプローチではないようです(あくまでもそのようにしたという実験環境の明記だと筆者には感じられました)。
ですが、ここでは試しに重みを初期化する関数を定義して、それを使って、今述べたように識別器と生成器の両者で重みとバイアスを初期化してみることにします(重みの初期化関数はPyTorchのドキュメント「DCGAN Tutorial」にあるものをそのまま使用しています(以下のコードではBatchNorm2dクラスのインスタンスについては重みは平均1、標準偏差が0.02となるように初期化して、バイアスは0に初期化するようになっています。これだと論文とは話が食い違っているような気もするので、興味のある方は初期化方法を変更してみるのもよいでしょう)。
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
この関数をニューラルネットワークモジュール(この場合はnetDやnetGのインスタンス)が持つapplyメソッドに渡すと、そのネットワークモジュールが格納しているニューラルネットワーククラスの各インスタンスを引数として、それが呼び出されるようになっています。つまり、パラメーターmにはConv2dクラスやBatchNorm2dクラスなどのインスタンスが渡されるので、上の関数では重みとバイアスを初期化したいインスタンスが渡されたかどうかを判断して、そうであれば初期化を行うようにしています。実際にこれを使って初期化を行うコードを以下に示します。
discriminator = nn.Sequential(
nn.Conv2d(1, 16, 5, 2, bias=False),
nn.LeakyReLU(0.2),
nn.Conv2d(16, 32, 5, 2, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(0.2),
nn.Conv2d(32, 64, 3, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 1, 2, bias=False),
nn.Sigmoid()
)
generator = nn.Sequential(
nn.ConvTranspose2d(zsize, feature_maps * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(),
nn.ConvTranspose2d(feature_maps * 2, 1, 2, 2, 2, bias=False),
nn.Tanh()
)
netD = MakeFrom(discriminator)
netG = MakeFrom(generator)
netD.apply(weights_init)
netG.apply(weights_init)
上の識別器と生成器を使って学習を行った結果を以下に示します。
悪くない結果が出たようです(ただし、筆者が試しているときには、思ったほどではない結果となることもありました。この辺はエポック数を増やすことで解決できるのかもしれません)。
今回はConv2dクラス、ConvTranspose2dクラス、BatchNorm2dクラスを使ってDCGANを作ってみました。が、これでは作ってみただけであって、それほど深い考察をできていません。次回はその辺について考えてみる予定です。
最後にいつものコードを掲載しておきます。
いつものコード
以下に必要なモジュールをインポートするコード、MNIST画像を読み出すためのデータセットとデータローダーの定義、画像表示に使用している関数の定義を示します。
import torch
from torch import optim
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.5, ))
])
batch_size = 100
trainset = MNIST('.', train=True, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
def imshow(img):
img = torchvision.utils.make_grid(img)
img = img / 2 + 0.5
npimg = img.detach().numpy()
plt.figure(figsize=(12, 12))
plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
plt.show()
Copyright© Digital Advantage Corp. All Rights Reserved.