検索
連載

MNISTの手書き数字を全結合型ニューラルネットワークで処理してみよう作って試そう! ディープラーニング工作室(1/2 ページ)

より高度なニューラルネットワークの作成に移る前に、これまでの知識を使って、MNISTの手書き数字を認識するプログラムを作ってみます。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
「作って試そう! ディープラーニング工作室」のインデックス

連載目次

 前回までは、あやめの品種の推測を題材にニューラルネットワークの基本となる要素について見てきました。今回からは手書き数字の認識を題材にもう少し高度な話題を見ていきましょう。

MNIST

 今回からは0〜9までの手書き数字を集めたMNISTデータベースを使用して、それらの数字を認識するニューラルネットワークモデルを作成します。

MNISTデータベースに含まれている手書き文字(抜粋)
MNISTデータベースに含まれている手書き文字(抜粋)

 MNISTデータベースには、上に示したような手書きの数字(と対応する正解ラベル)が訓練データとして6万個、テストデータとして1万個格納されています。この膨大な数のデータを使用して、手書きの数字を認識してみようというのが目標です。

 今回は、これまでに見てきた全結合型のニューラルネットワークを作成して、これを実際に試してみましょう。今回紹介するコードはここで公開しているので、必要に応じて参照してください。

データセットの準備と探索

 本連載で使用している機械学習フレームワークであるPyTorchには今述べたMNISTを手軽に扱えるようにするためのtorchvisionパッケージが用意されています(「vision」が付いているのは、このパッケージがコンピューターによる視覚の実現=コンピュータービジョンに由来するのでしょう)。このパッケージを使って実際にMNISTデータベースからデータセットを読み込んでみましょう。

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

BATCH_SIZE = 20

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      transform=transform, download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     transform=transform, download=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

MNISTデータベースからデータセットを読み込むコード

 最初の3行ではtorchパッケージ、torchvisionパッケージ、torchvisionパッケージが提供するtransformsパッケージをインポートしています。最後のtransformsパッケージには、画像(を構成する数値データ)を変換するためのクラスが含まれていて、これを使って、MNISTデータベースに格納されているデータをPyTorchで扱えるように変換作業が行われます。実際、その直後で、tranformsパッケージに含まれるToTensorクラス、Normalizeクラスを組み合わせた変換処理を行うオブジェクトを変数trnasformに代入しています(これについては後で簡単に見ます)。

 定数BATCH_SIZEの値「20」は訓練データ(とテストデータ)から一度に何個のデータを読み込むかを指定する値(バッチサイズ)です。前回まで使用していたあやめのデータセットは150個という極めて少ない量のデータセットでしたが、今回のデータセットには学習用とテスト用に合わせて7万個のデータがあるので、学習/テストを行う際にはそれらを分割して読み込むことにします(その下の変数trainloaderとtestloaderと合わせて、これらについても後述します)。

 変数trainsetとtestsetには、訓練データとテストデータが正解ラベル込みで代入されます。これらから実際にデータを取り出すときに使用するのが変数trainloaderとtestloaderに代入されているDataLoaderクラスのインスタンスです。

 変数trainsetとtestsetには、torchvision.datasetsモジュールが提供するMNISTクラスのインスタンスが代入されています。このインスタンスの生成時には、次のような引数を指定します。

  • root:データセットファイルを置くディレクトリを指定
  • train:訓練データを生成するか、テストデータを生成するかを指定
  • transform:trainset/testsetからデータを取り出す際に、MNISTの生データに対して行う変換処理を指定。ここでは変数transformを指定
  • download:必要に応じてインターネットからMNISTデータセットをダウンロードするかどうかを指定

 また、変数trainloaderとtestloaderに代入される、DataLoaderクラスのインスタンス生成時には、それぞれに対応するデータセットに加えて、次のような引数を指定しています。

  • batch_size:一度に読み込むデータ数(バッチサイズ)を指定
  • shuffle:読み込むデータをシャッフルするかどうかを指定

 変数trainloaderに代入するDataLoaderインスタンスの生成では引数shuffleにTrueを指定しています。これは変数trainsetを使って学習を行う際に、最初にデータをランダムに並べ替えることを意味しています。その学習で6万個のデータを使い切って学習が一区切り付いた後(この区切りのことを「エポック」といいます。つまり、この場合は6万個のデータが1つのエポックとして扱われます)、同じデータセットを使って次のエポックを開始する際には、またデータセット内のデータがランダムに並べ替えられます。これは、同じ並びでデータを取り出すのではなく、6万個のデータから任意の順序でデータをピックアップすることで、学習結果に偏りを生じさせないようにするためです。テストデータについてはshuffleをFalseにしていますが、これは最終的な確認を行うという観点から、シャッフルの必要がないためです。

 次に、読み込んだデータセットから先頭のデータを少し見てみましょう。データベースから読み込んだデータセットは変数trainsetのdata属性にアクセスすることでアクセスできるので、インデックス0を指定すればその先頭要素が得られます。

print(f'image: {len(trainset.data[0])} x {len(trainset.data[0][0])}')
for item1 in trainset.data[0]:
    for item2 in item1:
        print(f'{item2.data:4}', end='')
    print()

先頭にある手書き文字の値を表示するコード

 このコードは上で述べた通り、変数trainsetのdata属性にアクセスして、その先頭にある手書き文字を構成する数値を二重ループで表示するものです(画像は2次元のデータとなっているので、このように二重ループで処理をしています)。

 このコードを実行すると次のような実行結果になります(環境によっては表示が崩れるかもしれません)。

実行結果
実行結果

 最初にある「image: 28 x 28」というのはMNISTデータベースに含まれている手書き文字が28×28ピクセルのサイズになっていることを示すものです。その後には0〜255の範囲の整数値が表示されています。つまり、1つの手書き数字は28×28のサイズで、各ピクセルの値は0〜255の範囲にあるということです。さらに、全体をぼんやり眺めてみると、数字らしいものが浮かび上がっていることも分かります。このデータは何の数字を表しているのでしょうか。何となく想像は付きますが、実際の画像も見てみましょう。

import matplotlib.pyplot as plt

plt.imshow(trainset.data[0], cmap='gray')

Matplotlibを使って、先頭の手書き数字を表示する

 実行結果は次の通りです。

実行結果
実行結果

 これは数字の「5」のように見えますが、本当にそうか、実際に対応する正解ラベルを表示してみましょう。正解ラベルはtrainsetのtargets属性に格納されているので、「print(trainset.targets[0])」としてもよいのですが、ここではちょっと違う方法でデータを取り出してみます。

image, label = trainset[0]
print(label)
#print(trainset.targets[0])

先頭の手書き数字に対応する正解ラベルを表示する

 先ほどとの違いは、変数trainsetに対して直接インデックスを指定しているところと、その戻り値が2つある(2つの要素で構成されるタプル)ということです。torchvision.datasets.MNISTクラスはtorch.utils.data.Datasetクラスを(間接的に)継承したクラスで、インデックス指定を行うことで、そのインデックスに対応するデータと正解ラベルを取得できるようになっています。ここでは先頭の手書き文字の正解ラベルを取得するのに、これを使ってみました。実行結果を以下に示します。

実行結果
実行結果

 予想通りに「5」と表示されました。先ほどの数字は5を描いたものだったということです。ここで、上の手順で変数imageに取り出したデータについても見てみましょう。

image = image.reshape(28, 28)
print()
for item1 in image:
    for item2 in item1:
        print(f' {float(item2.data):+1.2f} ', end='')
    print()

MNISTクラスのインスタンスに対してインデックス指定を行って得た画像データを表示

 最初の行で行っているのは、変数imageに取り出したデータが「28×28の画像データをただ一つの要素とする配列」となっているので、これを「28×28の画像データ」に変換する処理です。その後は、trainset.data[0]の値を表示するのと同様な処理です(ただし、浮動小数点数を表示するようになっている点には注意してください)。

 実行結果を以下に示します(これも環境によっては表示が崩れるかもしれません)。

実行結果
実行結果

 注目してほしいのは、先ほどのMNISTデータベースに格納されていた0〜255の値が、今度は-1.0〜1.0の値に変換されている点です。MNISTクラスのインスタンス生成では次のようなコードを書いていました。

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

# …… 省略 ……

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      transform=transform, download=True)
# …… 省略 ……

MNISTクラスのインスタンスを生成するコード

 先ほどのコードでは、trainsetに対してインデックス指定などの手段で値を得るときには、0〜255の値が-1.0〜1.0の値に変換されていました。そして、上記のコードを見ると、transformでは、ToTensorクラスとNormalizeクラスのインスタンスが指定されています。実は前者は0〜255の範囲の数値を0〜1.0の範囲の浮動小数点数値に変換するためのものです。そして、後者はインスタンス生成時に第1引数に指定した値をm、第2引数に指定した値をsとしたときに、おおざっぱにいうと「出力=(入力−m)/s」という計算を行うものです(mとsがタプルになっているのは、RGB値など複数のチャネルで画像が構成されている場合に、チャネルごとにそれらを指定できるようにするため)。

 ここではどちらも0.5なので、0.0〜1.0の範囲の浮動小数点数値が-1.0〜1.0の範囲の数値へと変換されます。例えば、入力(ToTensorクラスで変換された値)が0であれば、Normalizeによる変換の結果は「(0.0−0.5)/0.5=-1.0」となります。入力が0.5なら「(0.5−0.5)/0.5=0.0」に、入力が1なら「(1.0−0.5)/0.5=1.0」となります。

 このような変換を自動的に行うのが、transform引数の役割です。実際に学習を行う段階では、ループ処理の中で繰り返しtrainsetからデータを取り出して、それをニューラルネットワークに入力していきますが、このときに今述べたような変換処理が自動的に行われます。なお、このようなある範囲の値を別の一定範囲の値へと変換することを正規化(normalize)と呼びます。

 ここまでの話をまとめると次のようになります。

  • 手書き数字のサイズは28×28ピクセルで、各ピクセルが持つ値は0〜255
  • 訓練データとしては上記の手書き数字が6万個、テストデータとしては1万個用意されている
  • 学習やテストでデータを取り出すときに自動的に0〜255の範囲の値が-1.0〜1.0の範囲の値に変換される

 では、1枚の手書き数字を構成する28×28(=784)個の値を(複数)受け取り、その数字が何であるかを推測するニューラルネットワーククラスを定義してみましょう。

Copyright© Digital Advantage Corp. All Rights Reserved.

       | 次のページへ
[an error occurred while processing this directive]
ページトップに戻る