ABEJA Tech Blog

中の人の興味のある情報を発信していきます

小型LLM「ABEJA Qwen2.5-7B Model」学習のための蒸留のパイプライン並列化

はじめに

こんにちは、Labsチームの藤本です。

弊社は、経済産業省とNEDOが実施する、国内の生成AIの開発力強化を目的としたプロジェクト「GENIAC(Generative AI Accelerator Challenge)」の1期に続き、2期にも採択され、そこで大規模言語モデルの開発を進めています。今回は、そのプロジェクトの中で実施した大規模言語モデルの蒸留(Knowledge Distillation)に関する技術的な取り組みをご紹介します。

本蒸留の成果については、以前の記事(https://tech-blog.abeja.asia/entry/geniac2-qwen25-7b-v0.1)で既に紹介しております。本記事では、特にNeMoフレームワークにおける蒸留の仕組みと、NeMoで大規模なモデルを効率的に蒸留する際の技術的課題およびABEJAではどのように実装したかについて紹介していきます。

蒸留とは

はじめに蒸留について簡単にご紹介します。蒸留とは、大規模な教師モデル(teacher)の知識を、小規模な生徒モデル(student)に転移する技術です。一般的には、学習済みの高性能モデルの知識を小さいモデルに引き継ぐことで高い精度を保ちつつ軽量化が可能というメリットがあります。

蒸留を行うことで、通常通りのSFT(Supervised Fine-Tuning)を行うよりも、高い精度を達成できることが多いという報告が多くされており、我々も蒸留を実施することにしました。蒸留は教師モデルの知識を生徒モデルに転移する仕組みであるため、学習前の生徒モデルが教師モデルと近いほど学習が容易になります。そのため、教師モデルの枝刈りをして小型化をしたものに対して蒸留を行う場合が多いです。一方で、枝刈りを利用すると蒸留後の小型モデルに対して差分マージなどを適用することができなくなってしまう可能性があると考えたので、今回は別途学習済みの小型モデルをベースとして蒸留を行い、その上でChatVectorを適用する方針としました。

蒸留では、教師モデルと生徒モデルの出力の差分をlossにします。一般的な手法としては、出力のsoft targetをloss関数として使う方法や、中間層の出力を合わせる方法(intermediate layer matching)などがあります。今回は教師モデルと生徒モデルのベースモデルが異なるものであるため、中間層を合わせる方法は採用せず、最終レイヤーのsoft targetをlossとして利用することとしました。

NeMoとMegatron-Coreの紹介

NeMo(NVIDIA NeMo)は、NVIDIAが提供する大規模言語モデルの学習・推論・蒸留を支援するライブラリであり、内部でMegatron-LMに含まれるMegatron-Coreの技術を活用しています。

Megatron-Coreは大規模モデルの分散学習のためのライブラリで、以下のような並列学習をサポートしています:

  • Data Parallel
  • Tensor Parallel
  • Pipeline Parallel
  • Sequencial Parallel
  • Expert Parallel

これらの仕組みをうまく組み合わせることで、効率的に大規模モデルを学習することが可能です。例えばTensor Parallelでは、テンソルを特定の次元に沿って分割し、各GPUは分割されたテンソルのみ処理することで計算量・メモリ使用量を抑えることができます。Pipeline Parallelでは、レイヤーを複数のステージに分けて、それぞれを並列に処理します。例えばGPUが4枚あり、24層のネットワークを、6レイヤーずつ4つのステージに分割したとします。1個目のデータが入力されたら最初のデバイスで6レイヤー分のforward計算を行い、次のデバイスにデータを送信します。次のタイミングでは、先ほど処理した1個目のデータを2個目のデバイスで続きの6レイヤー分のforward計算を行い、それと同時に2個目のデータを新たに最初のデバイスに投入しforward計算を行います。このようにデバイスごとにステージを割り当てて並列処理を行います。

https://developer.nvidia.com/blog/scaling-language-model-training-to-a-trillion-parameters-using-megatron/

NeMoは、このMegatron-Coreをベースに、大規模モデルの学習・推論・蒸留などを支援するライブラリです。自然言語だけではなく、画像・動画、音声など様々なモデルをサポートしています。また、モデルのフルスクラッチの学習だけではなく継続事前学習やParameter-Efficient Fine-Tuning (PEFT) などのFinetuning手法もサポートしています。さらに、開発や実験を効率的に行うため、PyTorch LightningやHydra、WandBを統合しています。NeMoはそれを取り巻く様々なフレームワーク・ライブラリがあり、NeMo-Alignerのような様々なFinetuningの手法が実装されたライブラリやNeMo Runという実験の設定・実行・管理を簡素化・構造化するツールなどとの連携が可能です。

NeMoにおける蒸留の実装

GENIACで開発を進めていた2025年2月時点では、NeMoによる蒸留の実装はexampleディレクトリにありました。ただし、当時はpipeline parallelに対応しておらず、大規模なモデルへの対応ができませんでした。そこで、ABEJAでは独自で蒸留を実装することにしました。ちなみに、もう少し細かい事情ですが、ABEJAでの蒸留実装がスタートしたタイミングで、NeMo本家でもpipeline parallelの実装がスタートしておりプルリクがありました。しかし、そのコードを確認した所、次節で述べるメモリ問題があったため、我々は独自での実装を進めることとしました。ここでは、まずは当時のexampleにあった実装について紹介します。

NeMoでは蒸留を行うために、NVIDIAが提供するnvidia-modelopt(以下modelopt)というライブラリを内部で使います。このライブラリは蒸留の他にモデルの量子化、枝刈り、投機的デコーディングなど様々な手法を利用することができます。NeMoで蒸留を行うために、GPTModelを継承した蒸留用のクラスを作成し、これを用います。通常、modeloptを用いて蒸留する場合、生徒モデルを先に用意しておき、以下のように教師モデルと生徒モデルを内部に持った蒸留用のモデルに変換できます。

model = mtd.convert(student_model, *mode*=[("kd_loss", kd_config)])

なお、kd_configにはロスの情報や、どのレイヤーの結果を蒸留の計算に利用するかなどを指示します。NeMoモデルの場合は、以下のようにoutput_layer同士をLogitsKLLossで比較するように設定しています。なお、_teacher_providerは教師モデルを返す関数です。

logit_pair = ("output_layer", "output_layer")
loss = LogitsKLLoss(*tensor_parallel*=tp_enabled)
kd_config = {
    "teacher_model": (_teacher_provider, [self.cfg, copy.deepcopy(self.trainer)], {}),
    "criterion": {logit_pair: loss},
    "loss_balancer": None,
}

NeMoのexampleでは、モデルのforwardはget_forward_output_and_loss_funcを継承して実装しています。その中で、get_batchでデータを取得、output_tensor = model(**forward_args)でモデルのforwardを行い、内部で定義されたloss_funcで先程のロスを計算します。ここまではmodeloptの蒸留の機能をそのままインテグするだけなので、比較的素直な実装になっています。

pipeline parallelに対応する際の課題

Megatron-Coreのpipeline parallelによる並列化の難しさ

NeMoによる蒸留は、2月時点ではpipeline parallelには対応していませんでした。pipeline parallelが難しい理由として、NeMoがバックエンドに利用しているMegatron-Coreのforwardとbackwardのpipeline parallelをする際のプロセスが原因の一つと考えます。Megatron-Coreのpipeline parallelの処理はmegatron/core/pipeline_parallel/schedules.pyにあり、先程の図に示したようなglobal batch全体をpipeline parallelにするという流れになっています。例えばforwardでは、以下のようにして、前のステージからのデータを受け取り、該当ステージでのforwardを行い、その後、現在のステージの出力を次のステージに送信します。

input_tensor = recv_forward(recv_tensor_shapes, config, parallel_state.is_pipeline_first_stage())
output_tensor, num_tokens = forward_step(forward_step_func, data_iterator, model, ...)
send_forward(output_tensor, send_tensor_shapes, config, parallel_state.is_pipeline_last_stage())

前のステージから受け取った中間データをrecv_forwardでinput_tensorに格納し、上記forward_stepの中で以下のようにモデルのset_input_tensorを通し、モデルへ入力します。

set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor")
set_input_tensor(input_tensor)

ところが、pipeline parallelをする際にステージ間で転送されるデータの形式が、先程のソースコードにあったようにrecv_tensor_shapesという各レイヤーが出力するテンソルの出力のshape(B x S x H)に固定されています。送受信データの形が上記コードのようにinput_tensorという一種類のテンソルになってしまっているため、ここには教師モデルと生徒モデルの両方の出力を格納できません。そのため、exampleのコードをそのままpipeline parallelにしようとしても、生徒モデルのデータしか送受信してくれず、蒸留を並列化できませんでした。

現在(2025年4月末時点)のNeMo2での実装と課題

その後、2025年2月よりNeMo自体はpipeline parallelに対応しました。これまでexampleはそのまま維持され、NeMo LLM Collectionの中で別途蒸留が実装されています。NeMo LLM Collectionは、NeMo 2.0で再設計された学習・運用を簡単かつ柔軟に行える統合ライブラリです。蒸留の実装では、上記pipeline parallelの送受信の仕様を引き継ぎ、教師モデルと生徒モデルを別々に並列化して実行する実装になっています。具体的にはnemo/lightning/megatron_parallel.py内で、以下のように教師モデルと生徒モデルを交互に動かすようになっています。すると、teacher_stepの中で教師モデルを一通り処理してから、続いてstudent_stepの中で生徒モデルの処理を行うことになります。teacher_step、student_stepでは、先ほど図示したpipeline parallelの一連の処理が行われるもので、下図のようにそれぞれの内部でglobal batch size回のforward/backwardが行われます。図のケースではglobal batch sizeを4としています。

  with self.unwrapped_model.only_teacher_forward():
      with self.unwrapped_model.swap_teacher_config(self.module):
          teacher_step()
  with self.unwrapped_model.only_student_forward():
      microbatch_outputs = student_step()

NeMoの実装

ここで、modeloptの実装の中身を見ていきましょう。教師モデルをforwardすると、以下のようなhookの仕組みを利用してoutput_capture_fwd_hook関数の中で所定の中間レイヤーの結果をキャッシュします(先程の図におけるO1〜O4)。キャッシュした中間レイヤーの結果をlossを計算する際にpopし、対応する生徒モデルの結果と比較します。ところが、global batch分の中間結果をteacher_layer._intermediate_outputに全てキャッシュする必要があるため、global batch sizeが大きいとOOMになりやすくなります。

for student_layer, teacher_layer in self._layers_to_loss:
    (省略)
    teacher_layer.register_forward_hook(output_capture_fwd_hook)

def output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):  # pylint: disable=redefined-builtin  # noqa
    (省略)
    # Teacher
    if len(module._intermediate_output) > 0:
        warnings.warn(
            f"Teacher's Module `{type(module).__name__}` already has an intermediate output stored."
            " This is undesired behavior unless Pipeline Parallelism is in use."
        )
    module._intermediate_output.append(output)
    (省略)

