PyTorchでCIFAR-10を基にDCGANで画像生成をしてみよう作って試そう! ディープラーニング工作室(2/2 ページ)

» 2020年11月27日 05時00分 公開
[かわさきしんじDeep Insider編集部]
前のページへ 1|2       

バッチ正規化入りのバージョン

 次に、バッチ正規化入りのDCGANで試してみましょう。実際の構成は次の通りです。

discriminator = nn.Sequential(
    nn.Conv2d(3, 64, 4, 2, 1, bias=False),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, 4, 2, 1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 256, 4, 2, 1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(256, 1, 4, 1, 0, bias=False),
    nn.Sigmoid()
)

feature_maps = 64
zsize = 100
generator = nn.Sequential(
    nn.ConvTranspose2d(zsize, feature_maps * 8, 4, 1, 0, bias=False),
    nn.BatchNorm2d(feature_maps * 8),
    nn.ReLU(inplace=True),
    nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
    nn.BatchNorm2d(feature_maps * 4),
    nn.ReLU(inplace=True),
    nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
    nn.BatchNorm2d(feature_maps * 2),
    nn.ReLU(inplace=True),
    nn.ConvTranspose2d(feature_maps * 2, 3, 4, 2, 1, bias=False),
    nn.Tanh()
)

netD = MakeFrom(discriminator)
netG = MakeFrom(generator)

losses, outs, nets = train(netD, netG, batch_size, zsize, EPOCHS, trainloader)

バッチ正規化入りのDCGAN

 実行結果を以下に示します。

実行結果 実行結果

 まず注目したいのは、エポック10、20、30のD(G(z))の値です。先ほどの例では、0.5近辺の値が続いていたのが、アッという間に低い数値になっています。バッチ正規化のメリットの一つに「効率的に学習を進められる」ことが挙げられます。D(G(z))がいきなり低い値になっているのは、識別器の学習が急速に進んだことを意味しているかもしれません。あるいは、生成器の学習がうまくいっていなくて、よい感じの画像を生成できていないということも考えられます(が、何度か試したところでもこの傾向は見られたので、識別器の学習が進んでいる、あるいは進みすぎているということが考えられます)。

 損失関数の値をグラフ化したものは次のようになりました。

損失関数のグラフ 損失関数のグラフ

 何ともコメントのしようがないグラフになってしまったので、コメントは省略します。D(x)とD(G(z))については以下のようになりました。

D(x)、D(G(z))の値の推移 D(x)、D(G(z))の値の推移

 このグラフからは、D(x)が高値安定、つまり訓練データの識別がうまくできていることが分かります。D(G(z))(オレンジ色と緑色の線)がそれほど高い値となっていないので、識別器から見ると偽データをきちんと識別できているといえるでしょう。逆に生成器にはガンバレヨと伝えたくなるところです。

 最後に、先ほどと同様に、80エポックの学習後に生成された画像を示します。

80エポックの学習後に生成された画像 80エポックの学習後に生成された画像

 生成画像評論家という目線で見ると、これらは左から「船的な何か」「バーバモジャか何か」「テーブルに置かれた花と向こう側に人」「海と崖」「イヌの顔がひしゃげてるところ」「海上を漂う何か」「鳥?」「判別不可能」といった感じです。生成器の学習がうまくいっていないのかもしれませんね(そもそもCIFAR-10にそんなカテゴリあった?)。

重みを独自に初期化するバージョン

 最後に平均0、標準偏差0.02となるように重みを初期化したDCGANで試してみます。注意点としては、weights_init関数内では、BatchNorm2dクラスのインスタンスが持つ重みについても平均1、標準偏差0.02となるように重みを初期化している点です(前回はPyTorchのチュートリアルにある通り、BatchNorm2dクラスのインスタンスについては平均1、標準偏差0.02となるように初期化していました)。

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

