ABEJA Tech Blog

中の人の興味のある情報を発信していきます

NVIDIA Clara を使用して医療・創薬系の AI モデルを動かしてみた

こんにちは!ABEJA で ABEJA Platform 開発や AI 関連の研究開発業務を行っている坂井(@Yagami360)です。

こちらはABEJAアドベントカレンダー2025の25日目の記事です。

近年のディープラーニングの発展に伴い、マテリアル・創薬・医療分野においても AI の活用が進みつつあるかと思いますが、その市場規模はかなり大きく、どのようなモデルや技術が使用されているのか?とか将来性はどうなのか?とか個人的に興味がありました。

とはいえ自分は、マテリアル・創薬・医療系のドメイン知識に関して素人なので、できるだけさくっと動かせるモデルやフレームワークで試してみたかったのですが、LLM とかと違ってこの分野では簡単に動かせるフレームワークやプラットフォームがそんなに整備されていない印象でした。そんな中で NVIDIA が提供している「NVIDIA Clara」と呼ばれる医療系の AI プラットフォームが、比較的簡単に動かせそうだったので、軽く動かしてみました

NVIDIA Clara の公式ドキュメントは、以下のリンク先にあります。

www.nvidia.com

NVIDIA Clara の概要

NVIDIA Clara は、NVIDIA が提供している医療系の AI プラットフォームで、 現時点で以下の分野のための機能群を提供しているようです。

  • NVIDIA Clara for Biopharma(AI 創薬)

    AI 創薬のためのプラットフォームである BioNeMo フレームワークを提供。BioNeMo フレームワークには、AI創薬のための事前学習済みモデル群も含まれます

    www.nvidia.com

  • NVIDIA Clara for Medical Devices / NVIDIA Holoscan

    医療機器データ(内視鏡動画、超音波などの医療センサーデータなど)のストリーミングに対してAIモデル推論によるリアルタイム解析を行うためのハイブリッド環境用(エッジ AI デバイス・オンプレ環境・クラウド環境)プラットフォームで、NVIDIA Holoscan として提供されています。

    www.nvidia.com

    docs.nvidia.com

  • NVIDIA Clara for Medical Imaging(AI 医療用画像診断)/ NVIDIA MONAI Toolkit

    NVIDIA Clara for Medical Imaging は、医療用画像(レントゲン画像、CT画像など)x AI(AIモデルによる画像セグメンテーションや画像生成など)にターゲットを絞った NVIDIA Clara の機能で、NVIDIA MONAI Toolkit として提供されています。

    www.nvidia.com

    docs.nvidia.com

  • NVIDIA Clara for Genomics(AI ゲノム解析)/ NVIDIA Parabricks

    www.nvidia.com

    次世代シーケンシング(NGS)による膨大なゲノムデータの解析を、GPU を使って劇的に高速化するためのプラットフォーム

    ゲノム解析における二次解析と呼ばれる工程では、細かく断片化された数億個のデータ(リード)をパズルのように正しい位置に並べ直す作業(アライメント)が発生するそうですが、これは行列とベクトル演算といった並列処理になるので、GPU で劇的な高速化ができるようです。 (余談ですが、GPUは、元々3Dゲームにおける行列とベクトル演算の高速化で使用されていましたが、AIモデルでも同じくDNN 内で行列とベクトル演算してるだけなので GPU で高速化できるんじゃんてなったのと同じ流れですね)

今回の記事では、最も簡単に動かせそうだった「NVIDIA Clara for Medical Imaging」と、個人的にAI 創薬の雰囲気だけでも知っておきたかったので 「NVIDIA Clara for Biopharma」も動かしてみます

NVIDIA Clara for Medical Imaging(AI 医療用画像診断)

NVIDIA Clara for Medical Imaging の概要

