対話生成におけるマルチカリキュラム学習の活用論文:Learning from Easy to Complex Adaptive Multi-curricula Learning for Neural Dialogue Generation

AAAI 2020の論文「Learning from Easy to Complex: Adaptive Multi-curricula Learning for Neural Dialogue Generation」より。

概要

カリキュラム学習を用いた対話生成の手法に関する論文。一般的なカリキュラム学習では一つの指標に応じてComplexityを決定して、易しい順に学習を進めていくが、対話生成においては単一の指標で複雑さを測れるものではない。

そこで、複数の複雑度指標を定義して、対話の複雑度を測定できるようにした。作成された複数(この論文では5つ)の複雑度指標を用いて multi-curricula学習によって、学習状況に応じて異なったcurriculaを自動的に選択することで学習を進めていく。

背景

例えば、既存のデータセットである、OpenSubtitlesには学習が難しい受け答えデータが含まれている。そのようなデータに対していきなりモデルが学習を行うことは難しい。

そこで、人間の子供のように簡単なデータから難しいデータへと学習していく方針を考える。これはいわゆるカリキュラム学習と呼ばれる分野であるが、対話システムのデータの場合、他のタスクと違って、単一の複雑度というものを決めづらい。そこでこの論文では5つの側面から複雑度を設定して、5つのcurriculaを作成している。

5つの指標

Specificity

対話システムは得てして、一般的な回答を返しがちである。できるだけ特定の会話内容に対しての返答を行ってほしいのでSpecificityという指標を考える。Normalized IDFを回答の中の単語に対して計算して平均を取る。それをSpecificityとして考える。

Repetitiveness

同じ単語ばかり使って回答を生成するよりも、いろいろな単語を使って回答を生成している方が文章の複雑度は高いと考えられる。そこで、過去に使った単語をどれだけ繰り返しているかというのをRepetitivenessという複雑度指標として考えることができる。

Query-relatedness

質問に対して関係する回答をしているかどうかというのは複雑度指標として使える。具体的な計算方法としては、質問と応答の文章の類似度を埋め込み表現のコサイン類似度を取ることで計算する。

Continuity

Query-relatednessと近い概念だが、応答に対する次の文章がどれくらい類似しているかを計算することで、会話が一貫してつながっているかを考えることができる。これも同じようにコサイン類似度を取ることで複雑度指標と考えることができる。

Model Confidence

モデルの出力の確信度も、回答についての難易度を示すものになると考えられるので、複雑度の指標として用いることができる。

これら5つの属性は相関を計算すると、ほとんど相関が無いことが分かるのでこれらの指標の取り方は良さそうだと考えられる。

5つの属性を活用するために、Adaptive Multi-curricula Learningを提唱。各カリキュラムからのデータ取得はバリデーションデータに対するモデルのパフォーマンスに応じて決定される。 決定方法は強化学習の考え方を使って決める。

実験

実験は3タスクについて、各5モデルを実行しているが、ほぼ全てのケースでこのカリキュラム学習方式を導入した方が性能向上している。人間による主観評価でも、4割以上は良いと評価され、4割程度は優劣つかなかったので、全体としては今回の手法が主観的にも良い結果を出しているといえる。

Ablation Studyとして、5つの属性のうち1つだけ使ってみた場合もそれぞれ検討されているが、全体的に5つすべての指標を使ってカリキュラム学習した方が性能が良い。そのほかにも強化学習ベースの代わりにランダムポリシーを使った場合と、難しいものから先に学習する場合が実験されているが、どちらも今回の提案手法には及ばない。

所感

カリキュラム学習の考え方は人間の学習に関する方法からアイデアを受けているが、確かに一つの指標からデータの難易度は決められないと思うので、今回の複数指標を用いたカリキュラム学習で性能向上するのは直感にあっている。

今回の複雑度指標は、データセットから計算できるものであるため、人手による難易度付けが不要なため、その他のタスクについても活用できる部分があるのではないかと感じた。


コメントする

メールアドレスが公開されることはありません。 が付いている欄は必須項目です