discriminator = nn.Sequential(
    nn.Conv2d(3, 64, 4, 2, 1, bias=False),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(64, 128, 4, 2, 1, bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(128, 256, 4, 2, 1, bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2, inplace=True),
    nn.Conv2d(256, 1, 4, 1, 0, bias=False),
    nn.Sigmoid()
)

feature_maps = 64
zsize = 100
generator = nn.Sequential(
    nn.ConvTranspose2d(zsize, feature_maps * 8, 4, 1, 0, bias=False),
    nn.BatchNorm2d(feature_maps * 8),
    nn.ReLU(inplace=True),
    nn.ConvTranspose2d(feature_maps * 8, feature_maps * 4, 4, 2, 1, bias=False),
    nn.BatchNorm2d(feature_maps * 4),
    nn.ReLU(inplace=True),
    nn.ConvTranspose2d(feature_maps * 4, feature_maps * 2, 4, 2, 1, bias=False),
    nn.BatchNorm2d(feature_maps * 2),
    nn.ReLU(inplace=True),
    nn.ConvTranspose2d(feature_maps * 2, 3, 4, 2, 1, bias=False),
    nn.Tanh()
)

netD = MakeFrom(discriminator)
netG = MakeFrom(generator)
netD.apply(weights_init)
netG.apply(weights_init)

重みの初期化も行うDCGANの識別器と生成器

 実行結果を以下に示します。

実行結果 実行結果

 D(G(z))が0.5近辺から次第に低下していくのはBatchNorm2dクラスを使ったバージョンと同様です。ただし、その値はエポック1の最後からエポック2の頭で底を打ち、その後は次第に上昇していきます。これは生成器の学習がある程度進んで、識別器で訓練データとして判定されていることの表れかもしれません。ただし、そうした動きは途中から鈍くなり、最後には再び偽データと判定されることが多くなっているようです(以下のグラフ参照)。

 D(x)は最初こそ訓練データとして判定されることが多いようですが、次第にその値が低下していきます。これは、D(G(z))の値の上昇と相反しているので、生成器の学習がここである程度進んでいると予想できるでしょう。その後は緩やかに上昇していきます(ここでは識別器の学習が進んで、生成器の画像が偽データと判定されていると思われます)。

 損失関数のグラフは次の通りです。

損失関数のグラフ 損失関数のグラフ

 今述べたような状況を反映したグラフになっているといえます。

 D(x)、D(G(z))の値の推移は次のようになっています。

D(x)、D(G(z))の値の推移 D(x)、D(G(z))の値の推移

 先ほどのBatchNorm2dクラスを利用したバージョンよりも振れ幅が広くなっています。その理由はちょっと分かりませんが、D(x)は、序盤に少し低下傾向が見られますが、元から高い値かつ右肩上がりのグラフになっていることから、上で述べたような振る舞いになっているといってもよいでしょう。

 2つのD(G(z))は最初にギュッと上がった後、平衡状態が続いて、その後は低下傾向にあるグラフになっているようです。これもまた、先ほど述べたことを反映するようなグラフといえます。

 最後に、80エポックの学習後に生成された画像を示します。

80エポックの学習後に生成された画像 80エポックの学習後に生成された画像

 何に見えるかは読者におまかせします。とはいえ、これまでのものよりはCIFAR-10にありそうなものになったように思えるのはひいき目でしょうか。

 ここ数回はオートエンコーダーとGAN、DCGANと画像生成に関する話をしてきました。ここいらで画像生成から離れて、また別のトピックに移ってみようかと思ってはいますが、実はまだどんなことをやってみようかは悩んでいるところです。次回以降も乞うご期待ということで。

 最後にいつものコードを掲載しておきましょう。

いつものコード

import torch
from torch import optim
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 100
trainset = CIFAR10('.', train=True, transform=transform, download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

def imshow(img):
    img = torchvision.utils.make_grid(img)
    img = img / 2 + 0.5
    npimg = img.detach().numpy()
    plt.figure(figsize=(12, 12))
    plt.imshow(np.transpose(npimg, (1, 2, 0)), cmap='gray')
    plt.show()

class MakeFrom(nn.Module):
    def __init__(self, s):
        super().__init__()
        self.model = s
    def forward(self, x):
        return self.model(x)

いつものコード


「作って試そう! ディープラーニング工作室」のインデックス

作って試そう! ディープラーニング工作室

前のページへ 1|2       

Copyright© Digital Advantage Corp. All Rights Reserved.

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。