ファッションアイテムを識別するタスクであるFashion MNISTというデータセットが登場しました。
https://github.com/zalandoresearch/fashion-mnist
(画像は上記githubページより)
このデータセットが登場した目的は、MNISTが簡単すぎる、MNISTは使われすぎ、MNISTは最近のコンピュータビジョンのタスクを表現していない、などの理由からだそうです。
まずは、データセットをダウンロードします。
git clone https://github.com/zalandoresearch/fashion-mnist.git
データの形式などはMNISTと同じで、分類するクラスも10個(Tシャツ、サンダル、バッグなど)です。
このデータをKerasを使って分類してみようと思います。バックエンドはTensorflowを使っています。
ネットワークの構造はLeNetを構築しています。
from keras import backend as K from utils.mnist_reader import load_mnist from keras.layers.convolutional import Conv2D, MaxPooling2D from keras.layers.core import Activation, Flatten, Dense from keras.models import Sequential from keras.utils import np_utils from keras.initializers import Constant from keras.optimizers import Adam import matplotlib.pyplot as plt #load_mnistはutilsにある X_train, y_train = load_mnist('data/fashion', kind='train') X_test, y_test = load_mnist('data/fashion', kind='t10k') X_train = X_train.astype('float32') / 255 X_test = X_test.astype('float32') / 255 X_test = X_test.reshape(X_test.shape[0], 1, 28, 28) X_train = X_train.reshape(X_train.shape[0], 1, 28, 28) y_test = np_utils.to_categorical(y_test, 10) y_train = np_utils.to_categorical(y_train, 10) K.set_image_dim_ordering("th") #LeNetを構築する model = Sequential() model.add(Conv2D(20, kernel_size=5, padding="same", input_shape=(1,28,28))) model.add(Activation("relu")) model.add(MaxPooling2D()) model.add(Conv2D(50, kernel_size=5, border_mode="same")) model.add(Activation("relu")) model.add(MaxPooling2D()) model.add(Flatten()) model.add(Dense(500)) model.add(Activation("relu")) model.add(Dense(10)) model.add(Activation("softmax")) model.compile(loss="categorical_crossentropy", optimizer=Adam(), metrics=["accuracy"]) history = model.fit(X_train, y_train, batch_size=128, epochs=20, verbose=1, validation_split=0.2) score = model.evaluate(X_test, y_test, verbose=1) print("Test score:", score[0]) print("Test accuracy:", score[1]) print(history.history.keys()) #グラフの表示 plt.plot(history.history['acc']) plt.plot(history.history['val_acc']) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(['train', 'test'], loc='upper left') plt.show() plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('model loss') plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train', 'test'], loc='upper left') plt.show()
テスト精度は92.16%でした。MNISTだと99%程度の精度が出るネットワーク構造なので、MNISTより難しくなっているというのは本当のようです。
48000/48000 [==============================] - 259s - loss: 0.4973 - acc: 0.8208 - val_loss: 0.3587 - val_acc: 0.8742 Epoch 2/20 48000/48000 [==============================] - 260s - loss: 0.3178 - acc: 0.8861 - val_loss: 0.3043 - val_acc: 0.8893 Epoch 3/20 48000/48000 [==============================] - 258s - loss: 0.2735 - acc: 0.9013 - val_loss: 0.2783 - val_acc: 0.9025 Epoch 4/20 48000/48000 [==============================] - 256s - loss: 0.2397 - acc: 0.9124 - val_loss: 0.2502 - val_acc: 0.9110 Epoch 5/20 48000/48000 [==============================] - 256s - loss: 0.2122 - acc: 0.9229 - val_loss: 0.2716 - val_acc: 0.9058 Epoch 6/20 48000/48000 [==============================] - 266s - loss: 0.1897 - acc: 0.9308 - val_loss: 0.2683 - val_acc: 0.9053 Epoch 7/20 48000/48000 [==============================] - 259s - loss: 0.1679 - acc: 0.9381 - val_loss: 0.2570 - val_acc: 0.9118 Epoch 8/20 48000/48000 [==============================] - 260s - loss: 0.1489 - acc: 0.9460 - val_loss: 0.2557 - val_acc: 0.9114 Epoch 9/20 48000/48000 [==============================] - 260s - loss: 0.1277 - acc: 0.9524 - val_loss: 0.2430 - val_acc: 0.9195 Epoch 10/20 48000/48000 [==============================] - 260s - loss: 0.1156 - acc: 0.9568 - val_loss: 0.2435 - val_acc: 0.9198 Epoch 11/20 48000/48000 [==============================] - 260s - loss: 0.0965 - acc: 0.9639 - val_loss: 0.2452 - val_acc: 0.9183 Epoch 12/20 48000/48000 [==============================] - 259s - loss: 0.0824 - acc: 0.9696 - val_loss: 0.2705 - val_acc: 0.9159 Epoch 13/20 48000/48000 [==============================] - 261s - loss: 0.0689 - acc: 0.9752 - val_loss: 0.2851 - val_acc: 0.9148 Epoch 14/20 48000/48000 [==============================] - 258s - loss: 0.0588 - acc: 0.9790 - val_loss: 0.3054 - val_acc: 0.9178 Epoch 15/20 48000/48000 [==============================] - 275s - loss: 0.0506 - acc: 0.9823 - val_loss: 0.3397 - val_acc: 0.9215 Epoch 16/20 48000/48000 [==============================] - 313s - loss: 0.0423 - acc: 0.9858 - val_loss: 0.3490 - val_acc: 0.9161 Epoch 17/20 48000/48000 [==============================] - 296s - loss: 0.0376 - acc: 0.9866 - val_loss: 0.3412 - val_acc: 0.9215 Epoch 18/20 48000/48000 [==============================] - 289s - loss: 0.0299 - acc: 0.9894 - val_loss: 0.3668 - val_acc: 0.9173 Epoch 19/20 48000/48000 [==============================] - 291s - loss: 0.0313 - acc: 0.9887 - val_loss: 0.3972 - val_acc: 0.9141 Epoch 20/20 48000/48000 [==============================] - 268s - loss: 0.0251 - acc: 0.9911 - val_loss: 0.3806 - val_acc: 0.9194 9984/10000 [============================>.] - ETA: 0sTest score: 0.362425664179 Test accuracy: 0.9216
出力されたグラフは次の通り。
以上です。MNISTの精度が99%を超えて飽和しつつあるので、今後はこのテストデータが広く使われることになるかもしれません。