ABEJA Tech Blog

中の人の興味のある情報を発信していきます

Stable Diffusion の仕組みを理解する

この記事は、ABEJAアドベントカレンダー2022 の 19 日目の記事です。

こんにちは!株式会社 ABEJA で ABEJA Platform 開発を行っている坂井です。

世間では Diffusion Model 使った AI による画像生成が流行っているみたいですね。 自分は元々 Computer Vision 系の機械学習エンジニアだったんですが、この1年くらいは AI モデル開発ではなくもっぱらバックエンド開発メインでやっていて完全に乗り遅れた感あるので、この機会に有名な Diffusion Model の1つである Stable Diffusion v1 について調べてみました!*1

では早速本題に入りたいと思います!

Stable Diffusion v1 とは?

Stable Diffusion v1 は、論文「High-Resolution Image Synthesis with Latent Diffusion Models」で提案されたモデル(LDM)をベースにしたテキストから画像を生成する(text-to-image と言ったりします)の Diffusion モデルになってます。

もっと正確に言うと、LDM は text-to-image 以外にも超解像度化や inpainting などのマルチモーダルな画像生成を可能にした Diffusion モデルですが、この内 text-to-image 用のネットワークを「LAION-5B データセット」と呼ばれる50 億枚!!!の{テキスト・画像}ペアデータセットを用いて学習し、条件付け生成のための Encoder として 事前学習済み CLIP ViT-L/14 を使用したモデルが Stable Diffusion v1 になっているようです。

なのでこの記事では、この LDM についての解説がメインになっています。

さてこの LDM ですが、Diffusion ベースのモデルなので、そもそも基本的な Diffusion モデル自体を何も理解してない自分が理解しようと思っても中々厳しいものがありました。前提として、論文「Denoising Diffusion Probabilistic Model」で提案されている DDPM あたりの理解が最低限必要そうだったので、まず DDPM の紹介から入っていきたいと思います。

Denoising Diffusion Probabilistic Model(DDPM)

Denoising Diffusion Probabilistic Model(DDPM)は Diffusion model を最初に提案した論文ではなく、オリジナルの Diffusion model を改良したモデルになっていて、後発の多くの Diffusion 系モデルのベースラインとして採用されているようです。

DDPM は学習時と推論時で挙動が異なるので、それぞれの場合で分けて説明していきたいと思います。

学習時の動作

上図は、DDPM の学習時の処理の流れを示した図になっています。

DDPM は学習時、forward process (diffusion process) と呼ばれる学習用データセットの画像からノイズ画像を生成するステップと reverse process と呼ばれるノイズ画像から元の画像を生成するステップの2つのステップから構成されます。

この内 forward process では学習を行わず、reverse process のみニューラルネットワークで学習を行うことに注意してください。

forward process (diffusion process)

forward process (diffusion process) は、学習用データセットの画像  x_0 からタイムステップ t =0 \sim T 度に徐々にノイズを加えていってノイズ付き画像  x_1, x_2, ..., を生成し、完全なノイズ画像 x_T を生成するプロセスです。

この処理過程は確率過程(特にマルコフ連鎖)とみなすことができるので、以下の式で定義できます。

マルコフ連鎖なので、次のステップでの確率分布は前回の確率分布のみに依存し、各ステップでの確率分布は正規分布に従います。

このノイズ画像を生成するプロセスは、ニューラルネットワークのような学習ありの処理で行うのではなく、学習パラメーターを持たないガウガジアン処理のような通常の画像処理で行っています。そのため forward process は DDPM における学習対象ではありません。*2

この forward process で生成した各ステップ度の一連のノイズ画像は、後述の reverse process におけるニューラルネットワークの学習用データとして使用されます。

reverse process

reverse process は、forward process で生成したノイズ画像から元の画像を生成するプロセスです。 この処理過程は確率過程(特にマルコフ連鎖)とみなすことができるので、以下の式で定義できます。

マルコフ連鎖なので、次のステップでの確率分布は前回の確率分布のみに依存し、各ステップでの確率分布は正規分布に従います。

