連載
» 2020年06月08日 05時00分 公開

第8回 分類問題をディープラーニング(基本のDNN)で解こうTensorFlow 2+Keras(tf.keras)入門(3/3 ページ)

[一色政彦,デジタルアドバンテージ]
前のページへ 1|2|3       

―――【二値分類編】―――

 多くのコードは前ページと重複するので、説明は極力なしで、書き換えたコードのみを太字で示していく。コード自体は全体を掲載するので長いが、太字以外は読み飛ばしていただいて構わない。

(6)データの準備

 二値分類問題では「MNIST」データセットを用いる(図14)。

図14 手書き数字の画像データセット「MNIST」 図14 手書き数字の画像データセット「MNIST」

 データの仕様は同じであるが、分類カテゴリーが次のように変わる。

  • ラベル「0」: 手書き数字「0」
  • ラベル「1」: 手書き数字「1」
  • ラベル「2」: 手書き数字「2」
  • ラベル「3」: 手書き数字「3」
  • ラベル「4」: 手書き数字「4」
  • ラベル「5」: 手書き数字「5」
  • ラベル「6」: 手書き数字「6」
  • ラベル「7」: 手書き数字「7」
  • ラベル「8」: 手書き数字「8」
  • ラベル「9」: 手書き数字「9」

 先ほどとほぼ同じコードでデータを導入できる(リスト6-1)。二値分類なので、2個の分類カテゴリーしか要らない。よって、ラベルが「0」「1」以外はカットするフィルタリング処理を追記している。

# TensorFlowライブラリのtensorflowパッケージを「tf」という別名でインポート
import tensorflow as tf
import matplotlib.pyplot as plt  # グラフ描画ライブラリ(データ画像の表示に使用)
import numpy as np               # 数値計算ライブラリ(データのシャッフルに使用)

# Fashion-MNISTデータ(NumPyの多次元配列型)を取得する
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
# データ分割は自動で、訓練用が6万枚、テスト用が1万枚(ホールドアウト法)。
# さらにそれぞれを「入力データ(X:行列)」と「ラベル(y:ベクトル)」に分ける

# データのフィルタリング
b = np.where(y_train < 2)[0]  # 訓練データから「0」「1」の全インデックスの取得
X_train, y_train = X_train[b], y_train[b]  # そのインデックス行を抽出(=フィルタリング)
c = np.where(y_test < 2)[0]   # テストデータから「0」「1」の全インデックスの取得
X_test, y_test = X_test[c], y_test[c]      # そのインデックス行を抽出(=フィルタリング)

# 訓練データは、学習時のfit関数で訓練用と精度検証用に分割する。
# そのため、あらかじめ訓練データをシャッフルしておく
p = np.random.permutation(len(X_train))    # ランダムなインデックス順の取得
X_train, y_train = X_train[p], y_train[p]  # その順で全行を抽出する(=シャッフル)

# [内容確認]データのうち、最初の10枚だけを表示
classes_name = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
plt.figure(figsize=(10,4))  # 横:10インチ、縦:4インチの図
for i in range(10):
  plt.subplot(2,5,i+1# 図内にある(sub)2行5列の描画領域(plot)の何番目かを指定
  plt.xticks([])        # X軸の目盛りを表示しない
  plt.yticks([])        # y軸の目盛りを表示しない
  plt.grid(False)       # グリッド線を表示しない
  plt.imshow(           # 画像を表示する
    X_train[i],         # 1つの訓練用入力データ(28行×28列)
    cmap=plt.cm.binary) # 白黒(2値:バイナリ)の配色
  plt.xlabel(classes_name[y_train[i]])  # X軸のラベルに分類名を表示
plt.show()

リスト6-1 MNIST(手書き文字)画像データの取得

 先ほどと同様に、訓練データの1つ目の入力データとラベルを、出力して確かめてみる(リスト6-2)。

前のページへ 1|2|3       

Copyright© Digital Advantage Corp. All Rights Reserved.

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。