検索
連載

第8回 分類問題をディープラーニング(基本のDNN)で解こうTensorFlow 2+Keras(tf.keras)入門(3/3 ページ)

回帰問題の次は、分類問題の基礎をマスターしよう。二値分類/多クラス分類の場合で一般的に使われる活性化関数や損失関数をしっかりと押さえる。また過学習問題の対処方法について言及する。

PC用表示 関連情報
Share
Tweet
LINE
Hatena
前のページへ |       

―――【二値分類編】―――

 多くのコードは前ページと重複するので、説明は極力なしで、書き換えたコードのみを太字で示していく。コード自体は全体を掲載するので長いが、太字以外は読み飛ばしていただいて構わない。

(6)データの準備

 二値分類問題では「MNIST」データセットを用いる(図14)。

図14 手書き数字の画像データセット「MNIST」
図14 手書き数字の画像データセット「MNIST」

 データの仕様は同じであるが、分類カテゴリーが次のように変わる。

  • ラベル「0」: 手書き数字「0」
  • ラベル「1」: 手書き数字「1」
  • ラベル「2」: 手書き数字「2」
  • ラベル「3」: 手書き数字「3」
  • ラベル「4」: 手書き数字「4」
  • ラベル「5」: 手書き数字「5」
  • ラベル「6」: 手書き数字「6」
  • ラベル「7」: 手書き数字「7」
  • ラベル「8」: 手書き数字「8」
  • ラベル「9」: 手書き数字「9」

 先ほどとほぼ同じコードでデータを導入できる(リスト6-1)。二値分類なので、2個の分類カテゴリーしか要らない。よって、ラベルが「0」「1」以外はカットするフィルタリング処理を追記している。

# TensorFlowライブラリのtensorflowパッケージを「tf」という別名でインポート
import tensorflow as tf
import matplotlib.pyplot as plt  # グラフ描画ライブラリ(データ画像の表示に使用)
import numpy as np               # 数値計算ライブラリ(データのシャッフルに使用)

# Fashion-MNISTデータ(NumPyの多次元配列型)を取得する
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
# データ分割は自動で、訓練用が6万枚、テスト用が1万枚(ホールドアウト法)。
# さらにそれぞれを「入力データ(X:行列)」と「ラベル(y:ベクトル)」に分ける

# データのフィルタリング
b = np.where(y_train < 2)[0]  # 訓練データから「0」「1」の全インデックスの取得
X_train, y_train = X_train[b], y_train[b]  # そのインデックス行を抽出(=フィルタリング)
c = np.where(y_test < 2)[0]   # テストデータから「0」「1」の全インデックスの取得
X_test, y_test = X_test[c], y_test[c]      # そのインデックス行を抽出(=フィルタリング)

# 訓練データは、学習時のfit関数で訓練用と精度検証用に分割する。
# そのため、あらかじめ訓練データをシャッフルしておく
p = np.random.permutation(len(X_train))    # ランダムなインデックス順の取得
X_train, y_train = X_train[p], y_train[p]  # その順で全行を抽出する(=シャッフル)

# [内容確認]データのうち、最初の10枚だけを表示
classes_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
plt.figure(figsize=(10,4))  # 横:10インチ、縦:4インチの図
for i in range(10):
  plt.subplot(2,5,i+1# 図内にある(sub)2行5列の描画領域(plot)の何番目かを指定
  plt.xticks([])        # X軸の目盛りを表示しない
  plt.yticks([])        # y軸の目盛りを表示しない
  plt.grid(False)       # グリッド線を表示しない
  plt.imshow(           # 画像を表示する
    X_train[i],         # 1つの訓練用入力データ(28行×28列)
    cmap=plt.cm.binary) # 白黒(2値:バイナリ)の配色
  plt.xlabel(classes_name[y_train[i]])  # X軸のラベルに分類名を表示
plt.show()

リスト6-1 MNIST(手書き文字)画像データの取得

 先ほどと同様に、訓練データの1つ目の入力データとラベルを、出力して確かめてみる(リスト6-2)。

Copyright© Digital Advantage Corp. All Rights Reserved.

前のページへ |       
[an error occurred while processing this directive]
ページトップに戻る