後述の損失関数の項目でも説明しますが、reverse process ではこの各ステップでの正規分布の平均値と分散値をニューラルネットワークで学習&推定することになります。

DDM では、このニューラルネットとして Unet ベースのモデルである PixelCNN++ をベースとして、拡散時間 t をネットワークの各残差ブロックに Transformer の self attention 構造で入力できるようにしたネットワークを採用しています。

Transformer の attention 構造を使用することにより、以下の図のように CNN 構造と比べて画像全体の大域的な特徴量を失うことなく後段のネットワークに伝搬できるようになるメリットがあります。

損失関数

DDM における学習の目的は、reverse process におけるノイズ画像 x_T から画像  x_0 を生成する確率分布(=尤度関数)の最大化です。但し、確率分布のままでは扱いづらいため、その対数尤度の最大化を考えます。

この対数尤度は、変分下限 [ELBO : Evidence Lower BOund] なるものと KL ダイバージェンス(2つの確率分布の距離を表す指標の1つ)の和で表現できます。(詳細計算は略します)

完璧に学習成功した場合には、forward process (diffusion process) での確率分布と reverse process での確率分布が完全に一致し、両者の確率分布間の KL ダイバージェンスが0になるので、対数尤度を最大化したければ、(VAE のときと同じように)変分下限を最大化すれば良いことになります。 更に、この変分下限は 負の KL ダイバージェンスと復元誤差の和に変形出来るので、まとめると DDM における学習の目的である対数尤度の最大化をしたければ、変分下限を最大化すればよいが、そのためには KLダイバージェンスを最小化し、復元誤差を最大化すれば良いことになります。

そして、復元誤差の最大化は負の復元誤差の最小化と同値であるので、損失関数を以下の式のように定義できます。*3

上式をそのまま計算するのは依然として難しいので、DDM では、上式で定義した損失関数ではなく、上式を変形して更に近似した以下のような損失関数を使用して学習を行っています。*4

この損失関数は、正解データであるノイズ画像 \epsilon とニューラルネットワークからの出力画像 \epsilon_{\theta}(...) との間の最小2乗誤差を{タイムステップ・学習用データセット}全体に関しての平均値をとったもので定義されているので、計算コストが低くなっています

正解データを正規分布(平均値:0、分散値:1)からランダムサンプリングしたノイズ画像の特徴マップとしているので、reverse process のニューラルネットワークからの出力はノイズ復元した画像そのものではなく、ノイズ画像になります。

但し単なるノイズ画像ではなくて、画像を復元するにあたっての除去すべきノイズ部分を示したノイズ除去画像になります。言い換えると、reverse process のニューラルネットワークでは、この除去ノイズ画像が従う正規分布の平均値と分散値を学習していることになります。

復元画像ではなく除去ノイズ画像を生成するようにしているのは、復元画像を直接生成するよりも除去ノイズ画像を生成し、後段の処理でノイズ除去したほうが復元画像の品質が向上したためです。ニューラルネットワークの出力は 0 ~ 1 の範囲に正規化しているほうがうまく学習できますが、正規分布に従う除去ノイズ画像を出力するようにしたほうがうまく学習できるのはこれと同じ理屈だと思います。

そして、この損失関数の値が収束するまで、ニューラルネットワークの重みパラメーターを更新しながらで誤差逆伝播で学習していく動作になります。

具体的には、以下のようなアルゴリズムになります。

推論時の動作

推論時は、上式のように forward process での処理は行われず、reverse process の学習済みのニューラルネットワークを用いて、正規分布(平均値:0、分散値:1)からランダムサンプリングしたノイズ画像から復元画像を生成する動作になります。 具体的には、以下のようなアルゴリズムになります。

High-Resolution Image Synthesis with Latent Diffusion Models(LDM)

DDMP の解説ができた所で、いよいよ本命の LDM について解説していきます!

