検索
連載

PyTorchで全結合型のGANを作ってみよう作って試そう! ディープラーニング工作室(1/2 ページ)

GANの基本的な構成を見た後に、PyTorchのLinearクラスでMNISTを対象としたGANを実装してみましょう。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
「作って試そう! ディープラーニング工作室」のインデックス

連載目次

今回の目的

 前回までは数回にわたってオートエンコーダーについて見てきました。今回はそれとはまた別の画像生成フレームワークとしてよく知られているGAN(Generative Adversarial Network。敵対的生成ネットワーク)を作ってみましょう。といっても、今回はみんな大好き「全結合型」のニューラルネットワークとして作ってみます。

 GANの代表的な用途としては今回も取り上げる画像生成が挙げられます。萌え絵を自動生成するようなニューラルネットワークモデルはいろいろなところで話題になったことから、関連してGANという語を覚えている方もたくさんいらっしゃるでしょう。

 GANは2014年に登場したアルゴリズムですが、その後、さまざまな改良が加えられ、その派生となる数多くのアルゴリズムが登場しています。その一つであるDCGAN(Deep Convolutional GAN)については後続の回で取り上げる予定です。「Convolutional」という語が含まれていることからも分かる通り、これはCNNを使ったGANの実装といえます。

 前回までに見てきたオートエンコーダーは、エンコーダーとデコーダーの2つのニューラルネットワークを内包するニューラルネットワークでした。GANもまた2つのニューラルネットワークで構成されます。1つは識別器(discriminator。判別器とも)と、もう1つは生成器(generator)と呼ばれます。これら2つのニューラルネットワークが「敵対」しながら学習が進むのがGANの大きな特徴です。

 ここでいう「敵対」とはどんな意味でしょう。典型的なGANの構成は次のようになっています。

GANの基本構成
GANの基本構成

 識別器は「入力されたデータが、生成器により生成されたものか、リアルなデータ(訓練データ)であるかを識別する」ことを目的とします。一方、生成器は「乱数を入力すると、リアルなデータ(訓練データ)とよく似たデータを生成することで、識別器をだます」ことを目的とします。

 GANを構成する2つのニューラルネットワークが相反する目的を持つことから、このネットワークは「敵対的」と呼ばれるということです。学習が進むにつれて、生成器は訓練データとよく似たデータを生成できるようになり、識別器は訓練データと生成器から得られるニセモノのデータをより高精度に識別できるようになるはずです。このようにして、高精度なデータを生成できる生成器を手に入れることがGANの目的といえます。

 上図から分かるように、識別器には生成器が生成したデータと、訓練データが入力され、それがホンモノ(訓練データ)かニセモノ(生成器によるデータ)かを判断するので、出力は1つだけで、それが1に近いほど識別器はそれをホンモノに近いと判断したことになります。これならPyTorchのLinearクラスを使って、全結合型のネットワークを作れば、簡単に実現できそうですね。

識別器
識別器

 一方、生成器はオートエンコーダーのデコーダーに近い動作をすると考えることができます。デコーダーは、エンコーダーによりエンコードされた結果である潜在変数を入力に取り、それを段階的に拡張しながら元データに近いデータを再現するものでした。これに対して、生成器は潜在変数としてランダムな値を入力に受け取り、その値を段階的に拡張していくことで最終的に識別器へ入力される訓練データと同じ形式のデータを生成します。

生成器
生成器

 GANの基本的な構造は今述べた通りです。説明がまだ足りない点もありますが、それについては必要に応じて話をしていくことにしましょう。ここでは、MNIST(またか!)の手書き数字とよく似た画像データを生成できるような生成器と、その識別を行う識別器を作っていくことにします。

識別器

 まずは実装が簡単な識別器から見ていきます。なお、今回のコードはこのノートブックで公開しています。また、例によって、必要なモジュールをインポートするコードなど、定型的なコードについては最後にまとめます。

 既に述べたように、今回はMNISTの手書き数字を使います。つまり、1×28×28=784次元のデータを識別器に入力して、それらの真贋(しんがん)を識別できるようにすることが目的となります。これもまた既に述べていますが、今回はこれをPyTorchのLinearクラスを使って作成します。

 細かく説明する必要は特にないので、実際のコードを以下に示してしまいましょう。

in_features = 1 * 28 * 28

class Discriminator(torch.nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.fc1 = torch.nn.Linear(in_features, 384, bias=False)
        self.fc2 = torch.nn.Linear(384, 128, bias=False)
        self.fc3 = torch.nn.Linear(128, 1, bias=False)
        self.relu = torch.nn.LeakyReLU()
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # 0 to 1(0: fake, 1: true)
        return x

全結合型のDiscriminatorクラス

 見ての通り、このDiscriminatorクラスはLinerクラスのインスタンスと活性化関数を組み合わせて、784次元のデータを入力に受け取り、0〜1の範囲の値を出力するだけです(活性化関数の一つであるtorch.nn.LearkyReLUクラスについてはこちらを参照してください)。試しに使ってみましょう。

netD = Discriminator(in_features)

iterator = iter(trainloader)
img, _ = next(iterator)

D_out = netD(img.reshape(batch_size, -1))

print(D_out[0:5])

識別器に画像を入力する

 変数trainloaderはMNISTの手書き数字をバッチごとに読み込むデータローダーです。上のコードでは、これを使って訓練データを読み込み、それをDiscriminatorクラスのインスタンスであるnetDに入力して、その結果(の一部)を表示しています。これを実行すると、次のようになります。

実行結果
実行結果

 何の学習もしていないので、出力結果にはあまり意味はないでしょうが、0でも1でもない=真贋が判断できていないとも解釈できる値となりました。ポイントはこのクラスのインスタンスにMNISTのデータセットから得たデータ(訓練データなど)を入力したときには出力は1(に近い値)に、生成器から得たデータを入力したときには出力は0となるように学習を進めることです。

 次に生成器について見てみましょう。

生成器

 生成器も特に説明の必要はないコードとなっています。

zsize = 100

class Generator(torch.nn.Module):
    def __init__(self, zsize, in_features):
        super().__init__()
        self.fc1 = torch.nn.Linear(zsize, 256, bias=False)
        self.fc2 = torch.nn.Linear(256, 512, bias=False)
        self.fc3 = torch.nn.Linear(512, in_features, bias=False)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))  # -1 to 1
        return x

全結合型の生成器

 最終的な出力が1×28×28=784次元のデータとなるようにLinearクラスと活性化関数を使用しているだけです。こちらも試しに使ってみましょう。

netG = Generator(zsize, in_features)

z = torch.randn(batch_size, zsize)

output = netG(z)

imshow(output[0:8].reshape(-1, 1, 28, 28))

生成器にランダムなデータを入力して、得られた画像を表示

 実行結果は次のようになりました。

実行結果
実行結果

 こちらも学習は何もしていないので、出力される画像はノッペリとしたグレー画像となっています。これを先ほど作成したDiscriminatorクラスのインスタンスに入力してみると、どうなるでしょう。

G_out = netD(output)

print(G_out[0:5])

得られた画像を識別器に入力して得られた結果を表示

 実行結果を以下に示します。こちらも白黒付かない灰色判定となっています。スタート地点としてはまあよいのではないでしょうか。

実行結果
実行結果

Copyright© Digital Advantage Corp. All Rights Reserved.

       | 次のページへ
[an error occurred while processing this directive]
ページトップに戻る