Kerasによるニューラルネットワーク本「Deep Learning with Keras」を読んだ

Deep Learningの基本的な仕組みなどについては大体把握してきたと思うので、実際に動くコードを書くにはどうすればよいのかということを学ぶために、Kerasによるニューラルネットワーク本、「Deep Learning with Keras」を読みました。

Kerasはご存知の通り、TensorflowやTheanoなどのDeep Learning基盤を使いやすくするためのフレームワークです。実際にKerasを使ってみると、難しそうなイメージがあるDeep Learningは積み木のように構築出来て、結構簡単じゃないかという自信を得られるので、一般的なユーザーには生のTensorflowを使うよりもお勧めです。

この本の内容は、Kerasのインストール、基本的なフィードフォワードNNから始まり、CNN、RNN、WordEmbedding、GAN、転移学習、強化学習、などと幅広くカバーされています。すべての内容にサンプルコードが付いているので、実際に動かすことができます。段々と後ろの章に進むにつれて、学習に時間がかかるコードが増えてくるので、手元の環境で動かそうとする場合にはGPUがあったほうが良いかと思います。サンプルコードはKeras2.0で書かれていますが、現在の最新版だと微妙にAPIが変わっている部分もあるようで修正が必要な個所もありましたが、調べればすぐに出てくるレベルの違いなので(少なくとも現時点では)特に大きな問題にはならないかと思います。

基本的にはサンプルはそんなに複雑な内容を扱っているわけではなく、コードを部分ごとに解説とともに説明されているので、Kerasの深い知識が無くてもサクサク読むことが出来ました。実際に自分が作りたいアプリケーションに向けて参考にするにはとても良いサンプルが揃っていると思います。

個人的にはGANについて概要レベルで仕組みを知ることが出来たのが大きな収穫でした。何となく理論は知っているんだけど、いざ実際にコードに落とすにはどうしたら良いのか?と思う方にはお勧めです。日本だとChainerがこの手のフレームワークとして人気だと聞きますが、世界的にみるとKerasの人気は強いとも聞きます。今後の開発がどのように進んでいくのかにもよりますが、こういったフレームワークは一つ手を付けておけば他のフレームワークを使う際も似た部分は多いかと思うので、応用が効くかと思います。

次はもう少し数学的な基礎を固めようかと思うので、統計の教科書やMurphy本あたりを読んでみようかと思っています。

Fashion MNISTをKerasでCNNを使って分類してみた

ファッションアイテムを識別するタスクであるFashion MNISTというデータセットが登場しました。
https://github.com/zalandoresearch/fashion-mnist


(画像は上記githubページより)

このデータセットが登場した目的は、MNISTが簡単すぎる、MNISTは使われすぎ、MNISTは最近のコンピュータビジョンのタスクを表現していない、などの理由からだそうです。

まずは、データセットをダウンロードします。

git clone https://github.com/zalandoresearch/fashion-mnist.git

データの形式などはMNISTと同じで、分類するクラスも10個(Tシャツ、サンダル、バッグなど)です。

このデータをKerasを使って分類してみようと思います。バックエンドはTensorflowを使っています。
ネットワークの構造はLeNetを構築しています。

from keras import backend as K
from utils.mnist_reader import load_mnist
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense
from keras.models import Sequential
from keras.utils import np_utils
from keras.initializers import Constant
from keras.optimizers import Adam
import matplotlib.pyplot as plt

#load_mnistはutilsにある
X_train, y_train = load_mnist('data/fashion', kind='train')
X_test, y_test = load_mnist('data/fashion', kind='t10k')

X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)

y_test = np_utils.to_categorical(y_test, 10)
y_train = np_utils.to_categorical(y_train, 10)

K.set_image_dim_ordering("th")
#LeNetを構築する
model = Sequential()
model.add(Conv2D(20, kernel_size=5, padding="same", input_shape=(1,28,28)))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Conv2D(50, kernel_size=5, border_mode="same"))
model.add(Activation("relu"))
model.add(MaxPooling2D())

model.add(Flatten())
model.add(Dense(500))
model.add(Activation("relu"))

model.add(Dense(10))
model.add(Activation("softmax"))