上図は LDM の全体のアーキテクチャの全体像を示した図です。 LDM のアーキテクチャは、以下の4つの主要コンポーネントで構成されています。

  1. 画像 Encoder
    入力画像を特徴マップにエンコードするネットワークです。
    ネットワーク構造としては、GAN ベースの Encoder-Decoder 型の事前学習モデルのエンコーダ部分を使用しています。
    このネットワークで入力画像をエンコードすることで入力画像全体の大域的な特徴量と局所的な細部の特徴量の両方を失うことなくエンコードできるようになっているようです。

  2. Diffusion Process
    DDPM の forward process と同じように、bicubic 補間などの学習パラメーターを持たない手法を用いて、元画像(=学習用データの画像)からノイズ画像を生成するプロセスです。
    ただし DDPM とは異なり、入力画像を Encoder でエンコードした特徴マップを入力とし、ノイズ画像の特徴マップを出力するようにしています。
    画像を直接 Diffusion Process で処理するのではなく、一旦特徴マップにエンコードして処理させることで、計算効率が大幅に向上するメリットがあります。

  3. Denoising Process
    DDMP の reverse process と同じように、diffusion process で生成したノイズ画像から元画像をニューラルネットワーク(LDM では UNet を使用)で生成するプロセスです。
    ただし DDPM とは異なり、入力画像を Encoder でエンコードした特徴マップを入力とし、ノイズ画像の特徴マップを出力するようにしています。
    ノイズ画像を直接 reverse process で処理するのではなくエンコードしたノイズ画像の特徴マップで処理させることで、計算効率が大幅に向上するメリットがあります。また、生成画像の品質向上にも大きく貢献しているようです。
    LDM では更に、このニューラルネットワーク(UNet)に対して、セマンティクスセグメンテーション画像 or テキストデータ or などを Encoder でエンコードした特徴マップをネットワークに入力し、条件つけ生成を行っています。 これにより、text-to-image や超解像度化、inpainting などのマルチモーダルな画像生成を可能にしています。

  4. 画像 Decoder
    Denoising Process で復元した画像の特徴マップを元の画像にデコードするネットワークです。
    ネットワーク構造としては、画像 Encoder と同じく GAN ベースの Encoder-Decoder 型の事前学習モデルのデコーダー部分を使用しています。 このネットワークでデコードすることで入力画像全体の大域的な特徴量と局所的な細部の特徴量の両方を失うことなくデコードできるようになっているようです。

以下それぞれ詳しく見ていきます。

画像 Encoder と Decoder

LDM では、まずはじめに学習用データセットの入力画像を画像 Encoder で特徴マップにエンコードし、最後に画像 Encoder で特徴マップを画像にデコードします。 アーキテクチャ図でいうと上図のピンク枠で囲った箇所が該当します。

画像をエンコードする理由としては、特徴マップは元の画像より画像の幅と高さが小さいので、後段の処理の計算効率が向上することと、生成画像の品質が向上するためです。

この際、生成画像の品質が向上させるために重要なのが元の画像の大域的な特徴量(=画像全体)と局所的な特徴量(=画像のピクセルレベルでの詳細)の両方を失うことなくうまく特徴マップにエンコードすることなのですが、LDM における画像 Encoder と Decoder は、論文「Taming transformers for high-resolution image synthesis」で提案された GAN ベースの Encoder-Decoder モデルをベースにしており、以下の損失関数で事前学習されたモデルになっています。

最も単純な損失関数としては L1 loss や L2 loss がありますが、この L1 loss や L2 loss で学習すると生成画像のぼやけの原因となるので、L1 loss や L2 loss ではなく VGG perceptual loss と GAN の Adversarial loss で学習しています。

LDM では更に、正則化項  L_{reg} も追加し、KL-reg(KLダイバージェンスによる正則化項)と VQ-reg(Decoder 内でのベクトル量子化による正則化項) という2つの異なる種類の正則化項いずれかを加えています。

損失関数に正則化項を追加することにより、特徴マップ z が正規分布(平均値:0、分散値:1)に近い分布で分布するようになり、画像の局所的な詳細の特徴量を失うことなく特徴マップにエンコードすることができるようになります。*5