NVIDIA Clara for Medical Imaging は、医療用画像(レントゲン画像、CT画像など)x AI にターゲットを絞った NVIDIA Clara の機能で、以下の主要コンポーネントで構成されるようです

  • NVIDIA MONAI [Medical Open Network for AI] Toolkit

    医療用画像(レントゲン画像、CT画像など)の AI プロダクト開発・モデル学習・デプロイのためのフレームワークです

    catalog.ngc.nvidia.com

    • サポートしているモデル例
      • VISTA-3D

        CT画像における内臓領域のセグメンテーション用モデル

      • MAISI [Medical AI for Synthetic Imaging]

        3D CT 画像の生成モデルで、拡散モデルをベースとしている。用途しては、学習用データセットのデータ拡張での利用(合成データ)など

  • NVIDIA FLARE との連携

    developer.nvidia.com

    NVIDIA FLARE は、連合学習 [Federated Learning] のためのフレームワークです。 連合学習というのは分散学習の1種なのですが、複数の組織に属する開発者が学習用データセットなどを共有せずに協力してAIモデルを分散学習する方法のことです。
    プライバシー保護などのセキュリティ上の理由から、異なる組織に学習用データセットなどのデータ共有ができないケースに有用になります。 NVIDIA FLARE 自体は、NVIDIA Clara の医療用とは別のフレームワークになるのですが、MONAI Toolkit と連携することで、医療機関におけるデータ保護などのセキュリティ上の問題をクリアしつつモデルの学習を行うことができます。

blueprint でのデモサイトが公開されているので、以下のサイトから試してみるとイメージしやすいかと思います

  • VISTA-3D(CT画像における内臓領域のセグメンテーション用モデル)のデモ

    build.nvidia.com

  • MAISI(CT画像の生成モデル)のデモ

    build.nvidia.com

MONAI Toolkit を起動する

それでは実際に、NVIDIA Clara for Medical Imaging での AI モデルを動かしていきます。まずは、MONAI Toolkit を起動することから始めていきます。

MONAI Toolkit は、NVIDIA が NGC 上に公開している MONAI Toolkit 用 docker image のコンテナで動かす形になります。

  1. GPU インスタンスを構築する

    今回は少し動かしてみるだけだったので GCP の VMインスタンス(T4:1台)で環境構築しました。

    本格的な学習を行う場合は、クラウドで動かすとインフラコストが膨大になるので、オンプレ環境で動かすことを推奨します。

    公式のシステム要件は、以下に記載されています。

    記載されているスペックは A100, H100 とかの高価な GPU ですが、今回動かしてみるモデル(VISTA-3D)は、生成 AI モデルのような大規模モデルではなくただの画像セグメンテーションモデルなので、T4 でも動かせました。

    docs.nvidia.com

  2. NVIDIA の NGC にログインして SetUp 画面から API キーを作成する

    ngc.nvidia.com

  3. NVIDIA NGC Container Register にログインする

     docker login nvcr.io
    
    • ユーザー名: $oauthtoken
    • パスワード: 上記作成した API トークン
  4. MONAI Toolkit の Docker image を pull する

    以下のコマンドで NVIDIA NGC Container Register から MONAI Toolkit の Docker image を pull してください

     docker pull nvcr.io/nvidia/clara/monai-toolkit:3.0
    
  5. MONAI Toolkit の Docker コンテナを起動する

    • JupyterLab で起動する場合

      今回は、JupyterLab 上のコードを使用するので、以下のコマンドで JupyterLab のコンテナを起動してださい

        docker run --gpus all -it --rm \
        --ipc=host --net=host \
        nvcr.io/nvidia/clara/monai-toolkit:3.0
      

      デフォルトで 8888 ポートが接続ポートとして使用されます。変更したい場合は -e JUPYTER_PORT=8900 のように args を指定してください

    • [Option] bash で起動する場合

      なお自作したスクリプト等を MONAI Toolkit 環境で動かしたい場合は、コンテナを bash 起動した上でスクリプトを実行してください

        docker run --gpus all -it --rm \
            --ipc=host --net=host \
            nvcr.io/nvidia/clara/monai-toolkit:3.0 /bin/bash
      
  6. ブラウザ上で JupyterLab を開く

    上記コマンド実行後に以下のようなログが出力されるので、ログ中の http://localhost:8888/lab?token=dummy をブラウザで開きます

     Welcome to MONAI Toolkit
    
     MONAI toolkit components:
    
     MONAI Core, Version: 1.4.0
     MONAI Label, Version: 0.8.4
     NVFlare, Version: 2.5.0
     Jupyter Port is set to: 8888
     checking port 8888 availability
     [I 2025-12-22 06:19:05.734 ServerApp] jupyter_lsp | extension was successfully linked.
     [I 2025-12-22 06:19:05.739 ServerApp] jupyter_server_terminals | extension was successfully linked.
     ...
     [I 2025-12-22 06:19:07.563 ServerApp] Jupyter Server 2.14.2 is running at:
     [I 2025-12-22 06:19:07.563 ServerApp] http://localhost:8888/lab?token=dummy
     [I 2025-12-22 06:19:07.563 ServerApp]     http://127.0.0.1:8888/lab?token=dummy
     [I 2025-12-22 06:19:07.563 ServerApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).
     [C 2025-12-22 06:19:07.568 ServerApp] 
    
         To access the server, open this file in a browser:
             file:///root/.local/share/jupyter/runtime/jpserver-29-open.html
         Or copy and paste one of these URLs:
             http://localhost:8888/lab?token=dummy
             http://127.0.0.1:8888/lab?token=dummy
    
  7. JupyterLab 上で welcome.md 等の内容に従って処理を行う

    ブラウザ接続に成功すると、上のような JupyterLab でのページが表示されるので、このチュートリアルに従って各種 AI モデルを動かしていきます