model.compile(loss="categorical_crossentropy", optimizer=Adam(), metrics=["accuracy"])
history = model.fit(X_train, y_train, batch_size=128, epochs=20, verbose=1, validation_split=0.2)

score = model.evaluate(X_test, y_test, verbose=1)
print("Test score:", score[0])
print("Test accuracy:", score[1])
print(history.history.keys())

#グラフの表示
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

テスト精度は92.16%でした。MNISTだと99%程度の精度が出るネットワーク構造なので、MNISTより難しくなっているというのは本当のようです。

48000/48000 [==============================] - 259s - loss: 0.4973 - acc: 0.8208 - val_loss: 0.3587 - val_acc: 0.8742
Epoch 2/20
48000/48000 [==============================] - 260s - loss: 0.3178 - acc: 0.8861 - val_loss: 0.3043 - val_acc: 0.8893
Epoch 3/20
48000/48000 [==============================] - 258s - loss: 0.2735 - acc: 0.9013 - val_loss: 0.2783 - val_acc: 0.9025
Epoch 4/20
48000/48000 [==============================] - 256s - loss: 0.2397 - acc: 0.9124 - val_loss: 0.2502 - val_acc: 0.9110
Epoch 5/20
48000/48000 [==============================] - 256s - loss: 0.2122 - acc: 0.9229 - val_loss: 0.2716 - val_acc: 0.9058
Epoch 6/20
48000/48000 [==============================] - 266s - loss: 0.1897 - acc: 0.9308 - val_loss: 0.2683 - val_acc: 0.9053
Epoch 7/20
48000/48000 [==============================] - 259s - loss: 0.1679 - acc: 0.9381 - val_loss: 0.2570 - val_acc: 0.9118
Epoch 8/20
48000/48000 [==============================] - 260s - loss: 0.1489 - acc: 0.9460 - val_loss: 0.2557 - val_acc: 0.9114
Epoch 9/20
48000/48000 [==============================] - 260s - loss: 0.1277 - acc: 0.9524 - val_loss: 0.2430 - val_acc: 0.9195
Epoch 10/20
48000/48000 [==============================] - 260s - loss: 0.1156 - acc: 0.9568 - val_loss: 0.2435 - val_acc: 0.9198
Epoch 11/20
48000/48000 [==============================] - 260s - loss: 0.0965 - acc: 0.9639 - val_loss: 0.2452 - val_acc: 0.9183
Epoch 12/20
48000/48000 [==============================] - 259s - loss: 0.0824 - acc: 0.9696 - val_loss: 0.2705 - val_acc: 0.9159
Epoch 13/20
48000/48000 [==============================] - 261s - loss: 0.0689 - acc: 0.9752 - val_loss: 0.2851 - val_acc: 0.9148
Epoch 14/20
48000/48000 [==============================] - 258s - loss: 0.0588 - acc: 0.9790 - val_loss: 0.3054 - val_acc: 0.9178
Epoch 15/20
48000/48000 [==============================] - 275s - loss: 0.0506 - acc: 0.9823 - val_loss: 0.3397 - val_acc: 0.9215
Epoch 16/20
48000/48000 [==============================] - 313s - loss: 0.0423 - acc: 0.9858 - val_loss: 0.3490 - val_acc: 0.9161
Epoch 17/20
48000/48000 [==============================] - 296s - loss: 0.0376 - acc: 0.9866 - val_loss: 0.3412 - val_acc: 0.9215
Epoch 18/20
48000/48000 [==============================] - 289s - loss: 0.0299 - acc: 0.9894 - val_loss: 0.3668 - val_acc: 0.9173
Epoch 19/20
48000/48000 [==============================] - 291s - loss: 0.0313 - acc: 0.9887 - val_loss: 0.3972 - val_acc: 0.9141
Epoch 20/20
48000/48000 [==============================] - 268s - loss: 0.0251 - acc: 0.9911 - val_loss: 0.3806 - val_acc: 0.9194
 9984/10000 [============================>.] - ETA: 0sTest score: 0.362425664179
Test accuracy: 0.9216

出力されたグラフは次の通り。

以上です。MNISTの精度が99%を超えて飽和しつつあるので、今後はこのテストデータが広く使われることになるかもしれません。