【補足】KL ダイバージェンス

KL ダイバージェンスは、2つの確率分布間の距離を表す指標の1つで、以下の式で定義される指標です。

KL ダイバージェンスが0のときは2つの確率分布が完全に一致しており、KL ダイバージェンスの値が大きくなるほど2つの確率分布は離れていることを意味します。

今回の KL-reg(KLダイバージェンスによる正則化)は、潜在変数(=特徴マップ)z が従う確率分布と正規分布(平均値:0、分散値:1)間の KL ダイバージェンスをとっているので、この KL-reg により、Encoder でエンコードされた潜在変数(=特徴マップ)z が正規分布(平均値:0、分散値:1)に近くなるように Encoder の学習が促進される効果があります。

Diffusion Process

Diffusion Process は、DDPM の forward process と同じように、bicubic 補間などの学習パラメーターを持たない手法を用いて、元画像(=学習用データの画像)からノイズ画像を生成するプロセスです。ただし DDPM とは異なり、入力画像を上記画像 Encoder でエンコードした特徴マップを入力とし、ノイズ画像の特徴マップを出力するようにしています。アーキテクチャ図でいうと上図のピンク枠で囲った箇所が該当します。

特徴マップは、元の画像より画像の幅と高さが小さいので、元画像からノイズ画像を生成する DDPM と比べて計算処理が削減されます。また特徴マップを使用することで(後段の Denoising Process にて)画像の大域的な特徴量と局所的な特徴量の両方を学習することが可能になります。

Denoising Process

Denoising Process は、DDMP の reverse process と同じように、diffusion process で生成したノイズ画像の特徴マップから復元画像の特徴マップをニューラルネットワーク(LDM では UNet を使用)で生成するプロセスです。アーキテクチャ図でいうと上図のピンク枠で囲った箇所が該当します。
ただこの図だと少々分かりづらいので、もう少し詳細に書くと以下のような図になります。*6

Denoising Process は先に述べたように、diffusion process で生成したノイズ画像から元画像をニューラルネットワークで生成するプロセスです。

但し Diffusion モデルの場合は、タイムステップ 1~T に度に徐々に画像復元する動作になるので、上図のようにタイムステップ T 個分の連続した UNet を連結する形になります。

UNet を採用しているのは、UNet の skip connection 構造により入力画像の特徴量を失うことなく後段のネットワークに伝搬することができるためです。

更に、LDM では text-to-image・超解像度化・inpainting といったマルチモーダルな画像生成を可能にするために、テキストデータ or セマンティクスセグメンテーション画像 などを Encoder でエンコードした特徴マップを、以下のいづれかで入力(どちらの方法で入力するかは、条件つけ生成のための入力データの種類によって変わる)し、条件付け生成を行えるようにしています。*7

  1. concat で入力層側にそのまま結合
  2. T 個の UNet の各 Encoder 中間層と Decoder 中間層それぞれに cross attention と呼ばれるネットワーク層で入力

ここで、2つ目の cross attention と呼ばれるネットワーク層で入力する方法についてもう少し詳しく見ていきます。

cross attention

LDM でいう cross attention は、Transformer で提案されている attention 構造のことです。

具体的には、以下の図と式のように、UNet の flatten 化された特徴マップと条件づけ入力のための Encoder でエンコードした特徴マップをネットワークの重み行列 W で線形変換したものをクエリ行列 Q・キー行列 K・バリュー行列 V とし、それらの内積を softmax で 確率化(0.0 ~ 1.0)したものをバリュー行列 V に乗算し、出力を得ます(ここまでが Transformer の attention 構造)。そして、この出力を Unet の各 Encoder 中間層と Decoder 中間層それぞれに入力しています。




Transformer の attention 構造を使用することにより、以下の図のように、CNN 構造と比べて画像やテキスト全体の大域的な特徴量を失うことなく後段のネットワークに伝搬できるようになるメリットがあります。

損失関数

