収録日:20210307
発表者:@yohei_kikuta
聞き手:@smochi_pub
内容:
- DALL·E の理解に向けてシリーズ第三弾
- DALL·E の論文がついに公開(第一弾と第二弾を収録したいいタイミングで出てきたのでびっくり)
- 第一弾と第二弾でやった Vision Transformer と CLIP について振り返りつつ、DALL·E の論文を理解するために読んでみた
- コードも公開されているが、ごく一部のみで discreate VAE のモデル定義と学習済みモデルのみ公開
- DALL·E の論文紹介
- モチベーション
- これまで複雑な構成要素(アーキテクチャ等)が必要だった text-to-image の領域において、transformer ベースのシンプルなモデルで挑む
- text-to-image では object が捻れたり変な位置に object が存在したりすることが度々起こるが、それを解決して自然な画像を生成したい
- MS-COCO のような小さめのデータ(O(10^5))ではなく、大量に集めた image-text ペア(O(10^8))で学習をして transformer の力を引き出したい
- モデルの説明
- text 部分の transformer は既知のものとして割愛。image 部分の transformer としては第一弾でやった Vision Transformer を使う
- (第一弾の復習)pixel 単位で画像を token として扱うのはメモリも食うし高周波成分が重視され過ぎてしまう傾向があるので、16x16 とかのパッチ単位で一つの token とするようにしてうまくいっていた。Vision Transformer では pre-training が supervised であったが、それ以外は基本的には text の transformer 型モデルと同様だった
- text 部分と image 部分が同様に transformer 型モデルで扱えるので、モデルとしては入力に text と image の token をそれぞれ並べたものとしている。text が 256 token で image が 32x32 = 1024 token としている(image 部分は 256x256 の画像を 32x32 のパッチにしている)。text-to-image のモデルなので、text は given として image token を auto-regressinve に出力していく transformer decoder モデルとなる。学習は Language Model として unsupervised にやる(入力 token を再現できるように学習)
- Vision Transformer では token をそのまま reproduce するというのは supervised pre-tranining と比べて明確に劣後していた。DALL·E では image token の取りうる値を量子化してそれを回避。8192 個のコードブック(これは text における語彙数に相当する)を有する discrete VAE encoder で image を token を encode し、transformer から出力される token はこの discrete VAE の decoder を用いて画像に戻す。取りうる値を制限することで Language Model の学習がうまく work してそうで、生成される画像もわけわからんものが描写されてて破綻してるみたいなものを防げていそう
- 学習の概要は Variational Autoencoder とかでよく出てくる ELBO を最大化するように進めるという説明がなされている。もともと image x に対して ln p(x) 潜在変数 z を使ってモデル化して、p(z|x) の likelihood と p(z|x) と p(z) の KL divergence で下から押さえた上で z についてうまく学習(よい潜在変数の確率分布を得る)というのが ELBO だった。今回は image だけでなく caption y もあるので x -> x,y になるというのが直感的な理解。最初に discrete VAE で x -> z を作れるようにして、その後 transformer で y,z の分布を学習する
- 学習したモデルを使った画像生成は、text を入力し、transformer decoder で画像 token を順に出力させていけばよい。transformer decoder の出力は 8192 個あるコードブックの logit になっているので、その logit で閾値以上のものに対して logit の値に応じた multinomial 分布からサンプルすることでいろいろな画像を生成することができる。必ずしも text だけを入力にする必要はなく、image の上の部分も与えて続きを生成させたりできる(DALL·E のブログでやってた)
- 生成した画像は入力した text と共に CLIP で reranking することでより良い画像を選ぶことができる。
- (第二弾の復習)CLIP は text と image をそれぞれ別のモデルを使って encode し、その類似度を計算して text 特徴量と image 特徴量を結び付けて学習するモデルであった。zero-shot 予測ができるところが強みだったが、reranking ではこの text と image の特徴量が結びついてるというところが大きい。生成画像で多様なものができたときに、text によくマッチした image を類似度計算によって選出することができるため。CLIP で reranking すると比較した時に明確に良い生成画像が選出できるようになる
- text token と image token の間が空くところは特定の token を入れるとか、attention mask の部分は 3 つの sparce attention mask を使っているとか、他にも色々あるけど、discrete VAE を理解すればあとはこれまで学んだ知識で理解できる
- text 部分の transformer は既知のものとして割愛。image 部分の transformer としては第一弾でやった Vision Transformer を使う
- 実験
- データは後で評価するために MS-COCO を除いて image-text ペアのデータを 250 million ほどウェブから作成。キャプション短すぎるものを省いたり英語以外を弾いたりアスペクト比が極端すぎるものを除いたりしてる
- 学習に関してはかなりいろいろなテクニックが詰め込まれている。discrete VAE の学習のための gumbel softmax relaxation とか loss の重み付けとか mixed-precision training (per-resblock gradient)とか powerSGD を使った分散学習とか。この論文はこの辺の学習テクニックの説明が多い。うまく学習させるのが重要という話なのだとは思うけど、コードが公開されてないのであれこれ書くのもいいけどコード公開してくれ...という感じ
- 生成画像の評価の一つ目は人間によるチェック。MS-COCO のデータを使い、先行研究と比べてどちらが realistic かと caption をよりよくマッチしてるのはどちらかを選んでもらっている。先行研究は DF-GAN でこれは MS-COCO のデータを使って学習してる。5 人に vote してもらって多数決を取ると、realism では 90.0%, マッチ度合いでは 93.3% で先行研究を圧倒している(先行研究では make sense な画像を出すのもかなり難しいというレベルだったので、この圧勝は凄い)
- 他には GAN の評価で使う Inception Score や Frechet Inception Distance でも評価している。zero-shot でも MS-COCO で同程度の性能で、CUB だとやたら悪いけど fine-tuning すればよさそうというコメントだけしている。個人的にはこの辺のスコアが良いことにあまり意味を感じないので詳しくは追わない
- モチベーション
- 改めて DALL·E のブログの生成画像例を眺めてみる。いやぁ本当に凄い。間違いなくここ数年で一番の衝撃だった。特にアニメ的な画像の生成とかは真に驚くべきレベルだよな。先行研究との比較が難しいというところなのか、この辺の凄さが論文ではほとんど語られてないのは寂しいけども(そしてコード公開が discrete VAE だけなのは度し難い...)。
- DALL·E の理解シリーズを三回やってみた。シリーズ物だと知識が活かされるので比較的しっかりした勉強にもなるしなかなか良いのではと思う
参考情報:
- Zero-Shot Text-to-Image Generation:https://arxiv.org/abs/2102.12092
- 公式実装(discrete VAE のみ):https://github.com/openai/dall-e
- PyTorch の非公式実装:https://github.com/lucidrains/DALLE-pytorch
- PowerSGD: Practical Low-Rank Gradient Compression for Distributed Optimization:https://arxiv.org/abs/1905.13727
- DF-GAN: Deep Fusion Generative Adversarial Networks for Text-to-Image Synthesis:https://arxiv.org/abs/2008.05865