内臓セグメンテーション用モデル(VISTA-3D)のファインチューニングを行う

CT画像セグメンテーション用モデル(VISTA-3D)のチュートリアルコード http://localhost:8888/lab/tree/tutorials/monai/vista_3d/vista3d_spleen_finetune.ipynb の学習部分に従って、学習を行います。

JupyterLab で動かしているので、UI 上から Run ボタン ▶️ をポチポチするだけで良いです。

  • 使用モデル

    VISTA-3D:CT画像セグメンテーション用モデルで、120以上の主要臓器クラスをセグメンテーション可能なモデルです

  • 学習用データセット

    以下で公開されているデータセットで、腹部断面のCT画像と正解ラベル(黄色の領域:脾臓として正しくラベル付けされた領域)を含むデータセットになります。

    このデータセットで VISTA-3D をファインチューニングします。

    medicaldecathlon.com

    msd-for-monai.s3-us-west-2.amazonaws.com

    • 画像解像度:239 x 239
    • 枚数: 150枚
      • かなり少ないが、ファインチューニングなので問題なし(本当はもっと多いほうが良い)
  • 学習設定

    • エポック数:5 epoch
      • かなり少ない気がしますが、チュートリアルのまま 5 epoch とします
    • 使用 GPU メモリ:
      • 8 ~ 9GB 程度なので、T4でも動かせました
  • loss 値のグラフ

    今回はチュートリアルと同じ設定で学習しましたが、epoch 5 ではまだ十分に loss が収束してないので、もっと Epoch 数増やしたほうが良さそうです

ファインチューニングした内臓セグメンテーション用モデル(VISTA-3D)で推論し、CT画像における特定の内臓領域をアノテーションする

同じくCT画像セグメンテーション用モデル(VISTA-3D)のチュートリアルコード http://localhost:8888/lab/tree/tutorials/monai/vista_3d/vista3d_spleen_finetune.ipynb の推論部分に従って、上記ファインチューニングしたモデル(VISTA-3D)で推論を行い、CT画像における特定の内臓領域をアノテーションします。

推論結果例は、以下の通りです。

腹部断面のCT画像(テスト用データを使用)から脾臓部分をうまくアノテーション(セグメンテーション)できていることがわかるかと思います

NVIDIA Clara for Biopharma(AI 創薬)

NVIDIA Clara for Biopharma / BioNeMo の概要

NVIDIA Clara for Biopharma は、NVIDIA Clara における AI 創薬のための機能です。BioNeMo と呼ばれるフレームワークと AI 創薬のための事前学習済みモデル群で構成されます

以下で blueprint でのデモが公開されているので、こちらを試してみるとイメージが掴みやすいと思います。