Diffusion Process は単に bicubic 補間などの学習パラメーターを持たない手法でノイズ画像を生成していただけでしたが、Denoising Process は DDPM と同じくニューラルネットワークで学習を行います。なので損失関数を定義し、その損失関数の値が収束するまで学習を行い、ニューラルネットワークの重みパラメーターを更新する必要があります。

LDM における学習対象のニューラルネットワークは、T 個の UNet と条件つけ付け生成のための Encoder です。画像 ENcoder と 画像 Ecoder は事前学習済みモデルを使用するので学習対象ではありません。

LDM における損失関数は、DDPM と同じように、正解データである正規分布(平均値:0、分散値:1)に従うノイズ画像の特徴マップと UNet からの出力特徴マップとの間の最小2乗誤差を{学習用データセット・条件つけ生成のための学習用データセット・タイムステップ}に関しての平均値をとった値で定義されます。式で書くと下式のようになります。

先のアーキテクチャ図で損失関数をとっている部分を図示すると以下の図のようになります。

DDPM と同じく、正解データを正規分布(平均値:0、分散値:1)からランダムサンプリングしたノイズ画像の特徴マップとしているので、UNet は復元画像そのものではなくノイズ画像の特徴マップを出力するように学習されます。但し単なるノイズ画像ではなくて、画像を復元するにあたっての除去すべきノイズ部分を示したノイズ除去画像の特徴マップになります。言い換えると、Unet ではこの除去ノイズ画像が従う正規分布の平均値と分散値を学習していることになります。

最終的な復元画像は上図のように、この除去ノイズ画像を使って「復元画像」=「正規分布からランダムサンプリングしたノイズ画像」ー「除去ノイズ画像」の式で得られるようにしています。*8

復元画像ではなく除去ノイズ画像を生成するようにしているのは、DDPM と同じく復元画像を直接生成するよりも除去ノイズ画像を生成し、後段の処理でノイズ除去したほうが復元画像の品質が向上したためです。ニューラルネットワークの出力は 0 ~ 1 の範囲に正規化しているほうがうまく学習できますが、正規分布に従う除去ノイズ画像を出力するようにしたほうがうまく学習できるのはこれと同じ理屈だと思います。

そして、この損失関数の値が収束するまで、UNet 及び条件付け生成のための Encoder の重みパラメーターを更新しながらで誤差逆伝播で学習していく動作になります

マルチモーダルな画像生成

LDM の基本的な構造としては、以上になります。 ここからは、LDM のマルチモーダルな画像生成に焦点をあてて解説していきます。 特に今回の主題である Stable Diffusion v1 は text-to-image 用の LDM のことなので、まずはこちらから解説します。

text-to-image

text-to-image 用の LDM では、条件つけ生成のための Encoder として自然言語モデルで特に競争力のある品質を実現できる Transformer を採用しています。また tokenizer としては、大規模自然言語モデルである BERT の BERT-tokenizer を使用しているようです。

学習用データセットとしては、テキストと画像の4億枚のペアデータセットである「LAION-400M データセット」を使用しています。

このモデルを使用することで、上図のように、あるテキストからそれに文章内容を反映した画像を生成することができるようになっています。

最初に述べたように、Stable Diffusion v1 は text-to-image 用の LDM モデルとの違いとしては、学習用データセットとして「LAION-5B データセット」と呼ばれる50 億枚の{テキスト・画像}ペアデータセットを使用し、Transformer で実装される条件付け生成のための Encoder として、事前学習済み CLIP ViT-L/14 モデルを使用したものが Stable Diffusion v1 になっています。

超解像度化



超解像度用の LDM では、条件つけ生成のための入力データとして低解像度画像を入力しています。 また、エンコードされた低解像度画像の特徴マップを Unet の入力層に concat で直接入力するようにしているようです。

学習用データセットとしては、(SR3 と公平な品質比較のために SR3と同じ)「ImageNet データセット」を使用し、画像 Encoder としては「OpenImages データセット」で事前学習したモデル(正則化項は VQ-reg、ダウンサンプリング係数 f =4)を使用しています。

