収録日:20210425
発表者:@yohei_kikuta
内容:
- Twitter か何かで prompt tuning の話を目にして面白そうだと思って読んでみることにした
- そもそも GPT3/T5 のような一つの汎用モデルで text prompt を準備して多様なタスクを解くというのは面白いと思ってたので、text prompt をいい感じにするというのは人間の手が介在する部分なのでここがどう発展していくかは気になっていた
- 気になっていたと言いつつ全然追えてなかったが、Twitter でこの論文を目にしたので、ちゃんと読んでみるかと思って読んでみた
- 論文のモチベーションは、大きな汎用言語モデルには手を加えずに prompt 部分を学習可能にすることで、モデル全体を再学習する手法と同等の精度を達成できないかを目指すというもの
- prompt design (入力文章の前にタスクに関する記述や例文を入れる方法) は、downstream タスクをなかなか良い精度で解くことができるが、モデル全体を downstream タスクで再学習する場合と比べるとかなり落ちる (モデルやタスクにも依るが SuperGLUE で 18 point 程度劣る)
- モデル全体を再学習するということは、タスク毎にモデル全体のパラメタを持たなければないことを意味し、かなり大変なのでそれを回避しつつでも精度はもっと高めたいという欲求がある
- ベースとなるモデルは基本的に freeze させるが、一部を soft (continuous) にして学習可能にするというアプローチが出てきていて、この論文では入力 (しかも pre-training で使う部分とは別の要素) だけを学習可能にしてどこまでいけるかを試している
- 提案手法を prompt tuning と名付けていて、これは基本的にタスク毎に元の入力文章の語彙とは別の語彙を 10~100 個当てがって (語彙といっても単語的な意味はなく、単に既存 vocabulary 以外の ID を付与するということ)、その語彙の embedding のみを downstream タスクで学習し、タスク記述語彙 embedding + 元の入力テキスト embedding を freeze している汎用言語モデルの入力にするというもの
- タスク毎に k 個の prompt token を準備 (合計ではタスク数 x k) し、これをベースとなる汎用言語モデルの token embedding 次元 e に embed する新たな k x e の embedding matrix のみが downstream で learnable となる
- ナイーブにはこれを学習するだけということなのだが、この embedding matrix の初期化の仕方には工夫が必要。embed した結果は freeze している汎用言語モデルに投入することになるので、元の入力テキストの意味で token 相当の意味を持つ embedding が親和性が高いように思われる。そこで、既存の vocabulary の embedding からサンプリングしたものを初期値とする。他にも classification の問題であれば出力がクラスを表す単語になるので、それを guide するようにクラスを表す単語の embedding を初期値にするという方法もある
- それ以外に重要なことは pre-training の仕方。T5 の構成を使うが、T5 では学習効率化のため出力は corrupt させた部分のみを対象としている。pre-training 後に downstream で再学習すればこれは気にしなくていいけど、提案手法ではこの部分は完全に freeze させるので、自然で完全な文章を扱えるようにしておきたい。corrupt 部分のみを出力させる普通の T5 の学習の後に、これも T5 で出てきた prefix LM で追加学習することで解決 (元の入力文章を二つに分けて、途中までを入力して残りを出力する)。
- 各種パラメタは実験で調査するが、最終的には prompt token 数が 100 で、embedding 初期化は既存のものからサンプリングもしくはクラス単語で初期化し、prefix LM で 100K ステップ追加学習する、というのが prompt tuning の完成形となる
- 実験は T5 と同じで SuperGLUE を T5 の各モデルサイズで試し、10^10 パラメタという一番大きなモデルでは prompt tuning がモデル全体を再学習したものと同程度の結果を出すことを示した
- モデルサイズをめちゃ大きくするまではモデル全体を再学習する方法の方が優れているが、10^10 パラメタくらいになるとなんと同程度になる。驚き。モデルサイズがめちゃでかいとそこに必要な情報がほぼ含まれていて、タスク固有の情報は prompt 部分での処理だけで十分賄えるということなのだろうか
- 他にも prompt 部分だけでタスクに対応するので overfit しにくいだろうということで、zero shot の domain shift での実験もしている。二つの入力が同じようなものか否かを判断するタスクで、一つを Quora からのデータ、もう一つを news からのデータとして、一方で学習したモデルをもう一方のデータで検証するということをして、モデル全体再学習よりも明確に良い結果を得た
- prompt ensembling という手法も提案していて、これは prompt tuning を複数作ってそれの majority vote を取るというシンプルなもの。巨大モデルを複製する必要がなくお手軽にできるが、確かに効果はあってよさそう
- まとめ
- 入力テキストの前段に学習可能な prompt tuning という機構を入れるだけでモデル全体を再学習する場合と同程度の精度を発揮するというのはなかなかの驚き。これも巨大モデルの成せる業か
- overfit しにくいだろうという期待のもと domain shift に強いことを実験で示したり、比較的低コストで ensemble ができることも示したりで、実用方面でも効きそうな話があるのもよい
- アイデアを実装しみたら大きな成果を達成したという類の論文なので、こういうのは (こういうのじゃなくてもだけど) 実装を公開して欲しい... pre-training モデルに手を加えることなく比較的お手軽に試せるんだしさ...
参考情報:
- The Power of Scale for Parameter-Efficient Prompt Tuning:https://arxiv.org/abs/2104.08691
- 菊田が論文を読んだときのメモ:yoheikikuta/paper-reading#59
- T5 の論文読みメモ:yoheikikuta/paper-reading#55
- CLIP の論文読みメモ:yoheikikuta/paper-reading#57