CNNなんて怖くない! その基本を見てみよう作って試そう! ディープラーニング工作室(2/2 ページ)

» 2020年05月29日 05時00分 公開
[かわさきしんじDeep Insider編集部]
前のページへ 1|2       

プーリング

 「プーリング」とは、畳み込みによって得た特徴(特徴マップ)から重要な要素は残しながら、データ量を削減する処理です。通常は入力(特徴マップ)を小さなサイズの区画(2×2、3×3など。これもやはりウィンドウとかカーネルと呼びます)に分けて、その区画内で特徴的な値(最大値、平均値など)を取り出して、それをプーリングの出力とします。

 ここでは、上で得た特徴マップに対して、2×2のサイズでプーリングを行ってみましょう。多くの場合は最大値を取り出すので、ここでもそうしてみます。

 特徴マップは次のようなものでした。

特徴マップ(再掲) 特徴マップ(再掲)

 これに対して、左上から右下に向かって、先ほどと同様な順序で値を取り出していくと次のようになります。

プーリング処理の結果 プーリング処理の結果

 元の画像データにパディングをしていない方では、特徴マップは3×3というサイズだったので、左上の2×2の要素の中で最大値である「9」という要素だけが取り出されました。パディングをしている方では2×2のサイズで最大値が取り出されました。いずれにしても交差していることが強く強調されるデータ(9)がうまく取り出されると同時に、データ量が削減されました。

 プーリングにも、畳み込みと同様な特徴があります。それは入力(この場合は特徴マップ)の中で多少のズレがあっても、もともとの特徴を示すデータをうまく拾い上げられる点です。画像データはピクセル単位での処理をするので、元のデータや重み、バイアスなどによって、特徴マップのどこに特徴といえる値が出てくるかはそのときどきで変わるかもしれません。そんなときでも、プーリングを行うことで必要なデータをうまく取り出せるのがプーリングのメリットといえます。

 PyTorchではこの処理を行うクラスとしてMaxPool2dクラスなどが提供されています。

 畳み込みは元データが持つ微細な特徴を見つけ出す処理、プーリングは畳み込みによって見つかった特徴マップの全体像を大まかな形で表現する処理(大きな特徴だけをより際立たせる処理)と考えることもできるでしょう。

 畳み込みとプーリング(とその間に挟み込む活性化関数)という組み合わせを何層かに重ねることで、入力層に近いところでは今述べた微細な特徴を、入力層から遠い層では全体的な(抽象的な)特徴を表現できるようになります。そうして得られたものを全結合により推測を行う層(全結合層)へと渡して、最終的に分類を行うというのがCNNによる画像認識の手順となります。

畳み込み層とプーリング層と全結合層で構成されるニューラルネットワーク 畳み込み層とプーリング層と全結合層で構成されるニューラルネットワーク

PyTorchでCNNを実装する

 ここまではCNNでどんな処理が行われるかを見てきましたが、実際にこれを使ってMNISTの手書き数字を推測するコードを最後に見ておきましょう。といっても、コードはこれまでのものとほとんど変わりません。

 まずMNISTデータベースからデータセットを読み込むコードです。これについて前回と同じコードです。torchvision.datasets.MNISTクラスとtorch.utils.data.DataLoaderクラスを使って、データセットを読み込んで、それを反復するデータローダーを用意しています。

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

BATCH_SIZE = 20

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