build.nvidia.com

  • RFdiffusion

    • 拡散モデルベースのモデルで、ノイズ入力から徐々にタンパク質構造を生成
    • 入力

      • 標的となるタンパク質の構造データ
      • 生成するタンパク質の長さや、標的タンパク質のどの範囲を固定するかを指定するパラメータ
      • ノイズ入力
    • 出力
      • タンパク質の3D構造の原型を PDB [Protein Data Bank] 形式で出力
        • この段階ではアミノ酸の種類は決まっておらず、グリシン(GLY)だけで構成されるいった原型だけのモデルが出力
  • ProteinMPNN

    • グラフ畳み込み(GNN)ベースのモデルで、タンパク質の3D構造からアミノ酸配列に逆変換するモデル
    • 入力

      • RFdiffusion で予想したタンパク質の3D構造
    • 出力
      • 入力されたタンパク質の3D構造に折りたたまる可能性が高いアミノ酸配列
      • 信頼度やスコア
  • AlphaFold2 (OpenFold)

    • Transformer ベースのモデルで、タンパク質の3次元構造を非常に高い精度で予測するAIモデル
    • 入力

      • ProteinMPNN で予想したアミノ酸配列
      • 既存の構造データ(例:ACE2)を参照用として使用
    • 出力
      • 最終的なタンパク質の3D構造を PDB [Protein Data Bank] 形式で出力
      • 信頼度スコア
  • 最終出力

    • json データ例

        {
            "result": {
                "alphafold2":
                    {
                        "all_predicted_pdbs": [
                            "ATOM      1  N   SER A  19      96.155  70.201  45.493  1.00100.66           N  \nATOM      2  CA  SER A  19      94.696  70.434  45.702  1.00101.02           C  \nATOM      3  C   SER A  19      93.987  69.087  45.880  1.00100.43           C  \nATOM      4  O   SER A  19      94.494  68.054  45.439  1.00 99.56           O  \nATOM      5  CB  SER A  19      94.116  71.194  44.499  1.00102.01           C  \nATOM      6  OG  SER A  19      92.778  71.609  44.730  1.00102.75           O  \nATOM      7  N   THR A  20      92.825  69.102  46.535  1.00100.27           N  \nATOM      8  CA  THR A  20      92.051  67.879  46.776  1.00 99.38           C  \nATOM      9  C   THR A  20      91.073  67.584  45.641  1.00 98.31           C  \nATOM     10  O   THR A  20      90.444  68.499  45.100  1.00 98.40           O  \nATOM     11  CB  THR A  20      91.236  67.973  48.092  1.00 98.99           C  \nATOM     12  OG1 THR A  20      92.126  68.161  49.199  1.00100.97           O  \nATOM     13  CG2 THR A  20      90.435  66.697  48.320  1.00 98.68           C  \nATOM     14  N   ILE A  21      90.946  66.305  45.289  1.00 97.23           N  \nATOM     15  CA  ILE A  21      90.028  65.874  44.232  1.00 96.91           C  \nATOM     16  C   ILE A  21      88.577  66.107  44.659  1.00 95.29           C  \nATOM     17  O   ILE A  21      87.687  65.320  44.350  1.00 95.10           O  \nATOM     18  CB  ILE A  21      90.209  64.370  43.890  1.00 97.49           C  \nATOM     19  CG1 ILE A  21      89.957  63.500  45.130  1.00 98.25  
                            ...
                        },
                        "proteinmpnn": {
                        "mfasta": ">input, score=3.0430, global_score=2.2795, fixed_chains=['B'], designed_chains=['A'], CA_model_name=v_48_002, git_hash=unknown, seed=667\nGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG\n>T=0.1, sample=1, score=1.3999, global_score=1.9260, seq_recovery=0.0469\nGPEEEEKKREIEALVAEMAKQFEEVKPLIELLEELKKKIGEGEKEAEKELKKKIEEEEKRKEEA\n",
                        "scores": [
                            1.3998868465423584
                        ],
                        "probabilities": [
                            [
                            [
                                2.9388796951579366e-10,
                                6.684248006694004e-19,
                                3.878053780881352e-11,
                                1.8682724638174886e-12,
                                2.496218801052632e-15,
                                0.9949813485145569,
                                8.619456505049819e-16,
                                2.2613524466467754e-15,
                                2.8787888739501e-10,
                                1.5028966569505253e-12,
                                0.004999660421162844,
                                5.45019411546388e-12,
                                1.129599669011383e-12,
                                4.379279345433822e-15,
                                8.465267549440103e-12,
                                0.000018977581930812448,
                                2.3997567821787413e-10,
                                3.865102394363202e-15,
                                1.6532286080523171e-18,
                                6.183128094975089e-17,
                                0
                            ],
                            [
                                0.00005282574420562014,
                                7.163673617923613e-18,
                                ...
                            ]
                            ],
                "rfdiffusion": {
                    "output_pdb": "ATOM      1  N   GLY A   1     -41.231  -9.202  10.946  1.00  0.00\nATOM      2  CA  GLY A   1     -40.410  -8.027  11.210  1.00  0.00\nATOM      3  C   GLY A   1     -39.456  -8.271  12.373  1.00  0.00\nATOM      4  O   GLY A   1     -38.329  -7.777  12.377  1.00  0.00\nATOM      5  N   GLY A   2     -39.762  -9.223  13.182  1.00  0.00\nATOM      6  CA  GLY A   2     -38.935  -9.554  14.336  1.00  0.00\nATOM      7  C   GLY A   2     -37.653 -10.259  13.912  1.00  0.00\nATOM      8  O   GLY A   2     -36.582 -10.007  14.465  1.00  0.00\nATOM      9  N   GLY A   3     -37.722 -11.118  12.941  1.00  0.00\nATOM     10  CA  GLY A   3     -36.533 -11.809  12.458  1.00  0.00\nATOM     11  C   GLY A   3     -35.546 -10.833  11.830  1.00  0.00\nATOM     12  O   GLY A   3     -34.337 -10.939  12.036  1.00  0.00\nATOM     13  N   GLY A   4     -36.039  -9.888  11.050  1.00  0.00\nATOM     14  CA  GLY A   4   
                    },
                    ...
        }
      
    • 上記 json データを 3D レンダリングした結果

    上記のモデル群を以下の手順でパイプライン的に推論することで、タンパク質の3D構造を出力することができます

    1. RFdiffusion::こういう機能に最適なタンパク質3D構造の原型をゼロ(ノイズ)から生成
    2. ProteinMPNN:上記生成されたタンパク質構造を安定させるためのアミノ酸配列を予想
    3. AlphaFold2:ProteinMPNN が予想したアミノ酸配列が、本当に意図したタンパク質構造に折りたたまれるかを検証

BioNeMo Framework を起動する

それでは実際に、NVIDIA Clara for Biopharma での AI モデルを動かしていきます。まずは、BioNeMo Framework を起動することから始めていきます。

BioNeMo Framework は、NVIDIA が NGC 上に公開している BioNeMo Framework 用 docker image のコンテナで動かす形になります。

  1. GPU インスタンスを構築する

    GCP の VMインスタンス(A100:1台)で環境構築しました。GPU メモリ的には、V100 でも問題なさそうでしたが、自分の環境では V100 だと以下のエラーが発生したので A100 にしました

     [rank0]: RuntimeError: CUDA error: no kernel image is available for execution on the device
    

    今回は少し動かしてみるだけだったのでクラウド環境で環境構築しましたが、本格的な学習を行う場合は、クラウドで動かすとインフラコストが膨大になるので、オンプレ環境で動かすことを推奨します。

  2. NVIDIA の NGC にログインして SetUp 画面から API キーを作成する

    GPU-optimized AI, Machine Learning, & HPC Software | NVIDIA NGC

  3. NVIDIA NGC Container Register にログインする

     docker login nvcr.io
    
    • ユーザー名: $oauthtoken
    • パスワード: 上記作成した API トークン
  4. BioNeMo Framework 用の docker image を pull する

     docker pull nvcr.io/nvidia/clara/bionemo-framework:2.7
    
  5. 各種環境変数を設定する

     export WANDB_API_KEY='dummy'    # wandb の API キー
     ...
    

    .env ファイルで定義するのでも OK です

  6. BioNeMo Framework の docker コンテナへの接続を行う

     mkdir -p results
     docker run --rm -it --gpus all \
         --network host \
         --shm-size=4g \
         -e WANDB_API_KEY \
         -v ${PWD}/results:/workspace/bionemo2/results \
         nvcr.io/nvidia/clara/bionemo-framework:2.7 \
         /bin/bash
    

    https://github.com/NVIDIA/bionemo-framework レポジトリのコードが存在するコンテナになっており、今回の記事ではこのレポジトリのコードを動かすことになります。

     root@011d2a38b429:/workspace/bionemo2# ls -al
     total 40
     drwxrwxrwx 1 root root 4096 Sep 30 02:17 .
     drwxrwxrwx 1 root root 4096 Sep 26 10:03 ..
     drwxrwxrwx 1 root root 4096 Sep 30 02:17 .cache
     drwxrwxrwx 1 root root 4096 Sep 26 19:23 LICENSE
     -rwxrwxrwx 1 root root 6358 Sep 30 02:13 README.md
     -rwxrwxrwx 1 root root    7 Sep 30 02:13 VERSION
     drwxrwxrwx 1 root root 4096 Sep 30 02:17 ci
     drwxrwxrwx 1 root root 4096 Sep 30 02:13 docs
     drwxrwxrwx 1 root root 4096 Sep 30 02:13 sub-packages
    
     root@011d2a38b429:/workspace/bionemo2# cd ..
     root@011d2a38b429:/workspace# ls -al
     total 28
     drwxrwxrwx 1 root root 4096 Sep 26 10:03 .
     drwxr-xr-x 1 root root 4096 Nov  6 08:31 ..
     -rw-rw-rw- 1 root root 2048 Jun 12 15:29 README.md
     drwxrwxrwx 1 root root 4096 Sep 30 02:17 bionemo2
     drwxrwxrwx 1 root root 4096 Jun 12 15:29 docker-examples
     -rw-rw-rw- 1 root root  467 Jun 12 06:42 license.txt
     drwxrwxrwx 1 root root 4096 Jun 12 15:33 tutorials
    

BioNeMo Framework を使用してタンパク質言語モデル(ESM-2)の学習を行う

では次に、タンパク質言語モデル(pLM: Protein Language Model)の1つである ESM-2 の学習を行う方法を紹介します。

タンパク質言語モデルとは、タンパク質のアミノ酸配列(20種類のアミノ酸:A, C, D, E...)を言語として捉え、そのアミノ酸配列の並びを学習することで、タンパク質の3次元構造予測のための特徴量(埋め込みベクトル)を出力できるようにした LLM です。

タンパク質は、20種類のアミノ酸の組み合わせで構成されますが、20種類のアミノ酸の組み合わせ数は天文学的数になります。またタンパク質の3D構造が薬の作用にとって重要になるので、AI モデルによってタンパク質の3D構造を予想することは創薬において有益になります。

BioNeMo Framework の docker コンテナへの接続後、コンテナ内にて以下の train_esm2 コマンドを実行してください

# wandb 上で loss 値を確認したい場合
wandb login

# ESM-2 用の環境変数
export MY_DATA_SOURCE="ngc"

# The fastest transformer engine environment variables in testing were the following two
TEST_DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source $MY_DATA_SOURCE); \
ESM2_650M_CKPT=$(download_bionemo_data esm2/650m:2.0 --source $MY_DATA_SOURCE); \

train_esm2 \
    --train-cluster-path ${TEST_DATA_DIR}/2024_03_sanity/train_clusters_sanity.parquet \
    --train-database-path ${TEST_DATA_DIR}/2024_03_sanity/train_sanity.db \
    --valid-cluster-path ${TEST_DATA_DIR}/2024_03_sanity/valid_clusters.parquet \
    --valid-database-path ${TEST_DATA_DIR}/2024_03_sanity/validation.db \
    --result-dir ./results \
    --experiment-name exper_esm2_20251107 \
    --wandb-project ai-material-exercises \
    --num-gpus 1 \
    --num-nodes 1 \
    --val-check-interval 100 \
    --num-dataset-workers 4 \
    --num-steps 1000 \
    --max-seq-length 1024 \
    --limit-val-batches 4 \
    --micro-batch-size 4 \
    --restore-from-checkpoint-path ${ESM2_650M_CKPT}

loss 値のグラフ(--num-steps 1000 の場合)は、以下のようになりました

学習完了後、results ディレクトリ以下に、学習済みチェックポイント等が保存されます

BioNeMo Framework を使用してタンパク質言語モデル(ESM-2)の推論を行う

次に、ESM-2 で推論する方法を紹介します

  1. ESM-2 の学習済みチェックポイントをダウンロードします

    今回は上記工程で学習した ESM-2 ではなく、以下のスクリプトでダウンロードした学習済み ESM-2 モデルで推論します

     from bionemo.core.data.load import load
    
     checkpoint_path = load("esm2/650m:2.0")
     print("checkpoint_path: ", checkpoint_path)
    
  2. 推論時に入力する以下のようなアミノ酸配列(タンパク質一次構造)データのファイル(sequences.csv)を作成します

     sequences
     TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI
     LYSGDHSTQGARFLRDLAENTGRAEYELLSLF
     GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN
     DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF
     KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI
     LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP
     LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF
     LYSGDHSTQGARFLRDLAENTGRAEYELLSLF
     ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP
     SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT
    
  3. ESM-2 での推論を行う

    BioNeMo Framework の docker コンテナへの接続後、コンテナ内にて以下の infer_esm2 コマンドを実行してください

     infer_esm2 \
         --checkpoint-path /root/.cache/bionemo/0798767e843e3d54315aef91934d28ae7d8e93c2849d5fcfbdf5fac242013997-esm2_650M_nemo2.tar.gz.untar \
         --data-path ./datasets/sequences.csv \
         --results-path ./results \
         --micro-batch-size 3 \
         --num-gpus 1 \
         --precision "bf16-mixed" \
         --include-hiddens \
         --include-embeddings \
         --include-logits \
         --include-input-ids
    
     ...
     LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
     [NeMo W 2025-11-07 07:50:52 nemo_logging:405] Could not copy Trainer's 'max_steps' to LR scheduler's 'max_steps'. If you are not using an LR scheduler, this warning can safely be ignored.
     [NeMo I 2025-11-07 07:50:52 nemo_logging:393]  > number of parameters on (tensor, pipeline) model parallel rank (0 ,0): 651164288
     [NeMo I 2025-11-07 07:51:01 nemo_logging:393] Inference predictions are stored in results/predictions__rank_0__dp_rank_0.pt
         dict_keys(['token_logits', 'binary_logits', 'hidden_states', 'embeddings'])
    

    推論結果は、--results-path で指定したディレクトリ以下に保存されます。

     root@sakai-gpu-dev:/workspace/bionemo2# cd results/
     root@sakai-gpu-dev:/workspace/bionemo2/results# ls -al
     total 28196
     drwxr-xr-x 2 1005 1006     4096 Nov  7 07:51 .
     drwxrwxrwx 1 root root     4096 Nov  7 07:50 ..
     -rw-r--r-- 1 root root 28863920 Nov  7 07:51 predictions__rank_0__dp_rank_0.pt
    
  4. 推論時の出力データを確認する

    上記推論後に出力される predictions__rank_0__dp_rank_0.pt 内に、dict_keys(['token_logits', 'binary_logits', 'hidden_states', 'embeddings']) の形式で各推論データが存在します

     token_logits    torch.Size([1024, 10, 128])     # トークンの予測スコア(配列長, バッチ, 隠れ層)
     hidden_states   torch.Size([10, 1024, 1280])    # 内部特徴量(バッチ, 配列長, 埋め込み次元)
     input_ids       torch.Size([10, 1024])          # 入力トークンID(バッチ, 配列長)
     embeddings      torch.Size([10, 1280])          # 配列全体の埋め込みベクトル(バッチ, 埋め込み次元)
    

    token_logits は、トークンの予測スコアで、最後の次元の 128 内の最初の33位置がアミノ酸語彙に対応し、その後に95個のパディングが続きます。

    また hidden_states, embeddings が、タンパク質言語モデルが出力するタンパク質の3次元構造予測のための特徴量(埋め込みベクトル)になり、この特徴量を別の構造予測モデル(ESMFold など)に渡して、3D構造を推定することになります。

まとめ

今回は初めて AI x 医療の分野のモデルやフレームワークについて軽く調べてみました。

AI モデルによる医療用画像診断や創薬の雰囲気だけも掴むことで、技術の方向性や将来性への解像度が少し上げられて良かったです

We Are Hiring!

ABEJAは、テクノロジーの社会実装に取り組んでいます。 技術はもちろん、技術をどのようにして社会やビジネスに組み込んでいくかを考えるのが好きな方は、下記採用ページからエントリーください! (新卒の方やインターンシップのエントリーもお待ちしております!) careers.abejainc.com

特に下記ポジションの募集を強化しています!ぜひ御覧ください!

トランスフォーメーション領域:データサイエンティスト

トランスフォーメーション領域:データサイエンティスト(ミドル)

トランスフォーメーション領域:データサイエンティスト(シニア)