また Diffusion Process は、ダウンサンプリング係数 4 の bicubic 補間で行っています。

上図の比較検証より、競争力の高い Diffusion model ベースの超解像度化モデルである SR3*9 と比較して、定性的にも定量的にも LDM はより優れた品質を達成できていることが見て取れます。

Inpainting


Inpainting 用 LDM では、条件つけ生成のための入力データとして、画像の一部領域をマスクされた画像データを入力しています。

学習用データセットとしては、10万枚の様々なシーンの画像データセットである「Places データセット」を使用しています。但し実際に画像を入力する際には、LaMa*10 と同様の方法で画像の一部をマスクしています。

競争力のある既存の Inpainting モデルである LaMa と比較して、定性的にも定量的にも LDM はより優れた品質を達成できていることが見て取れます。

その他



LDM ではその他にも、条件つけ生成のための入力データにセマンティクスセグメンテーション画像や物体検出用のバウンディングボックスを入力することで、上図のようにセマンティクス画像からの画像生成やバウンディングボックスからの画像生成なども行えるようになっています。

感想

最後に感想をいくつか。

画像を Encode して後段の処理を行う方法・UNet の skip connection 構造により入力画像の特徴量を失うことなく後段のネットワークに伝搬する方法・条件つけ生成のための UNet の各 Encoder 中間層と Decoder 中間層に入力を追加する方法といったものに関しては、今までの GAN の生成モデルでも広く行われていることだし、正直新規性は感じなかったです。

ただこれらの工夫を Diffusion Model に適用したこととが LDM の新規性でしょうか。あと50 億ペアの大規模学習用データセットで LDM を学習し、誰でも使えるように公開したのが Stable Diffusion v1 の新規性・貢献かと思います。(賛否は起こっているみたいですが、、、)

あと Stable Diffusion v1 (LDM) は Diffusion Model といいつつも、中身よく見てみると GAN も Transformer も使っているし、GAN + Transformer + Diffusion Model の組み合わせという今までの DNN の叡智を集結させたモデル感があってそこに感動しました!

採用情報

株式会社ABEJAでは共に働く仲間を募集しています!

機械学習やバックエンド開発に興味あるエンジニアの方々!こちらの採用ページから是非ご応募くださいませ!

careers.abejainc.com

*1:最近 Stable Diffusion v2 も出たみたいですが、この記事では Stable Diffusion v1 について解説します。

*2:上式の正規分布のパラメーター β を学習パラメーターとして、forward process を学習対象にすることもできるのですが、βを学習パラメーターにしても品質向上にはならなかったので DDMP では β を定数にして、学習を行わないようにしているようです。

*3:ここらへんの議論は VAE [Variational Autoencoder] のときと同じになるので、詳細気になる方は VAE の元論文などをご確認ください。

*4:変形と近似の過程の詳細はかなり長いので略します。気になる方は「Denoising Diffusion Probabilistic Model(DDPM)」の論文を確認してくださいませ。

*5:この KL ダイバージェンスによる正則化項は、SPADE [Semantic Image Synthesis with Spatially-Adaptive Normalization] 等で行われているものと同じものだと理解しています。LDM の論文では書かれてなかったですが、潜在空間(今の場合は特徴マップ)の多様体の滑らかさが向上し、その結果として生成画像の多様性が向上し、GANの学習も安定化する効果もあったはずです

*6:厳密にはニューラルネットワークが出力するのは元画像を復元する際に使用する除去ノイズ画像ですが、ここでは簡単のため元画像をニューラルネットワークの出力に書いてます。

*7:Stable Diffusion v1 は、LDM の条件付け生成のための入力データとしてテキストデータを指定した場合のモデルになっています。

*8:適切な「除去ノイズ画像」を準備できなかったので、この図では「除去ノイズ画像」を適当なノイズ画像で図示しています。「除去ノイズ画像」は実際にはもうちょい異なるノイズ画像になります。

*9:SR3 : 「Image super-resolution via iterative refinement

*10:LaMa:「Resolution-robust Large Mask Inpainting with Fourier Convolutions