MNISTデータベースからデータセットを読み込むコード

 これまでに見てきたような畳み込みとプーリングを行うクラスの定義を以下に示します。

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 6, 5)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        self.fc1 = torch.nn.Linear(16 * 16, 64)
        self.fc2 = torch.nn.Linear(64, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.reshape(-1, 16 * 16)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.fc2(x)
        return x

CNNを利用してMNISTの手書き数字を認識するNetクラス

 既に述べたように、PyTorchには2次元データの畳み込みを行うためのクラスとしてConv2dクラスが、最大値のプーリングを行うクラスとしてMaxPool2dクラスが用意されています。ここでは、それらのインスタンスを作成しています。Conv2dクラスのインスタンスは2つ作成して、インスタンス変数conv1とconv2に代入しています。forwardメソッドを見ると分かりますが、畳み込みは2回行うということです。

 インスタンス変数conv1のインスタンス生成では、「Conv2d(1, 6, 5)」のように引数を指定しています。第1引数の「1」は「入力チャネルの数」です。MNISTはRGB値のようなカラー画像ではなく、各ピクセルが0〜255(の値を-1〜1の範囲の浮動小数点数に変換したもの)だけのデータなので、ここでは1を指定しています。第2引数は「出力チャネルの数」ですが、これが実質的にはカーネルの数を表します。ここでは6個のカーネルを作成するということです。第3引数は「カーネルのサイズ」です。ここでは「5×5」のサイズのカーネルを作成するということになります。

 インスタンス変数conv2の生成では、「Conv2d(6, 16, 5)」のように引数を指定しています。第2引数と第3引数の指定は上と同様なので説明は不要でしょう。しかし、第1引数の「6」については少し説明が必要です。インスタンス変数conv1では出力チャネルの数を「6」としていました。これが活性化関数とプーリングによる処理を経て、インスタンス変数conv2へと渡されます。そのため、データの入力元となるインスタンス変数conv1の出力チャネルの数と、データを受け取るインスタンス変数conv2の入力チャネルの数を一致させておく必要があります。そのため、ここでは「6」を指定しています(なお、これらの引数の値は筆者が適当に定めたもので、特に理由はありません。もっとよい値があるかもしれません)。

 MaxPool2dクラスのインスタンスは1つだけ作成して、それをインスタンス変数poolに代入しています。2回の畳み込みの(結果を活性化関数で処理した)結果は、このインスタンスで処理してプーリングを行っています。引数は「MaxPool2d(2, 2)」となっているので、2×2のサイズでプーリングを行うことを意味しています。

 最後に、それらを、全結合を行うインスタンス変数fc1とfc2で処理するだけです(インスタンス変数fc1のノード数は、16×16=256個となっているのは、畳み込みとプーリングによって得た、全結合層への入力の数を筆者が確認して、その数を指定しました)。

 最後に学習とテストを行うコードを以下に示します。

import torch.optim as optim

net = Net()

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

EPOCHS = 2

for epoch in range(1, EPOCHS + 1):
    running_loss = 0.0
    for count, item in enumerate(trainloader, 1):
        inputs, labels = item

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if count % 500 == 0:
            print(f'#{epoch}, data: {count * 20}, running_loss: {running_loss / 500:1.3f}')
            running_loss = 0.0

print('Finished')

correct = 0
total = 0

with torch.no_grad():
    for data in testloader:
        inputs, labels = data
        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)
        total += len(outputs)
        correct += (predicted == labels).sum().item()

print(f'correct: {correct}, accuracy: {correct} / {total} = {correct / total}')

学習を行うコード

 これも基本的には、前回と同じコードなので説明は省略します。

 これらのコードを実行すると、結果は次のようになります。

実行結果 実行結果

 前回の全結合型のニューラルネットワークでは92%程度の精度でしたが、今回はそれよりも高い精度で認識できていることが分かります。


 今回はCNNによる画像認識の基礎知識とそれを実際に行うコードを見ました。コードについては少し駆け足になってしまいましたが、次回は手を動かしながら、実際のコードでどんなことが行われているかを見ていくことにします。

「作って試そう! ディープラーニング工作室」のインデックス

作って試そう! ディープラーニング工作室

前のページへ 1|2       

Copyright© Digital Advantage Corp. All Rights Reserved.

スポンサーからのお知らせPR

注目のテーマ

Microsoft & Windows最前線2025
AI for エンジニアリング
ローコード/ノーコード セントラル by @IT - ITエンジニアがビジネスの中心で活躍する組織へ
Cloud Native Central by @IT - スケーラブルな能力を組織に
システム開発ノウハウ 【発注ナビ】PR
あなたにおすすめの記事PR

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。