def compute_kd_loss(
    self,
    student_loss: Optional[torch.Tensor] = None,
    loss_reduction_fn: Callable = None,
    skip_balancer: bool = False,
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
    (省略)
    for (student_layer, teacher_layer), loss_fn in self._layers_to_loss.items():
        out_s = student_layer._intermediate_output
        out_t = teacher_layer._intermediate_output.pop(0)  # can store multiple in special cases
        student_layer._intermediate_output = None

        loss = loss_fn(out_s, out_t)  # Student is pred, Teacher is target
    (省略)

理想的には教師モデルの結果をキャッシュせずに、教師モデルと生徒モデルをmicrobach単位で交互に動かしたいのですが、Megatron-Coreではステージ間でのデータの送受信の形がレイヤーの出力のshapeに固定されており、教師と生徒の両方の出力を同時に送受信することが難しいです。そのため、NeMoの実装では教師と生徒をそれぞれglobal batchの単位で動かしていました。結果的に(global batch x seq_len x hidden_dim)のメモリが必要になっていました。

ABEJAでの実装方法

ABEJAではこの解決のため、megatron/core/pipeline_parallel/schedules.pyにあるMegatron-Coreの並列の実装を作り直し、蒸留のプロセスに合わせて修正しました。具体的には以下の図の処理になるように並列処理のロジックを書き直しました。micro bach毎に(1)教師モデルのforward、(2)生徒モデルのforward、(3)ロスの計算、(4)生徒モデルのbackwardという順番で処理を行うようにし、先程の中間結果のキャッシュを不要にしてメモリを節約することとしました。これにより、(micro batch x seq_len x hidden_dim)までメモリ消費に抑えることに成功しました。

ABEJAでの実装

本実装では、microbatch毎に、教師モデルのforwardを行い、その後すぐに生徒モデルのforwardを行います。教師モデルはbackwardは不要なため、最後のステージでは生徒モデルのみbackwardを行い、前のステージに結果を送ります。backwardのフェーズでは生徒モデルのみ、後ろのステージからデータを受け取り、そのデータを用いてbackwardを行い、その結果を更に前のステージに送ります。例えば、forwardの処理の一部は以下のようになります。送信と受信の順番がズレるとハングするため丁寧に作る必要がありますが、並列プログラミングはデバッグしづらいので中々大変ですね。なお、実際の処理の中ではvalidationの処理はforwardのみであったり、またステージ毎に処理を分けていたりと、非常に長くなってしまうので、ここでは簡易的なコードのみ示します。以下のように、ステージの中でmicrobatch毎に教師モデルの処理と生徒モデルの処理をそれぞれ交互に行うこととしています。

# (1) 前のステージのデータを受け取り
teacher_input_tensor = recv_forward(teacher_recv_tensor_shapes, teacher_config)
input_tensor = recv_forward(recv_tensor_shapes, config)

for i in range(num_microbatches):
    # (2) それぞれのモデルのforwardを行う
    teacher_output_tensor, _ = forward_step(teacher_forward_step_func, teacher_data_iterator, teacher_model, ...)
    output_tensor, num_tokens = forward_step(forward_step_func, data_iterator, model,. ..)

    # (3) 次のステージに結果を送りつつ、studentは後段処理の結果を待つ
    send_forward(teacher_output_tensor, teacher_send_tensor_shapes, teacher_config)
    output_tensor_grad = send_forward_recv_backward(
        output_tensor, send_tensor_shapes, config
    )

    # (4) studentのbackwardを行う
    input_tensor_grad = backward_step(
        input_tensor, output_tensor, output_tensor_grad, model_type, config
    )

    # (1-2) 前のステージのデータを受け取る
    teacher_input_tensor = recv_forward(teacher_recv_tensor_shapes, teacher_config)

    # (2-2) 前のステージにbackward結果を送りつつ、前のステージから次のデータを受け取る
    input_tensor = send_backward_recv_forward(
        input_tensor_grad, recv_tensor_shapes, config
    )

実験

実験では、V100のGPU2基を搭載した計算環境で、Qwen2.5-1.5B-Instructを教師モデル、Qwen2.5-0.5B-Instructを生徒モデルとして、global batch sizeを変えながら、どこでOOMが出るかをチェックしました。なお、メモリに関連するパラメータとしては、sequence lengthは1024、micro batch sizeは1で固定しました。ちなみに、詳細なメモリの使用量で見たかったのですが、内部でpush/popを繰り返しているからか、nvidia-smiはメモリの使用量は変わらなかったので、ここではOOMが出るかどうかのチェックに留めました。実験の結果、元のコードではglobal batch sizeが64以上でOOMが出てしまったのに対し、ABEJAの実装ではglobal batch sizeが1024でも動作することを確認しました。global batch sizeを増やすことで、小さい場合よりも安定した学習が期待できます。本実装を用いて、先のブログの実験結果を出すことができました。

モデル 8 16 32 48 64 128 256 512 1024
original
ABEJA

まとめ

本実験では、NeMoのpipeline parallelにおける蒸留の実装において、ABEJAで取った実装方法によって、メモリの使用量を抑えることができることを確認しました。NeMoの中身を把握することで、取りうる手段は増えるので、どんどんソースコードを読んでいきましょう。

なお、本成果は、経済産業省とNEDOが実施するGENIACでのモデル開発によって得られたものです。

追記

本記事を書き終えてNeMoのソースコードを改めて見ていたら、なんと2025年5月10日にNeMoの蒸留の実装の修正が行われたようで、動作チェックはしていませんが、ソースコードを眺めた限りでは上記問題を改善できているように見えます。Megatron-Coreのpipeline parallelの実装に、adjust_tensor_shapes_fnという仕組みを導入し、これまでは各ステージのネットワークの出力サイズに固定されていた送受信のshapeを外部から変えられるようにしたようです。これにより、nemo/collections/llm/modelopt/distill/utils.pyで教師モデルと生徒モデルの両方を送受信するように定義することで、教師と生徒を同時に処理できるようになっているようですね。俺でなきゃ見逃しちゃうね。

We Are Hiring!

ABEJAは、テクノロジーの社会実装に取り組んでいます。 技術はもちろん、技術をどのようにして社会やビジネスに組み込んでいくかを考えるのが好きな方は、下記採用ページからエントリーください! (新卒の方やインターンシップのエントリーもお待ちしております!)

careers.abejainc.com