X

Fashion MNISTをKerasでCNNを使って分類してみた

ファッションアイテムを識別するタスクであるFashion MNISTというデータセットが登場しました。
https://github.com/zalandoresearch/fashion-mnist


(画像は上記githubページより)

このデータセットが登場した目的は、MNISTが簡単すぎる、MNISTは使われすぎ、MNISTは最近のコンピュータビジョンのタスクを表現していない、などの理由からだそうです。

まずは、データセットをダウンロードします。

git clone https://github.com/zalandoresearch/fashion-mnist.git

データの形式などはMNISTと同じで、分類するクラスも10個(Tシャツ、サンダル、バッグなど)です。

このデータをKerasを使って分類してみようと思います。バックエンドはTensorflowを使っています。
ネットワークの構造はLeNetを構築しています。

from keras import backend as K
from utils.mnist_reader import load_mnist
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense
from keras.models import Sequential
from keras.utils import np_utils
from keras.initializers import Constant
from keras.optimizers import Adam
import matplotlib.pyplot as plt

#load_mnistはutilsにある
X_train, y_train = load_mnist('data/fashion', kind='train')
X_test, y_test = load_mnist('data/fashion', kind='t10k')

X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)

y_test = np_utils.to_categorical(y_test, 10)
y_train = np_utils.to_categorical(y_train, 10)

K.set_image_dim_ordering("th")
#LeNetを構築する
model = Sequential()
model.add(Conv2D(20, kernel_size=5, padding="same", input_shape=(1,28,28)))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Conv2D(50, kernel_size=5, border_mode="same"))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Flatten())
model.add(Dense(500))
model.add(Activation("relu"))

model.add(Dense(10))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer=Adam(), metrics=["accuracy"])
history = model.fit(X_train, y_train, batch_size=128, epochs=20, verbose=1, validation_split=0.2)

score = model.evaluate(X_test, y_test, verbose=1)
print("Test score:", score[0])
print("Test accuracy:", score[1])
print(history.history.keys())

#グラフの表示
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

テスト精度は92.16%でした。MNISTだと99%程度の精度が出るネットワーク構造なので、MNISTより難しくなっているというのは本当のようです。

48000/48000 [==============================] - 259s - loss: 0.4973 - acc: 0.8208 - val_loss: 0.3587 - val_acc: 0.8742
Epoch 2/20
48000/48000 [==============================] - 260s - loss: 0.3178 - acc: 0.8861 - val_loss: 0.3043 - val_acc: 0.8893
Epoch 3/20
48000/48000 [==============================] - 258s - loss: 0.2735 - acc: 0.9013 - val_loss: 0.2783 - val_acc: 0.9025
Epoch 4/20
48000/48000 [==============================] - 256s - loss: 0.2397 - acc: 0.9124 - val_loss: 0.2502 - val_acc: 0.9110
Epoch 5/20
48000/48000 [==============================] - 256s - loss: 0.2122 - acc: 0.9229 - val_loss: 0.2716 - val_acc: 0.9058
Epoch 6/20
48000/48000 [==============================] - 266s - loss: 0.1897 - acc: 0.9308 - val_loss: 0.2683 - val_acc: 0.9053
Epoch 7/20
48000/48000 [==============================] - 259s - loss: 0.1679 - acc: 0.9381 - val_loss: 0.2570 - val_acc: 0.9118
Epoch 8/20
48000/48000 [==============================] - 260s - loss: 0.1489 - acc: 0.9460 - val_loss: 0.2557 - val_acc: 0.9114
Epoch 9/20
48000/48000 [==============================] - 260s - loss: 0.1277 - acc: 0.9524 - val_loss: 0.2430 - val_acc: 0.9195
Epoch 10/20
48000/48000 [==============================] - 260s - loss: 0.1156 - acc: 0.9568 - val_loss: 0.2435 - val_acc: 0.9198
Epoch 11/20
48000/48000 [==============================] - 260s - loss: 0.0965 - acc: 0.9639 - val_loss: 0.2452 - val_acc: 0.9183
Epoch 12/20
48000/48000 [==============================] - 259s - loss: 0.0824 - acc: 0.9696 - val_loss: 0.2705 - val_acc: 0.9159
Epoch 13/20
48000/48000 [==============================] - 261s - loss: 0.0689 - acc: 0.9752 - val_loss: 0.2851 - val_acc: 0.9148
Epoch 14/20
48000/48000 [==============================] - 258s - loss: 0.0588 - acc: 0.9790 - val_loss: 0.3054 - val_acc: 0.9178
Epoch 15/20
48000/48000 [==============================] - 275s - loss: 0.0506 - acc: 0.9823 - val_loss: 0.3397 - val_acc: 0.9215
Epoch 16/20
48000/48000 [==============================] - 313s - loss: 0.0423 - acc: 0.9858 - val_loss: 0.3490 - val_acc: 0.9161
Epoch 17/20
48000/48000 [==============================] - 296s - loss: 0.0376 - acc: 0.9866 - val_loss: 0.3412 - val_acc: 0.9215
Epoch 18/20
48000/48000 [==============================] - 289s - loss: 0.0299 - acc: 0.9894 - val_loss: 0.3668 - val_acc: 0.9173
Epoch 19/20
48000/48000 [==============================] - 291s - loss: 0.0313 - acc: 0.9887 - val_loss: 0.3972 - val_acc: 0.9141
Epoch 20/20
48000/48000 [==============================] - 268s - loss: 0.0251 - acc: 0.9911 - val_loss: 0.3806 - val_acc: 0.9194
 9984/10000 [============================>.] - ETA: 0sTest score: 0.362425664179
Test accuracy: 0.9216

出力されたグラフは次の通り。

以上です。MNISTの精度が99%を超えて飽和しつつあるので、今後はこのテストデータが広く使われることになるかもしれません。

Hiro: