はじめに
こんにちは。Labsチームの藤本です。ABEJA Qwen先生にタイトルを作ってもらったら割と強いタイトルになってしまいました。
大規模言語モデル(LLM)の学習を始め、Deep Learningにおいては最適なハイパーパラメータの設定は頭を悩ませる問題の一つです。大規模なモデルの学習には一度学習をするのに大量のGPUを使う上に、数日から数週間という時間がかかります。従来のアプローチでは、グリッドサーチやランダムサーチを使って複数の設定を試すのが一般的でした。しかし、32Bなどの大規模なモデルは学習に必要なリソースが大きいため、グリッドサーチやランダムサーチでは試すことが現実的ではありません。
こうした課題に対して、μP(Maximal Update Parametrization, muP)という非常に興味深い理論があります。これは、モデルのパラメータの初期化や学習率を幅に応じて適切にスケーリングすることで、小さなモデルで見つけた最適なハイパーパラメータを、大きなモデルにもそのまま適用できるようにするというものです [1]。この理論により、大規模モデルを直接チューニングせずに済む可能性が広がり、結果として学習にかかる時間やリソースを大幅に削減できると期待されています。μPの理論を活用し、小さなモデルで得たハイパーパラメータを大きなモデルにそのまま適用する実践手法を μTransfer と呼びます。これらの原理についての解説はこちらのnoteに詳しく紹介されています。実際のモデルに適用した事例が見つからなかったので、本ブログではμP/μTransferの簡単な解説に加え、Qwenモデルを対象に実験した結果についてご紹介します。なお、本実験のソースコードはこちらにあります。
μP(Maximal Update Parametrization)とは
μPが解決する根本的な問題
ニューラルネットワークで学習をする際には、通常はネットワークの幅を変更すると最適なハイパーパラメータも変わってしまいます。例えば、隠れ層が512次元の小さなモデルで学習率0.001が最適だったとします。しかし、同じアーキテクチャで隠れ層を2048次元に拡張すると、学習率0.001では学習が不安定になったり、逆に収束が遅くなったりします。結果として、モデルサイズごとにハイパーパラメータを調整し直す必要があり、特に大規模なモデルでは膨大な計算コストが必要になってしまいます。
μP(Maximal Update Parametrization, muP)は、この問題を解決できる手法の一つです。μP (muP)の理論を利用したパラメータ転移手法であるμTransferによって、小さなモデルで見つけた最適なハイパーパラメータを大きなモデルにそのまま転移できます。
なぜハイパーパラメータ転移が困難なのか:SPの根本問題
従来の標準的なパラメータ化(Standard Parametrization、SP)でハイパーパラメータ転移が困難な理由は、ネットワークの幅を変更すると学習の性質そのものが変わってしまうことにあります。
まず、μPが目指す「理想的な学習」を定義しましょう。深層学習の力を最大限に発揮するには、ネットワークの各層の出力の分散がΘ(1)となることが重要です(要するに、ネットワークの幅に依らない大きすぎたり、小さすぎたりしない値)。この状態においては、表現力がサイズに比例して向上するような理想状態となります。μPは適切なスケーリングにより、小さなモデルと大きなモデルで勾配の大きさ、活性化の分散、重みの更新量を調整し、同一の学習ダイナミクスを実現することで、ハイパーパラメータの転移を可能にします。具体的には、幅をn倍にした場合は、それに合わせて重みの初期化と学習率を適切にスケールすることで、パラメータの更新量の比率が一定になります。
これに対して、SPにおいて学習率や初期値を変えずにネットワークの幅のみ変えた場合、勾配や活性化のスケールが制御されないため、更新量や学習の進み方が変わってしまい、ハイパーパラメータの転移ができません。そのため、小さなモデルで見つけた最適なハイパーパラメータ(学習率、バッチサイズ、重み減衰など)は、大きなモデルでは機能しなくなります。
μPの仕組み
μPの核心は、モデルの幅が変化しても学習の振る舞いが実質的に変わらないように、パラメータの初期値と更新量を巧みにスケーリングすることにあります。
順伝播における「初期状態」の安定化
まず基本として、深層学習では、He初期化やLeCun初期化のように重みの初期分散を 1/fan_in
(fan_in
は入力次元数)でスケールすることで、各層の出力の分散が幅に依存しないようにします。
// He初期化などによる、SPでも行われる基本的な処理 Var[z^(l+1)] = Var[z^(l)] × Var[W] × fan_in = Var[z^(l)] × (c/fan_in) × fan_in = c × Var[z^(l)]
これにより、順伝播において各層の出力の値が爆発・消失するのを防ぎます。μPもこの原則に従いますが、ここでμPの真価は学習中の更新(勾配による変化)がこの安定性を壊さないように制御する点にあります。
学習を支配する「パラメータ更新」のスケーリング
SPでは、学習が進むとパラメータの更新量がネットワークの幅に対して不均衡になり、学習の進み方そのものが変わってしまいます。これは、ネットワークの幅を変えても初期化や学習率のスケーリングが調整されないため、出力の大きさや勾配の大きさ、そしてそれらによって生じるパラメータの更新量が、幅に対して適切なバランスを保てなくなるからです。例えば、ネットワークの幅を大きくすると、同じ学習率であっても1ステップあたりの出力変化量が過剰に大きくなり、逆に幅を小さくすると出力がほとんど変化しなくなるといった問題が発生します。これにより、モデルサイズによって学習の安定性や収束の仕方が大きく変わってしまい、小さなモデルで調整したハイパーパラメータが大きなモデルでは機能しなくなります。
μP(Maximal Update Parametrization)は、このモデル幅に由来する不均衡を理論的に解決する手法です。層の種類(入力層、中間層、出力層)やパラメータの形(行列型かベクトル型か)に応じて、初期化の分散や学習率を非対称に設計することで、パラメータの更新が常に適切なスケールに保たれます。これにより、ネットワークの幅を変えても学習のダイナミクスが保たれ、小さなモデルで見つけたハイパーパラメータをそのまま大きなモデルに移植できるようになります。
では、この「役割に応じたスケーリング」を、μPは具体的にどのように実現するのでしょうか。その答えは「Tensor Programs」という高度な理論から導かれますが、最終的に得られるルールは後述するように驚くほどシンプルです。簡易的に背景を説明すると、学習中にネットワーク各層の出力が幅に依らず安定して保たれるように、パラメータの更新量のスケールを調整することが重要になります。中間層を例にとって考えると、出力は $\boldsymbol{h} = \boldsymbol{W} \boldsymbol{x}$ という形で計算されるとき、初期化を適切に行えば$\boldsymbol{h}$の分散はΘ(1)に保たれます。これに対し、学習中に重みが$\boldsymbol{W} ← \boldsymbol{W} + \boldsymbol{\Delta} \boldsymbol{W}$として更新されると、出力も変化します。ここで、その変化が過度に大きくなったり小さくなったりすると、学習が不安定になります。この問題を防ぐためには、更新後の出力の変化 $\boldsymbol{\Delta} \boldsymbol{h} = \boldsymbol{\Delta} \boldsymbol{W} \boldsymbol{x}$ の分散がやはりΘ(1)に保たれるように、更新量 $\boldsymbol{\Delta} \boldsymbol{W}$ のスケールを選ぶ必要があります。具体的には、幅がnのときに学習率を 1/nにスケーリングすることで、$\boldsymbol{\Delta} \boldsymbol{h}$ の分散もΘ(1)に維持され、出力のスケールがネットワーク幅に依存せず一定に保たれるため、学習が安定します。この状態になると、仮にネットワーク幅が変わっても更新の影響が同程度になるため、結果として同じ学習率を転移できます。これらの詳細についてはGreg Yangらの論文 [1]を参照ください。
このような考えの元、Adam optimizerを例として挙げた場合、基本となるルールは以下のようになります。
μPスケーリング則の基本概念(Adamの場合)
層の種類 | パラメータ形状 | 初期化分散 | 学習率のスケール |
---|---|---|---|
入力層 | (width, d_in) | ~ 1/d_in |
1 (変化なし) |
中間層 | (width, width) | ~ 1/width |
~ 1/width |
出力層 | (d_out, width) | ~ 1/width^2 |
~ 1/width |
(注: これは概念を単純化した一例です。実際には複数の同等な定式化があります)
この非対称なスケーリングにより、幅がどれだけ大きくなっても、各層のパラメータが学習に与える影響のバランスが保たれます。これにより、学習の進み方そのものがモデルサイズに依存しなくなるのです。
ハイパーパラメータの計算例
さて、具体的な数値を利用して、μPの計算をしてみましょう。まず、以下のようなネットワークがあったとします。
- 探索すべき学習パラメータ (
η_base
) - 入力層の次元 (
d_in
):768
- 中間層の次元 (
width
):512
この場合、μPの規則を守った場合の初期化の分散と学習率は通常は以下のようになります。
層の種類 | 初期化分散 σ² | 学習率 η |
---|---|---|
入力層 | 1/768 ≈ 0.0013 |
η_base |
中間層 | 1/512 ≈ 0.00195 |
η_base/512 |
出力層 | 1/512² ≈ 0.000038 |
η_base/512 |
ここでこのモデルを最もよく学習するためのη_base
はどうなるでしょうか?中間層の次元を512から落とし、次元数を128にした場合に、μPを用いて同じ学習ダイナミクスになるような設定を計算してみましょう。
層の種類 | 初期化分散 σ² |
学習率 η |
---|---|---|
入力層 | 1/768 ≈ 0.0013 |
η_base |
中間層 | 1/128 ≈ 0.00781 |
η_base/128 |
出力層 | 1/128² ≈ 0.000061 |
η_base/128 |
このセッティングで最適な学習率η_base
を探索すると、その値を512にした場合の学習率に転移することができます。小さいモデルで探索できるので、大幅にパラメータ探索の時間と短縮できます。
なお、論文[1]では転移可能・不可能なパラメータは以下のように記載されています。μPの規則に従えば、幅以外にも、モデル間で、深さ・バッチサイズ・シーケンス長などを転移できるようですね。
μPの実装
上記スケーリングを行うμPの実装自体はμpのリポジトリにあります。また、mutransformersにてhuggingfaceのtransformersのモデルにμpを適用するサンプルがあります。ただし、サンプルにはBERTとGPT-2、RoBERTaしかありません。本ブログではこれを参考に、Qwen2.5で利用できるようにコードを修正しました。各実験をおこなうコードはこちらに置いておきます。
μPではモデルの初期値と学習率をスケーリングする必要があります。モデルについては、model.pyの中で初期値をスケーリングするようにしています。中間層については、module.weight.data.normal_
をmup.normal_
に置き換えます。mup.normal_
の中で、ネットワークの幅に応じてスケーリングを行っています。
def _init_weights(self, module, readout_zero_init=False, query_zero_init=False): std = self.config.initializer_range if isinstance(module, nn.Linear): if hasattr(module.weight, 'infshape'): normal_(module.weight, mean=0.0, std=self.config.initializer_range) else: module.weight.data.normal_(mean=0.0, std=std)
また、最終層はnn.Linear
をmup.MuReadout
に置き換えてスケーリングを行います。単純に初期の重みを設定するのでも良いですが、MuReadoutを利用することで、入力層との重みのシェアに拡張できるというメリットがあるようです。
self.lm_head = MuReadout(config.hidden_size, config.vocab_size, bias=False)
また、学習率のスケーリングは、optimizerを定義する際に、層ごとにスケールを調整するmup.option.MuAdam
などを利用します。実装では、mup_coord.py
やmup_lr.py
の中で以下のようにすることでSPとの切り替えを同じコードで行えるようにしました。
if mup: from mup.optim import MuAdam as Adam from mup.optim import MuAdamW as AdamW from mup.optim import MuSGD as SGD else: from torch.optim import SGD, Adam, AdamW
Qwen2.5で実験
実験ではQwen2.5モデルを対象として、μPを適用した場合と適用しない場合の学習率のスケーリングの比較を目的とします。小さいモデルでの学習率が、大きいモデルに転移できることを確認し、最適なパラメータを求めていました。
まずは、μPが正しく動作できているかどうかを調べるために、モデルの幅(width)を大きくしていったときに活性化の座標のサイズ(スケール)がどう変化するかを確認しましょう。mup_coord.py
の中で、各層の出力を記録していきます。以下のようにして、各層の出力を記録するhookを登録します。ここでは出力はl1ノルムで計算しました。
module.register_forward_hook( _record_coords(df, width, name, batch_idx, output_fdict=output_fdict, input_fdict=input_fdict, param_fdict=param_fdict))
check_coord.py
でチェックを行い、notebookで結果をプロットした結果を以下に示します。左がSP、右がμPの結果となります。図の読み方としては、x軸がネットワークの幅(隠れ層の次元数)、y軸がその層の出力値になります。各プロットはどの層からの出力であるかを示します。層の数が多いので層名は省略します。SPではネットワークの幅が大きくなるほど、出力が大きくなっていることが確認できます。一方で、μPではネットワークの幅が大きくなっても出力のスケールが変わらないことが確認できます。よって、μPを利用すればネットワークの幅によらず出力のスケールがΘ(1)に保たれていることが確認できました。
出力のスケールが幅に寄らないというμPの基本的な動作を確認できたので、次は学習率のスケーリングを行ってみましょう。check_lr.py
を実行するとネットワーク幅80〜640で、様々な学習率で3000ステップ学習行い、最終ステップ後のlossをプロットしました。結果を以下に示します。x軸が学習率、y軸が学習後のlossの値、各プロットはネットワークの幅(隠れ層の次元数)となります。
まずSPの結果を見てみましょう。各width毎にlossが最小になる学習率が異なっていることがわかります。ここから、ネットワークの幅が異なると最適な学習率が異なることがわかります。一方で、μPの結果ではネットワーク幅によらず学習率に対するlossの変化が似た形になることが確認できます。μPを使うことによって、学習率を転移できることが確認できました。ここより、幅80で学習率の調整を行ってから、最も良い学習率で幅640のモデルに転移することで、幅640のモデルでも安定して学習できることが分かります。具体的には最適な学習率はlog2lr=-9
であることが分かりました。実際のQwen2.5-32Bでは、ネットワーク幅がもっと大きいですが、このロジックを利用すれば同様に学習率のスケーリングを行うことができます。
まとめ
本ブログでは、大規模言語モデルの学習において避けては通れない「ハイパーパラメータ最適化」に対して、μP(Maximal Update Parametrization)がどのように理論的・実践的に解決を提供するのかを紹介しました。特に、実際に自分たちで扱っているQwen2.5モデルに対して適用し、μPによって出力スケールの安定性が保たれ、それを活用したμTransferにより学習率のスケーリングの一貫性が保たれることを確認し、小さなモデルでのチューニング結果がそのまま大きなモデルにスムーズに転移できることを確認しました。これは、大規模モデルの学習コストや試行錯誤の時間を大幅に削減できる非常に強力な手法です。
実際のところは、μP自体は継続事前学習ではなくフルスクラッチのモデルを前提としていることや、実際に最適な学習率を計算してみたら、いつも使っている学習率とはあまり変わらなかったので、結果的にいつも通り学習をしたんですけどね。とはいえ、本理論はフルスクラッチで学習する際には常に利用できる強力な手法ですので、覚えておいて損はないと思います。みなさんもフルスクラッチ学習する際は利用してみてはいかがでしょう。
We Are Hiring!
ABEJAは、テクノロジーの社会実装に取り組んでいます。 技術はもちろん、技術をどのようにして社会やビジネスに組み込んでいくかを考えるのが好きな方は、下記採用ページからエントリーください! (新卒の方やインターンシップのエントリーもお待ちしております!)
[1] G. Yang, E. J. Hu, I. Babuschkin, S. Sidor, X. Liu, D. Farhi, N. Ryder, J. Pachocki, W. Chen, and J. Gao. Tensor programs V: Tuning large neural networks via zero-shot hyperparameter transfer, arXiv preprint arXiv:2203.03466, 2022.