PyTorchが提供するConv2dクラスとConvTranspose2dクラスを利用して、畳み込みオートエンコーダーを作成してみましょう。
この記事は会員限定です。会員登録(無料)すると全てご覧いただけます。
前回は、CIFAR-10と呼ばれる画像データセットをエンコード/デコードするオートエンコーダーを作成しました。このオートエンコーダーの内部では全結合型のニューラルネットワークを用いていましたが、画像をうまく復元しようとすると、かなりの時間がかかることが難点でした。また、復元後の画像もそれほどキレイなものではありませんでした。
そこで、今回は「畳み込みオートエンコーダー」と呼ばれるオートエンコーダーを作成して、全結合型のオートエンコーダーよりも高い精度で画像を復元できるようにすることを目的とします(学習にかかる時間が短ければさらに好ましいといえるでしょう)。
畳み込みオートエンコーダーとは、本連載の「CNNなんて怖くない! その基本を見てみよう」で取り上げたCNN(Convolutional Neural Network、畳み込みニューラルネットワーク)を使ったオートエンコーダーです。
ここでは、CIFAR-10は2次元の画像データ(かつRGB形式で3つの層を持ちます)なので、PyTorchのConv2dクラスをエンコード部に使用することにしましょう。畳み込みとプーリングという処理は、元画像の特徴をピックアップすると同時に、そのデータ量を削減する処理でした。ということは、これはエンコード(による次元削減)にも使えるということです。
その逆を行うのに、今回はConvTranspose2dクラスを使用します。このクラスを使うと、次元削減された画像を元の次元に復元することが可能です。が、このクラスのオブジェクトがどのように振る舞うのかの詳細については、ここでは省略します。
簡単にまとめると、今回は以下のような構成のオートエンコーダーを作成するということになります。
なお、各種モジュールのインポートや、データセットとデータローダーなどについては、これまでと同様なので、最後にまとめて示します(全てのコードは今回のノートブックを参照してください)。以下では、CNNを使った畳み込みの振る舞いを手作業で確認していきましょう。オートエンコーダーを実装するAutoEncoder2クラスは前回と同様に、エンコーダーとデコーダーをプラグインできるようにしてあるので、手作業で振る舞いを確認した後に、対応するエンコーダーとデコーダーのコードを示すことにします。
畳み込みやプーリングのおさらいも兼ねて、Conv2dクラスやMaxPool2dクラスによりデータがどのように削減されていくのかを見ておきましょう。
conv1 = torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
conv2 = torch.nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, padding=1)
pool = torch.nn.MaxPool2d(2)
ここではRGB形式(3層)の画像データを入力するので、最初の畳み込み層となるConv2dクラスのインスタンスでは入力チャネル数に3を指定しています。out_channelsパラメーターは出力が何チャネルかを指定するものですが、これは同時に画像の走査に使用するカーネルの数でもありました。つまり、ここでは32×32ピクセルで3層ある画像の走査に16個のカーネルを使用するという意味です。そして、そのカーネルサイズは3×3となり、画像の外周に値0のパディングを行います。
もう一つのConv2dクラスのインスタンスは、1つ目の畳み込み層であるconv1の出力が活性化関数やプーリングによる処理を終えた後に渡されます。よって、入力のチャネル数はconv1の出力チャネル数と同じです。その他については、適当な値を指定しています。
MaxPool2dクラスではカーネルサイズに2×2を指定しています。
というわけで、データローダー(ここでは学習用のtrainloaderオブジェクトを使いましょう)からデータを取ってきて、これらconv1から順に渡していきながら、出力がどうなるかを確認しておきましょう。
iterator = iter(trainloader)
x, _ = next(iterator)
print('init:', x.shape)
x = conv1(x)
print('after conv1:', x.shape)
x = torch.relu(x)
x = pool(x)
print('after 1st pool:', x.shape)
x = conv2(x)
print('after conv2:', x.shape)
x = torch.relu(x)
x = pool(x)
print('after 2nd pool:', x.shape)
実行結果を以下に示します。
実行結果から分かるように、32×32×3=3072次元あったデータが8×8×8=512次元に削減されていることが分かります(変数xの形状を表示した先頭の「50」はデータローダーから一度に取ってくる画像データの数です)。
今見た処理を、Sequentialクラスを使って、表現したものが以下です。
enc = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Conv2d(16, 8, 3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
次にデコードがどのように行われるかを見てみます。
convt1 = torch.nn.ConvTranspose2d(8, 16, kernel_size=2, stride=2)
convt2 = torch.nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2)
ここではConvTranspose2dクラスのインスタンスを2つ定義しています。最初のものは、エンコードされた画像データ(8チャンネル、8×8のデータ)を受け取るので入力チャネル数は8になっています。エンコード時に「3→16→8」とチャネル数を指定していたので、ここでは「8→16→3」となるように出力チャネル数は16としました(2つ目のインスタンスでは出力チャネル数に「3」を指定)。kernel_sizeパラメーターは、Conv2dクラスと同様に2次元データを走査するのに使用するもので、カーネルと2次元データとの演算結果が復元後の画像データとなります。その値が入力値と近い値になるように、これらのオブジェクトの重みとバイアスが学習によって調整されるわけです。
とはいえ、Conv2dクラスと同じように走査/演算を行っても、得られるデータのサイズがいきなり16×16になることはなさそうですね。例えば、上記コード例の最初の行からstrideパラメーターの指定を抜いて、エンコーダーを通ってきたデータを渡すとその形状は「50×16×9×9」となってしまいます。
convt1 = torch.nn.ConvTranspose2d(8, 16, 2)
x = convt1(x)
print(x.shape) # torch.Size([50, 16, 9, 9])
8×8の2次元データに2×2のカーネルを適用して演算を行うと、結果は7×7のデータとなってしまいそうですが、そうなっていないのは恐らく、走査したときにパディングに相当する処理が自動的に行われているからでしょう(ConvTranspose2dクラスの処理で出力サイズがどうなるかについては後で少し話をします)。
これが16×16のデータとなるように、上ではConvTranspose2dクラスのインスタンス生成時にstrideパラメーターを指定しています。strideパラメーターが実際の処理にどんな影響を及ぼすかについては「Convolution arithmetic tutorial」などを参照してください(英語)。
ここで重要なのは、convt1に8×8(×8チャネル)というサイズのデータを入力すると、それが16×16(×16チャネル)というサイズのデータが作成される必要があるということです。
ところで、PyTorchのConvTranspose2dクラスのドキュメントを見てみると、次のような記述があります(以下では、式の全てが表示しきれないため、2つの式のスクロールバーの位置を変えて、おおまかな内容が表示されるようにしています)。
この「Shape:」というところには出力サイズについての情報が書かれています。長い式ですが、現在のコードに関係あるものだけをまとめると次のようになります(ここでは、出力は8×8、カーネルサイズは2×2のように全てが平方となっているので、1つの式にまとめてしまいましたが、そうではない場合には、縦/横というか、行数と列数がどうなるかをきちんと計算する必要があることには注意しましょう)。
上の式で出力サイズは16(×16)です。入力サイズは8(×8)です。paddingパラメーターは指定していませんが、そのデフォルト引数値は0になっています。dilationパラメーターも同様に指定していませんが、こちらのデフォルト引数値は1となっています。kernel_sizeは2(×2)でした。これらを上の式に当てはめると次のようになります。
ここからstrideを求めると、上で指定している「2」が得られます。上でstrideパラメーターに2を指定しているのはそういうわけです。もう1つのConvTranspose2dクラスでも同様に計算できるはずです。
では、先ほどエンコーダーを通ったデータを上の2つのインスタンス(と活性化関数)に通して、そのサイズがどうなるかを見てみましょう。
x = convt1(x)
print('after convt1:', x.shape)
x = torch.relu(x)
x = convt2(x)
print('after convt2:', x.shape)
x = torch.sigmoid(x)
実行結果は次の通りです。
見ての通り、最終的に32×32×3チャネルのデータが復元されるようになりました。そして、これと同じことを行うデコーダーをSequentialクラスを使って定義すると次のようになります。
dec = torch.nn.Sequential(
torch.nn.ConvTranspose2d(8, 16, kernel_size=2, stride=2),
torch.nn.ReLU(),
torch.nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2),
torch.nn.Sigmoid()
)
後はこれを使って、Autoencoder2クラスのインスタンスを生成して、学習を行ってみるだけです。
Copyright© Digital Advantage Corp. All Rights Reserved.