検索
連載

CNNなんて怖くない! コードでその動作を確認しよう作って試そう! ディープラーニング工作室(1/2 ページ)

CNNによる画像認識ではどんなふうに処理が進むのかを、実際に手を動かしながら確認していきましょう。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
「作って試そう! ディープラーニング工作室」のインデックス

連載目次

 前回はCNNで画像を認識する際の基本的な仕組みを理論的な面から説明しました。今回は、PyTorchを使って、実際にコードを動かしながら、その動作を確認していきましょう。

横線と縦線のどちらであるかを推測するCNN

 前回は×と○を例として、カーネル(フィルター、ウィンドウ)を使い、画像の特徴がどこに現れているかを特徴マップに畳み込み、それをプーリングによって強調するという話をしました。

 今回はさらにシンプルに横棒と縦棒を表すデータを例として、畳み込み層(+プーリング層)と全結合層で行われる処理(とは、PyTorchのニューラルネットワーククラスのforwardメソッドで行われる処理です)について実際に見ていきます。

 前回に紹介したMNISTの手書き数字を認識するニューラルネットワークは次のようなものでした。

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 16, 64)
        self.fc2 = torch.nn.Linear(64, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.reshape(-1, 16 * 16)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        return x

MNISTの手書き数字を認識するNetクラス

 今回のコードは、細かなところに関しては上記のコードと異なりますが、基本的には上と同じコードを手で動かしながら、その動作を確認します。「畳み込み→活性化関数→プーリング→畳み込み→活性化関数→プーリング→全結合」という流れを取りあえずは把握しておきましょう。

 ここでは以下のようなデータを用意しました。

今回使用するデータ
今回使用するデータ

 特徴を抜き出す対象の画像(に相当するもの)は4×4のサイズのものが4種類、カーネルは3×3のサイズのものが2種類とします(この後、2つ目の畳み込み層も登場するので、それ用のカーネルもありますが、それについては後述します)。

 横棒を表すデータが2つ、縦棒を表すデータが1つ、どちらともいえないデータが1つあります。カーネルは横線を表すものと、縦線を表すものになっているのはすぐに分かるでしょう。2つのカーネルを使って4つのデータを調査すると、実際にどのような特徴がピックアップされるのかを試してみることにしましょう。なお、今回のコードはこのリンクで公開しています。

データと各層の準備

 まず畳み込みとプーリングを行うためのConv2dクラスとMaxPool2dクラスのインスタンスを用意します。

import torch

conv1 = torch.nn.Conv2d(1, 2, 3, padding=True, bias=False)
pool = torch.nn.MaxPool2d(2, padding=1)

print(conv1.weight.shape)

畳み込み(1回目)とプーリングを行うオブジェクトの用意

 1回目の畳み込みを行うオブジェクトであるconv1の生成では、入力チャネル数に1、出力チャネル数(カーネル数)に2、カーネルのサイズに3を指定しています。入力チャネルの数が1なのは、前回と同様、これがRGB値のように複数のチャネルを持つものではないからです。また、データが小さいのでここではパディングを付加するようにしています。話を簡単にするためにバイアスもここでは使わないことにしました。

 プーリングを行うオブジェクトでは、カーネルのサイズは2×2で、畳み込みと同様にパディングをすることにしました。

 なお、カーネルは実際には畳み込みを行うオブジェクトの重みとなるので、上のコードではその形状(各次元の要素数)を確認しています。これを実行した結果が以下です。

実行結果
実行結果

 この結果が意味するのは、「チャネル数が1で、サイズが3×3のカーネル(重み)が2つある」ということです。カーネルを表すデータはこれに合わせて作成します。

kernels1 = torch.tensor([
    [[[-1., -1., -1.],  # 横線
      [ 1.,  1.,  1.],
      [-1., -1., -1.]]], 

    [[[-1.1., -1.],  # 縦線
      [-1.1., -1.],
      [-1.1., -1.]]]])

print(kernels1.shape)

カーネルを表すデータ

 1つ目のカーネルは第1行(中央の行)の要素が全て「1」です。これで横線という特徴を表しています。対して、2つ目のカーネルでは第1列(中央の列)がそうなっています(縦線)。最後にその形状を確認しています。実際の実行結果は次の通りです。

実行結果
実行結果

 conv1の重みと同じ形状になっているのを確認したところで、今回はこのカーネルをconv1の重みとしてしまいましょう(学習させるのが目的ではなく、カーネルを使って4つのデータの畳み込みとプーリングを試すのが目的なので)。これには次のように、conv1オブジェクトのweight.data属性にカーネルを代入するのが簡単です。

conv1.weight.data = kernels1

カーネルをconv1の重みとする

 これでカーネルの準備ができました。次に画像に相当する4つのデータを用意しましょう。以下のコードでは、-1を背景として、そこに1で横線、縦線、どちらでもない線を表したものを、変数sample_dataに代入しています。

sample_data = torch.tensor(
    [[[-1., -1., -1., -1.],  # 横線
      [ 1.,  1.,  1.,  1.],
      [-1., -1., -1., -1.],
      [-1., -1., -1., -1.]],
     [[-1.1., -1., -1.],  # 縦線
      [-1.1., -1., -1.],
      [-1.1., -1., -1.],
      [-1.1., -1., -1.]],
     [[-1., -1., -1., -1.],  # 横線
      [ 1.,  1., -1., -1.],
      [-1., -1.1.,  1.],
      [-1., -1., -1., -1.]],
     [[-1., -1., -1.1.],  # 左下がりの直線
      [-1., -1.1., -1.],
      [-1.1., -1., -1.],
      [ 1., -1., -1., -1.]]])

print(sample_data.shape)

横線と縦線とどちらでもない画像に相当するデータ

 4つ目のデータは左下がりの直線を表すもので、横線とも縦線ともいえるようないえないようなものになっています。これらをカーネルを使って調査するとどうなるかも後で見てみましょう。

 最後にこのデータの形状も確認しています。実行結果は次の通りです。

実行結果
実行結果

 この通り、4×4のデータ(テンソル)を4つ格納するテンソルとなっていますが、チャネル数の指定がないので、以下のようにして、「チャネル数が1、サイズが4×4のデータ」を4つ格納するテンソルに変換しておきます。

sample_data = sample_data.reshape(4, 1, 4, 4)

サンプルデータの形状を変更

 興味のある方は、このデータを出力して、かっこ「[]」がどこに増えているかを確認しておきましょう。

 以上で元のデータとカーネルの準備ができました。それでは実際に畳み込みとプーリングで、これらがどう処理されるかを確認していきましょう。

畳み込み

 畳み込みを行うには、上で作成したconv1オブジェクトに、変数sample_dataを渡すだけです。その結果は変数f_map1に代入しておきましょう(前回も述べましたが、畳み込みの結果は特徴マップと呼ばれることがあります。そこでここでは変数名を「feature map」を表す「f_map1」としています。「1」はこれが1回目の畳み込みであることを示します)。

f_map1 = conv1(sample_data)
print(f_map1)

畳み込みを行い、その結果を表示

 実行結果を以下に示します。

実行結果
実行結果

 注目してほしいのは、4つのデータを入力したのに対して、出力された特徴マップは8つのデータを含んでいるように見えている点です。これは1つのデータを、2つのカーネルで調べたそれぞれの結果、つまり、4×2=8個の特徴マップが出力されたということです。最初の2つの塊は1つ目のデータ(横線)を2つのカーネルで調査した結果です。1つ目は横方向に「6., 9., 9., 6.」という(他の数値と比べて)大きな数値が出ています。これは横線を、横線を示すカーネルで調査した結果であり、「なるほど」という感じがします。一方、その次の出力では0、1、2という値が存在するだけです。これは縦線の特徴を表すカーネルで調査した結果、そうしたところはあまり見られないことが数値的にも分かる感じがします。この傾向は3つ目のデータの調査結果(特徴マップ)でも同様です。

 その下の2つは縦線を表すデータを調査したもので、これらについては上と逆の結果になっていることが分かります。

 4つ目のデータ(左下がりの直線)から得られた特徴マップからは、これが横線であることを示しているとも、縦線であることを示しているともいえないことが分かります。

 畳み込みを行い特徴マップを得ることで、それぞれの元データの特徴がうまく得られました。これをプーリングするのですが、前回のコードではその前に活性化関数ReLUを適用していました。この関数は0以下の値については0を、正の値については、元の値を返すという関数です。手順に従って、特徴マップにこれを適用してみましょう。

f_map1 = torch.nn.functional.relu(f_map1)
print(f_map1)

活性化関数ReLUを特徴マップに適用

 実行結果を以下に示します。

実行結果
実行結果

 負の値が0に変換され、正の値だけが残りました。実際には、この結果を使ってプーリングを行うことになります。

プーリング

 では、実際にプーリングを行ってみましょう。といっても、これも先ほど作成したpoolオブジェクトに活性化関数を通した結果を渡すだけです。

pooled1 = pool(f_map1)
print(pooled1)

活性化関数を適用した結果に対してプーリングを行う

 poolオブジェクトの作成時には、パディングを指定していたので、ここでは特徴マップ(f_map1)の外側にパディングが付加されて(値は0)、6×6のサイズのデータを2×2のサイズの分割した各区画から最大値がピックアップされて、3×3のサイズのデータが得られます。実行結果は次の通りです。

実行結果
実行結果

 1つ目のデータから3つ目のデータまでは、横線、縦線という特徴がしっかりと出ていることが分かります(4つ目のデータではきちんとどちらともいえない特徴が得られています)。これらが2回目の畳み込みの入力データとなります。

Copyright© Digital Advantage Corp. All Rights Reserved.

       | 次のページへ
[an error occurred while processing this directive]
ページトップに戻る