はじめに
こちらはABEJAアドベントカレンダー2024 12日目の記事です。
こんにちは、ABEJAでデータサイエンティストをしている坂元です。最近はLLMでアプローチしようとしていたことがよくよく検証してみるとLLMでは難しいことが分かり急遽CVのあらゆるモデルとレガシーな画像処理をこれでもかというくらい詰め込んだパイプラインを実装することになった案件を経験して、LLMでは難しそうなことをLLM以外のアプローチでこなせるだけの引き出しとスキルはDSとしてやはり身に付けておくべきだなと思うなどしています(LLMにやらせようとしていることは大抵難しいことなので切り替えはそこそこ大変)。
とはいうものの、Agentの普及によってより複雑かつ高度な推論も出来るようになってきています。弊社の社内外のプロジェクト状況を見ていても最近では単純なRAG案件は減りつつあり、計画からアクションの実行、結果の集約までを行うAgentやMulti Agentの需要が高まっているように見えます。これまではAgent開発のフレームワークと言えばLangGraphやAutoGen、Langroidといったものがよく使われてきたと思いますが、最近Pydanticが公式でAgentフレームワークの開発を進めているので紹介します。
ただし、基本的な使い方はドキュメントにかなり詳しく書かれていてそれをなぞるだけだとつまらないので、ここではPydanticの利点をどう活かしているのかという点に注目して中身の実装部分を掘り下げて行こうと思います。
- はじめに
- Pydanticの利点
- PydanticAIとは
- 基本的な使い方
- どのように型安全を実現しているのか
- 依存関係の注入
- その他の機能
- PydanticAIを深掘りしてみた感想
- We Are Hiring!
注:本記事執筆時点のpydantic-aiのバージョンは0.0.12です。目下開発進行中のライブラリで更新頻度が高いので、今後のバージョンアップで実装は変わるかもしれません。
Pydanticの利点
Pydanticは高速かつカスタマイズ性の高いデータ検証ライブラリです。Pydanticの利点をざっくりまとめると以下のような感じです。
厳密な型安全性とバリデーション
PydanticはPythonの型ヒント(type hint)を積極的に活用し、データ構造が期待する型通りであることを保証します。モデル定義時に型ヒントを書くだけで自動的に入力値に対してバリデーションが行われます。また、バリデーションのタイミングも細かく制御出来ます。デフォルト値や必須属性の管理
フィールドに対してデフォルト値を簡単に設定でき、またOptionalやFieldを用いることで必須/任意属性、さらには入力値の最大・最小値や正規表現などの追加制約を簡易かつ直感的に指定できます。複雑なネスト構造への対応
リスト、辞書、他のPydanticモデルをフィールドに持つような複雑なネスト構造でも簡単に定義でき、同様にバリデーションされます。有効なデータへの自動変換
文字列を日付型へ、整数文字列を整数へといった型変換を自動的に行います。この機能はあるメンバが複数のPydanticモデルを受け入れる可能性がある場合でも自動で最適な型を推測して変換してくれるので地味に便利です。JSONシリアライズ/デシリアライズの容易化
モデルインスタンスを.model_dump()
でdictに、.model_dump_json()
でJSON文字列へ簡単に変換できます。また、.model_json_schema()
でJSONスキーマへの変換も出来ます。さらにPydantic v2.10.0からpartial validationという機能が実験的に追加され、不完全なJSON文字列に対しても部分的にバリデーションが実行出来るようになっています(ただし機能はまだ限定的)。デフォルトでの読み取り専用フィールド、エイリアス、複数の入力形式サポート
Fieldオプションでaliasを指定することで、異なる名前で送られてくるJSONキーに対応したり、.model_validate()
,.model_validate_json()
などのメソッドを用いてさまざまなフォーマットのデータを柔軟に読み込めます。
PydanticAIとは
Pydanticが公式で開発を進めているAgentフレームワークで、Pydanticの利点やプラクティスを詰め込むことで型安全かつシンプルにAgentを構築することを容易にします。ドキュメントには以下の記載があります。
- Built by the team behind Pydantic (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more)
- Model-agnostic — currently OpenAI, Gemini, and Groq are supported, Anthropic is coming soon. And there is a simple interface to implement support for other models.
- Type-safe
- Control flow and agent composition is done with vanilla Python, allowing you to make use of the same Python development best practices you'd use in any other (non-AI) project
- Structured response validation with Pydantic
- Streamed responses, including validation of streamed structured responses with Pydantic
- Novel, type-safe dependency injection system, useful for testing and eval-driven iterative development
- Logfire integration for debugging and monitoring the performance and general behavior of your LLM-powered application
Logfireはインターネット上にホストされているlofireのプラットフォームにテレメトリを送信し監視をするツールですが、logfire
パッケージをインストールしない限り何もしません。PydanticAIのドキュメントには以下の記載があります。
PydanticAI has built-in (but optional) support for Logfire via the logfire-api no-op package.
That means if the logfire package is installed and configured, detailed information about agent runs is sent to Logfire. But if the logfire package is not installed, there's virtually no overhead and nothing is sent.
基本的な使い方
APIキーを読み込んでいる前提で、APIからのレスポンスを受け取るだけであれば以下のコードで済みます。
from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel model = OpenAIModel("gpt-4o") agent = Agent(model) result = await agent.run("Where does 'hello world' come from?") print(result.data) # > The phrase "Hello, World!" originated as an example of a simple program to demonstrate the basic syntax of a programming language. Its earliest and most famous usage is attributed to the book "The C Programming Language," published in 1978 by Brian Kernighan and Dennis Ritchie. In this book, the authors used "Hello, World!" as the first program example to illustrate how to write and run a simple program in the C language. Since then, it has become a traditional and widely used practice for developers learning a new programming language or environment to write a "Hello, World!" program.
このresultはpydanticモデルではなく、RunResult
というクラスのインスタンスです。
from pydantic import BaseModel type(result), issubclass(type(result), BaseModel) # > (pydantic_ai.result.RunResult, False)
同期的に実行する場合は以下のようにします。(ただ、内部的には非同期処理を行っているのでnotebookで実行する際はイベントループの競合によるエラーを回避するためにnest_asyncio
を実行しておく)
import nest_asyncio nest_asyncio.apply() result = agent.run_sync("Where does 'hello world' come from?")
Azure OpenAIを利用する場合はopenai-pythonでAzureOpenAIクライアントを初期化して渡す必要があります。
from openai import AsyncAzureOpenAI client = AsyncAzureOpenAI( azure_endpoint="...", api_version="2024-07-01-preview", api_key="your-api-key", ) model_azure = OpenAIModel("gpt-4o", openai_client=client) agent = Agent(model_azure)
Gemini on VertexAIの場合、google周りの認証を済ませていれば以下のコードで初期化することが出来ます。
import os from pydantic_ai import Agent from pydantic_ai.models.vertexai import VertexAIModel model = VertexAIModel("gemini-1.5-flash", project_id=os.getenv("PROJECT_ID")) agent_gemini = Agent(model)
どのように型安全を実現しているのか
ここからPydanticAIの中身を掘り下げていきます。
返り値の型を指定する
pydantic-aiではAPIコールの返り値の型を以下のようにはresult_type
で指定することが出来ます。
from pydantic import BaseModel class CityLocation(BaseModel): city: str country: str model = OpenAIModel("gpt-4o") agent = Agent(model, result_type=CityLocation) result = await agent.run("Where the olympics held in 2012?") print(result.data) print(result.cost()) # > city='London' country='United Kingdom' # > Cost(request_tokens=68, response_tokens=19, total_tokens=87, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0})
結果を見るとちゃんとreturn_type
で指定したモデルで出力されています。
result.data
# CityLocation(city='London', country='United Kingdom')
ドキュメントを見ると、
If the result type is a union with multiple members (after remove str from the members), each member is registered as a separate tool with the model in order to reduce the complexity of the tool schemas and maximise the changes a model will respond correctly.
(DeepL訳):結果型が複数のメンバーを持つユニオンの場合(メンバーからstrを取り除いた後)、ツールのスキーマの複雑さを軽減し、モデルが正しく反応する変化を最大化するために、各メンバーは別々のツールとしてモデルに登録される。
と書いてあります。つまり、指定した属性と型の引数を持つようなツールのスキーマを与え、その引数をresult_type
にマップするということのようです。実装の詳細を見てみましょう。Agentクラスをインスタンス化すると内部で以下のようにスキーマ定義されています。
agent._result_schema
# > ResultSchema(tools={'final_result': ResultTool(tool_def=ToolDefinition(name='final_result', description='The final response which ends this conversation', parameters_json_schema={'properties': {'city': {'title': 'City', 'type': 'string'}, 'country': {'title': 'Country', 'type': 'string'}}, 'required': ['city', 'country'], 'title': 'CityLocation', 'type': 'object'}, outer_typed_dict_key=None), type_adapter=TypeAdapter(CityLocation))}, allow_text_result=False)
ツール名はデフォルトでfinal_result
となっていて、descriptionの指定が無い場合はデフォルトでThe final response which ends this conversation
という説明が使われます。parameters_json_schema
属性にはツールを使用するのに必要な情報がありますね。このスキーマが定義される手順を追ってみましょう。result_tool_name
とresult_tool_description
はデフォルト値を使用します。
result_tool_name: str = "final_result" # default result_tool_description: str | None = None # default result_type = CityLocation
Agentのインスタンス化時に以下のメソッドが実行されます。
from pydantic_ai._result import ResultSchema ResultSchema[result_type].build(result_type, result_tool_name, result_tool_description) # > ResultSchema(tools={'final_result': ResultTool(tool_def=ToolDefinition(name='final_result', description='The final response which ends this conversation', parameters_json_schema={'properties': {'city': {'title': 'City', 'type': 'string'}, 'country': {'title': 'Country', 'type': 'string'}}, 'required': ['city', 'country'], 'title': 'CityLocation', 'type': 'object'}, outer_typed_dict_key=None), type_adapter=TypeAdapter(CityLocation))}, allow_text_result=False)
ResultSchemaはGenericsになっていて、入力された型の情報を保持することが出来ます(このような実装が随所に見られます)。このbuild()
メソッドは与えられたresult_type
から返り値がUnionであるかどうかを判定し、Unionの場合は複数のツールを定義します。その際、以下の関数の部分で上記の結果にあるResultTool
クラスのインスタンスを生成しています。a
はツールの引数(もといresult_typeの属性)、multiple
というのは返り値がUnionであればTrue、そうでなければFalseとなるフラグです。
def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultData]: return cast(ResultTool[ResultData], ResultTool(a, tool_name_, description, multiple))
なお、返り値の型はTypedDict、dataclass、Pydanticモデルのいずれかで定義出来ますが、TypedDictとdataclassが指定された場合はPydanticのTypeAdapter
を適用してPydanticモデルと同様にJSONスキーマを生成出来るようにしています。TypeAdapter
はPydanticモデルではないdataclasasやTypedDictのJSONスキーマを生成したりバリデーションするために使用するPydanticの機能です。こうすることで、異なる型に対して同じコードを再利用しつつ型安全な状態を作っていて、Pydanticの良さを最大限に活用しています。
PydanticのTypeAdapterの詳細についてはドキュメントをご参照ください。
ここではTypedDictを使用する場合もdataclassを使用する場合もJSONスキーマが同じになることを確認しましょう。
from pydantic import TypeAdapter from typing_extensions import TypedDict from dataclasses import dataclass # TypedDictの場合 class CityLocationDict(TypedDict): city: str country: str TypeAdapter(CityLocationDict).json_schema() # > {'properties': {'city': {'title': 'City', 'type': 'string'}, # > 'country': {'title': 'Country', 'type': 'string'}}, # > 'required': ['city', 'country'], # > 'title': 'CityLocationDict', # > 'type': 'object'} # dataclassの場合 @dataclass class CityLocationData: city: str country: str TypeAdapter(CityLocationData).json_schema() # > {'properties': {'city': {'title': 'City', 'type': 'string'}, # > 'country': {'title': 'Country', 'type': 'string'}}, # > 'required': ['city', 'country'], # > 'title': 'CityLocationData', # > 'type': 'object'}
次はこのようなResultTool
クラスを使ってそれぞれのクライアントAPI向けにツールのスキーマを変換します。各クライアントのモデルクラスにagent_model
というメソッドを定義し、そのメソッド内で変換します。OpenAIModel
の場合は以下のようになります。
async def agent_model( self, *, function_tools: list[ToolDefinition], allow_text_result: bool, result_tools: list[ToolDefinition], ) -> AgentModel: check_allow_model_requests() tools = [self._map_tool_definition(r) for r in function_tools] if result_tools: tools += [self._map_tool_definition(r) for r in result_tools] return OpenAIAgentModel( self.client, self.model_name, allow_text_result, tools, )
このメソッドの中で使われているself._map_tool_definition()
は以下のようになっており、見慣れたOpenAI APIのツール呼び出しのスキーマを返す関数となっています。
@staticmethod def _map_tool_definition(f: ToolDefinition) -> chat.ChatCompletionToolParam: return { 'type': 'function', 'function': { 'name': f.name, 'description': f.description, 'parameters': f.parameters_json_schema, }, }
上記の操作を行っているのがagent.run()
内の以下の部分になります。
model_used, mode_selection = await agent._get_model(model) ... agent_model = await agent._prepare_model(model_used, deps=None)
このagent_model
インスタンスにはrequest()
メソッドが定義されており、このメソッドの中でクライアントAPIに対してリクエストを飛ばしています。request()
メソッドの引数の型はlist[Message]
となっており、pydantic-ai独自のメッセージクラスのインスタンスを受け付けるようです。
new_message_index, messages = await agent._prepare_messages( deps=None, user_prompt="Where the olympics held in 2012?", message_history=[], ) messages # > [UserPrompt(content='Where the olympics held in 2012?', timestamp=datetime.datetime(2024, 12, 9, 2, 59, 39, 352121, tzinfo=datetime.timezone.utc), role='user')]
メッセージはLLMクライアントに依らずSystemPrompt
, UserPrompt
, ToolReturn
, RetryPrompt
, ModelTextResponse
, ModelStructuredResponse
の6つのプロンプトタイプのリストとなっており、OpenAI APIでもGemini APIでも共通のメッセージを利用することが出来ます。これらのクラスはPydanticモデルではなく単なるdataclassのようです。
ではOpenAIAgentModelのrequest()メソッドを実行してみます。
model_response, request_cost = await agent_model.request(messages) model_response # > ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"city":"London","country":"United Kingdom"}'), tool_id='call_Je21MIkAZ1oQ4BjWc4ikQKHo')], timestamp=datetime.datetime(2024, 12, 9, 3, 1, 47, tzinfo=datetime.timezone.utc), role='model-structured-response')
1つめの返り値はOpenAI APIのツール呼び出しの結果をModelStructuredResponse
というクラスのインスタンスにまとめ直したものです。このレスポンスは以下のメソッド内でバリデーションされ、ラップされます。
final_result, response_messages = await agent._handle_model_response(model_response, deps=None) final_result, response_messages # > (_MarkFinalResult(data=CityLocation(city='London', country='United Kingdom')), # > [ToolReturn(tool_name='final_result', content='Final result processed.', tool_id='call_Je21MIkAZ1oQ4BjWc4ikQKHo', timestamp=datetime.datetime(2024, 12, 9, 3, 2, 33, 172645, tzinfo=datetime.timezone.utc), role='tool-return')])
重要な所だけ抜粋して説明します。まず以下のメソッドでクライアントAPIから返ってきたツール呼び出しのレスポンスにあるツール名とそれに対応するスキーマのResultTool
のペアを取得します。
call, result_tool = agent._result_schema.find_tool(model_response) call, result_tool # > (ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"city":"London","country":"United Kingdom"}'), tool_id='call_Je21MIkAZ1oQ4BjWc4ikQKHo'), # > ResultTool(tool_def=ToolDefinition(name='final_result', description='The final response which ends this conversation', parameters_json_schema={'properties': {'city': {'title': 'City', 'type': 'string'}, 'country': {'title': 'Country', 'type': 'string'}}, 'required': ['city', 'country'], 'title': 'CityLocation', 'type': 'object'}, outer_typed_dict_key=None), type_adapter=TypeAdapter(CityLocation)))
ResultTool
クラスは、ツール呼び出しのレスポンスに対してresult_type
のJSONスキーマに合致しているかどうかをバリデーションするメソッドを持っています。
result_data = result_tool.validate(call)
result_data
# > CityLocation(city='London', country='United Kingdom')
このvalidate()
メソッド内では以下のようにPydanticのvalidate_json()
が実行され、バリデーションが通れば最初に指定した返り値の型が返されます。
PydanticのJSONパースの詳細についてはこちらをご参照ください。
result_tool.type_adapter.validate_json(call.args.args_json)
テキストレスポンスの際はバリデーションが通ればそのまま、result_type
が指定された場合は「対応するツールがあり」かつ「バリデーションが通れば」_MarkFinalResult
クラスのインスタンスとしてラップされ出力されます。_MarkFinalResult
に変換されれば、最終的にこのバリデーションされたレスポンスは以下のようにRunResultクラスのインスタンスになります。
from pydantic_ai.result import RunResult result_data = final_result.data RunResult(messages, new_message_index, result_data, request_cost) # > RunResult(_all_messages=[UserPrompt(content='Where the olympics held in 2012?', timestamp=datetime.datetime(2024, 12, 8, 1, 52, 39, 547497, tzinfo=datetime.timezone.utc), role='user')], _new_message_index=0, data=CityLocation(city='London', country='United Kingdom'), _cost=Cost(request_tokens=68, response_tokens=19, total_tokens=87, details={'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0, 'cached_tokens': 0}))
もしレスポンスのバリデーションが失敗するなどで_MarkFinalResult
とならない場合は、指定した回数だけAPIコールをリトライするようになっています。その際、RetryPrompt
というプロンプトがメッセージの末尾に挿入されます。RetryPrompt
は以下のようになっています。
@dataclass class RetryPrompt: """A message back to a model asking it to try again. This can be sent for a number of reasons: * Pydantic validation of tool arguments failed, here content is derived from a Pydantic [`ValidationError`][pydantic_core.ValidationError] * a tool raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception * no tool was found for the tool name * the model returned plain text when a structured response was expected * Pydantic validation of a structured response failed, here content is derived from a Pydantic [`ValidationError`][pydantic_core.ValidationError] * a result validator raised a [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] exception """ content: list[pydantic_core.ErrorDetails] | str """Details of why and how the model should retry. If the retry was triggered by a [`ValidationError`][pydantic_core.ValidationError], this will be a list of error details. """ tool_name: str | None = None """The name of the tool that was called, if any.""" tool_id: str | None = None """The tool identifier, if any.""" timestamp: datetime = field(default_factory=_now_utc) """The timestamp, when the retry was triggered.""" role: Literal['retry-prompt'] = 'retry-prompt' """Message type identifier, this type is available on all message as a discriminator.""" def model_response(self) -> str: if isinstance(self.content, str): description = self.content else: json_errors = ErrorDetailsTa.dump_json(self.content, exclude={'__all__': {'ctx'}}, indent=2) description = f'{len(self.content)} validation errors: {json_errors.decode()}' return f'{description}\n\nFix the errors and try again.'
見たところPydanticのエラートレースに多少のお化粧をしたプロンプトのようです。このプロンプトはリトライ時に以下のようにmodel_response()
の返り値のテキストがクライアントAPIのメッセージに変換されます。OpenAI APIの場合は以下のようにChatCompletionUserMessageParam
としてuserのロールで挿入されます。
chat.ChatCompletionUserMessageParam(role='user', content=message.model_response())
リトライプロンプトが英語なので日本語での会話にどのような影響を与えるのか少し気になるところです。また、今となってはOpenAI APIではStructured Outputで複雑なスキーマのレスポンスを生成出来るようになっているのであえてツール呼び出しを使う理由があるのか気になったのですが、Structured Outputよりもツール呼び出しの方がクライアントによる機能の差分が小さく抽象化しやすいからかもしれません。
ここまでは単一の返り値の場合についての内部的な動作を追ってきましたが、pydantic_aiでは返り値の型を複数指定することが出来ます。例えば、文字列型を含むUnionの場合、result_schema
は以下のようになります。
from typing import Union from pydantic import BaseModel class Box(BaseModel): width: int height: int depth: int units: str result_type = Union[Box, str] ResultSchema[result_type].build(result_type, result_tool_name, result_tool_description) # > ResultSchema(tools={'final_result': ResultTool(tool_def=ToolDefinition(name='final_result', description='The final response which ends this conversation', parameters_json_schema={'properties': {'width': {'title': 'Width', 'type': 'integer'}, 'height': {'title': 'Height', 'type': 'integer'}, 'depth': {'title': 'Depth', 'type': 'integer'}, 'units': {'title': 'Units', 'type': 'string'}}, 'required': ['width', 'height', 'depth', 'units'], 'title': 'Box', 'type': 'object'}, outer_typed_dict_key=None), type_adapter=TypeAdapter(Box))}, allow_text_result=True)
strを含まない複数の型を持つ場合は、それぞれの型についてのツールが定義され、どの型に合致するようにレスポンスを生成するかは各種LLMに任せているようです。
result_type = list[str] | list[int] ResultSchema[result_type].build(result_type, result_tool_name, result_tool_description) # > ResultSchema(tools={'final_result_list': ResultTool(tool_def=ToolDefinition(name='final_result_list', description='list: The final response which ends this conversation', parameters_json_schema={'properties': {'response': {'items': {'type': 'string'}, 'title': 'Response', 'type': 'array'}}, 'required': ['response'], 'type': 'object'}, outer_typed_dict_key='response'), type_adapter=TypeAdapter(response_data_typed_dict)), 'final_result_list_2': ResultTool(tool_def=ToolDefinition(name='final_result_list_2', description='list: The final response which ends this conversation', parameters_json_schema={'properties': {'response': {'items': {'type': 'integer'}, 'title': 'Response', 'type': 'array'}}, 'required': ['response'], 'type': 'object'}, outer_typed_dict_key='response'), type_adapter=TypeAdapter(response_data_typed_dict))}, allow_text_result=False)
ストリーミング出力を型安全にする
ストリーミング出力で型を検証したい場合はPydanticに最近追加された部分検証という機能を使用するそうです。(まだexperimentalの機能です)
ドキュメントには以下のように書かれています。
PydanticAI streams just enough of the response to sniff out if it's a tool call or a result, then streams the whole thing and calls tools, or returns the stream as a StreamedRunResult.
ふむ。まずはここまでと同じ段どりでresult_type
を指定した状態でストリーミングしてみましょう。
from datetime import date from pydantic_ai import Agent from typing_extensions import TypedDict class UserProfile(TypedDict, total=False): name: str dob: date bio: str model = OpenAIModel("gpt-4o") agent = Agent( model=model, result_type=UserProfile, system_prompt="Extract a user profile from the input", ) user_prompt = ( "My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid." ) async with agent.run_stream(user_prompt) as response: async for message in response.stream(): print(message) # > {'name': 'Ben'} # > {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'I like'} # > {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'I like the chain the dog and the pyramid.'} # > {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'I like the chain the dog and the pyramid.'}
では、どのようにバリデーションしているのか実装を追ってみましょう。基本的にメッセージを整える所までは非ストリームの場合と同様なので省略します。
ストリーミングの場合はagent_modelのrequest_stream()
メソッドを使用します。OpenAIAgentModelの場合は以下のようになっています。
@asynccontextmanager async def request_stream(self, messages: list[Message]) -> AsyncIterator[EitherStreamedResponse]: response = await self._completions_create(messages, True) async with response: yield await self._process_streamed_response(response)
_completions_create()
は普通にOpenAI APIのストリームのリクエストをする処理で、_process_streamed_response()
は指定の形式(OpenAIStreamStructuredResponse
またはOpenAIStreamTextResponse
)にレスポンスを整形する処理です。ただ、出力はストリーミングなので構造化されたストリームなのか単なるテキストのストリームなのかはOpenAI APIから返って来るストリームのレスポンスを実際に少し進めてみるまで分かりません。そこで、それを判定するのに有効な最初のチャンクが現れるまでストリームを進めた上で、レスポンスを整形しています。具体的には、ストリームのdeltaにcontent属性があるかどうか、tool_call属性があるかどうかで判定してOpenAIStreamStructuredResponse
またはOpenAIStreamTextResponse
のいずれかを返しています。
response = await agent_model._completions_create(messages, True) async with response: processed_response = await agent_model._process_streamed_response(response) processed_response # > OpenAIStreamStructuredResponse(_response=<openai.AsyncStream object at 0x7fc6704dcbb0>, _delta_tool_calls={0: ChoiceDeltaToolCall(index=0, id='call_7QOFcfPFL5eFQbeErlyJbOfP', function=ChoiceDeltaToolCallFunction(arguments='', name='final_result'), type='function')}, _timestamp=datetime.datetime(2024, 12, 9, 3, 46, 16, tzinfo=datetime.timezone.utc), _cost=Cost(request_tokens=None, response_tokens=None, total_tokens=None, details=None))
ここではtool_callがストリームレスポンスの中に見つかったので、OpenAIStreamStructuredResponse
が返ってきています。この返り値は次に
final_result, response_messages = await self._handle_streamed_model_response(
model_response, deps
)
という処理に回されます。この関数の中では、
if self._result_schema is not None: # if there's a result schema, iterate over the stream until we find at least one tool # NOTE: this means we ignore any other tools called here structured_msg = model_response.get() while not structured_msg.calls: try: await model_response.__anext__() except StopAsyncIteration: break structured_msg = model_response.get() if match := self._result_schema.find_tool(structured_msg): call, _ = match tool_return = _messages.ToolReturn( tool_name=call.tool_name, content='Final result processed.', tool_id=call.tool_id, ) return _MarkFinalResult(model_response), [tool_return]
という処理が走っており、tool_callが見つかったら対応するツールがあるかをチェックし、あればレスポンスを_MarkFinalResutl
クラスでラップして返します。なお、ここでreturnしなかった場合はストリームを最後まで消費して完全なレスポンスを取得した上でクライアントAPIで呼び出されたツールを全て実行して結果をmessagesに追加しています。
# the model is calling a tool function, consume the response to get the next message async for _ in model_response: pass structured_msg = model_response.get() if not structured_msg.calls: raise exceptions.UnexpectedModelBehavior('Received empty tool call message') messages: list[_messages.Message] = [structured_msg] # we now run all tool functions in parallel tasks: list[asyncio.Task[_messages.Message]] = [] for call in structured_msg.calls: if tool := self._function_tools.get(call.tool_name): tasks.append(asyncio.create_task(tool.run(deps, call), name=call.tool_name)) else: messages.append(self._unknown_tool(call.tool_name)) with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): task_results: Sequence[_messages.Message] = await asyncio.gather(*tasks) messages.extend(task_results) return None, messages
こうして取得されたレスポンスが最終的にStreamedRunResult
としてreturnされます。ここまでの処理がちょうど以下の部分ですね。
user_prompt = ( "My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid." ) async with agent.run_stream(user_prompt) as response:
さて、このStreamdRunResult
はstream_text()
とstream_structured()
というメソッドをもっていて、.stream()
という処理はレスポンスがテキストレスポンスか構造化されたレスポンスかでどちらを実行するかを選択しています。ここではstream_structured()
が実行されます。
async for structured_message, is_last in self.stream_structured(debounce_by=debounce_by): yield await self.validate_structured_result(structured_message, allow_partial=not is_last)
この関数を見るとストリーミングのある時点までの全てのチャンクを結合したレスポンスに対して都度バリデーションをかけていることが分かります。なお、出力が長くなるケースを考慮して内部でチャンクをgroupbyするなど、処理の負荷を軽減する仕組みも入っているようです(それを制御するのがdebounce_by
引き数)。is_lastがFalseの場合はallow_partial
という引数がTrueになるようなのでここをさらに見ていきます。
まずはstream_structured()
を実行して各ストリームの時点までのチャンクを結合した出力を取得します。
user_prompt = ( "My name is Ben, I was born on January 28th 1990, I like the chain the dog and the pyramid." ) async with agent.run_stream(user_prompt) as _response: async for message, is_last in _response.stream_structured(): print("raw: ", message.calls[0].args.args_json) print( "validated: ", await _response.validate_structured_result(message, allow_partial=not is_last), ) print("-" * 50) # > raw: {"name":"Ben","dob":"1990- # > validated: {'name': 'Ben'} # > -------------------------------------------------- # > raw: {"name":"Ben","dob":"1990-01-28","bio":"I like the chain the dog and the # > validated: {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'I like the chain the dog and the'} # > -------------------------------------------------- # > raw: {"name":"Ben","dob":"1990-01-28","bio":"I like the chain the dog and the pyramid."} # > validated: {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'I like the chain the dog and the pyramid.'} # > -------------------------------------------------- # > raw: {"name":"Ben","dob":"1990-01-28","bio":"I like the chain the dog and the pyramid."} # > validated: {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'I like the chain the dog and the pyramid.'} # > --------------------------------------------------
このストリームの最初のメッセージを見ると、'{"name":"Ben","dob":"1990-'
が返ってきているものの、バリデーションの出力は{'name': 'Ben'}
となっています。dob
属性はdatetimeであることを期待していますが、出力の"1990-
だけからはdatetimeであるかどうかはまだ判別出来ません。そこで、確実に期待する型を出力しているname
属性に絞ってバリデーションを行って出力する、というのがPydanticの部分検証という機能です。
繰り返しですが、この機能はまだexperimentalです。詳細はこちらをご参照ください。
validate_structured_result()
メソッドの処理としては、ツールの名前からtool_callとresult_toolのペアを取得し前述と同様にresult_toolのvalidate()
メソッド内で、validate_json()
をしています。この時、experimental_allow_partial
引き数を指定することで部分検証を行っています。この部分の処理を簡潔に表すと以下のようになります。
async with agent.run_stream(user_prompt) as _response: async for message, is_last in _response.stream_structured(): arg_stream = message.calls[0].args.args_json validated_args = TypeAdapter(UserProfile).validate_json( arg_stream, experimental_allow_partial="trailing-strings" ) print("raw: ", arg_stream) print("validated: ", validated_args) print("-" * 50) # > raw: {"name":"Ben","dob":"1990-01- # > validated: {'name': 'Ben'} # > -------------------------------------------------- # > raw: {"name":"Ben","dob":"1990-01-28","bio":"likes the chain the dog and # > validated: {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'likes the chain the dog and'} # > -------------------------------------------------- # > raw: {"name":"Ben","dob":"1990-01-28","bio":"likes the chain the dog and the pyramid"} # > validated: {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'likes the chain the dog and the pyramid'} # > -------------------------------------------------- # > raw: {"name":"Ben","dob":"1990-01-28","bio":"likes the chain the dog and the pyramid"} # > validated: {'name': 'Ben', 'dob': datetime.date(1990, 1, 28), 'bio': 'likes the chain the dog and the pyramid'} # > --------------------------------------------------
前述の結果と同じようになりましたね(ストリームする度にチャンクが変わるので厳密に同じではないですが、挙動が同じであることは確認出来ます)。 このように、Pydanticの部分検証を使うことでストリーミング出力であってもチャンクを部分的に検証しながら型安全な構造化出力を得ることが出来ます。
ツール実行を型安全にする
何はともあれまずはツールを与えて実行してみましょう。公式ドキュメントから一部引用したサイコロを振るツールを与えます。
import random from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIModel model = OpenAIModel("gpt-4o") agent = Agent(model) @agent.tool_plain def roll_die() -> str: """Roll a six-sided die and return the result.""" return str(random.randint(1, 6)) dice_result = await agent.run("Please roll") print(dice_result.data) # > The result of the die roll is 4.
メッセージ履歴を見てみます。
dice_result.all_messages() # > [UserPrompt(content='Please roll', timestamp=datetime.datetime(2024, 12, 9, 3, 55, 52, 6355, tzinfo=datetime.timezone.utc), role='user'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='roll_die', args=ArgsJson(args_json='{}'), tool_id='call_1ZVWog5Pz3FOYYjlyyHTlZsB')], timestamp=datetime.datetime(2024, 12, 9, 3, 55, 53, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='roll_die', content='4', tool_id='call_1ZVWog5Pz3FOYYjlyyHTlZsB', timestamp=datetime.datetime(2024, 12, 9, 3, 55, 52, 709649, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelTextResponse(content='The result of the die roll is 4.', timestamp=datetime.datetime(2024, 12, 9, 3, 55, 53, tzinfo=datetime.timezone.utc), role='model-text-response')]
流れとしては、
- ツールとユーザープロンプトを入力する
- OpenAI APIのツール呼び出しが発動してtool_callのレスポンスが返って来る
- ツールの実行結果を取得する
- 再度APIをコールし、OpenAI APIからのレスポンスを最終結果として出力する
という感じです。これ自体はまぁ普通ですね。順を追ってみていきましょう。まず、Agentへのツールの渡し方は
- Agentのコンストラクタに渡す
- agent.toolデコレータで関数をラップする
の2通りがありますが、いずれの場合もツールはTool
というクラスのインスタンスに変換されます。変換されたツールは_function_tools
属性に保持されています。実際に見てみましょう。
agent._function_tools
# > {'roll_die': Tool(function=<function roll_die at 0x7fc671004310>, takes_ctx=False, max_retries=1, name='roll_die', description='Roll a six-sided die and return the result.', prepare=None, _is_async=False, _single_arg_name=None, _positional_fields=[], _var_positional_field=None, _parameters_json_schema={'description': 'Roll a six-sided die and return the result.', 'properties': {}, 'type': 'object', 'additionalProperties': False}, current_retry=0)}
Agentクラスインスタンスの内部ではこのように_function_tools
属性として、dict[str, Tool]
の形でツールが保持されています。このTool
とModelStructuredResponse
からツールを実行する手順は簡素化すると以下のようなコードになります。
tool = agent._function_tools["roll_die"] model_structured_response = dice_result.all_messages()[1] tool_call = model_structured_response.calls[0] # ツールを実行 await tool.run(deps=None, message=tool_call) # > ToolReturn(tool_name='roll_die', content='2', tool_id='call_1ZVWog5Pz3FOYYjlyyHTlZsB', timestamp=datetime.datetime(2024, 12, 9, 3, 59, 26, 325568, tzinfo=datetime.timezone.utc), role='tool-return')
そしてこの結果をメッセージに加えて再度APIコールした結果、テキストレスポンスが返ってくれば最終出力となり、前述のようにModelTextResponse
として吐き出されます。
このTool
クラスは実行メソッドrun()
を持っており、このrun()
メソッドの中で、前述の構造化出力の所と同様にtool_callで生成された引数をバリデーションしています。それがちょうど以下の部分です。
async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message: """Run the tool function asynchronously.""" try: if isinstance(message.args, messages.ArgsJson): args_dict = self._validator.validate_json(message.args.args_json) else: args_dict = self._validator.validate_python(message.args.args_dict) except ValidationError as e: return self._on_error(e, message) ...
基本的な部分はこれだけです。
依存関係の注入
ところでPydanticAIの特徴の一つにDependenciesという概念があります。公式ドキュメントには以下のように書かれています。
PydanticAI uses a dependency injection system to provide data and services to your agent's system prompts, tools and result validators.
(DeepL訳)PydanticAIは、依存性注入システムを使用して、エージェントのシステムプロンプト、ツール、結果バリデーターにデータとサービスを提供します。
ツールは実行のための引数をLLMで生成しますが、depsはそうではなく外部から特定の値をツールやシステムプロンプトに注入する役割を担います。
基本動作の確認と注意点
from pydantic_ai import Agent, RunContext from pydantic_ai.models.openai import OpenAIModel model = OpenAIModel("gpt-4o") agent = Agent(model) @agent.tool def get_player_name(ctx: RunContext[str]) -> str: """Get the player's name.""" return ctx.deps result = await agent.run("Who is a player?", deps="Anne") print(result.data) # > A player is named Anne.
RunContext
というのはGenericを継承しているデータクラスで、実装は以下のようになっています。
@dataclass class RunContext(Generic[AgentDeps]): """Information about the current call.""" deps: AgentDeps """Dependencies for the agent.""" retry: int """Number of retries so far.""" tool_name: str | None = None """Name of the tool being called."""
メッセージ履歴を見てみましょう。
result.all_messages() # > [UserPrompt(content='Who is a player?', timestamp=datetime.datetime(2024, 12, 9, 4, 1, 37, 114535, tzinfo=datetime.timezone.utc), role='user'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_name', args=ArgsJson(args_json='{}'), tool_id='call_zBgR9MImvcEtArMOdMNOGETH')], timestamp=datetime.datetime(2024, 12, 9, 4, 1, 38, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='get_player_name', content='Anne', tool_id='call_zBgR9MImvcEtArMOdMNOGETH', timestamp=datetime.datetime(2024, 12, 9, 4, 1, 37, 924530, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelTextResponse(content='A player is named Anne.', timestamp=datetime.datetime(2024, 12, 9, 4, 1, 39, tzinfo=datetime.timezone.utc), role='model-text-response')]
2つめのメッセージのModelStructuredResponse
を見て分かる通り、ツールは呼ばれているものの引数は生成されていません(定義していないので当然)。しかし後段のToolReturn
にはツールの返り値としてdepsで入力した文字列が出力されています。実装の中身を見ると、キーワード引数や位置引数に加えてdepsがツールの引数として追加されていることが分かります。
# ツールの実行時に引数をまとめている部分のコード async def run(self, deps: AgentDeps, message: messages.ToolCall) -> messages.Message: """Run the tool function asynchronously.""" try: if isinstance(message.args, messages.ArgsJson): args_dict = self._validator.validate_json(message.args.args_json) else: args_dict = self._validator.validate_python(message.args.args_dict) except ValidationError as e: return self._on_error(e, message) args, kwargs = self._call_args(deps, args_dict, message) # _call_args def _call_args( self, deps: AgentDeps, args_dict: dict[str, Any], message: messages.ToolCall ) -> tuple[list[Any], dict[str, Any]]: if self._single_arg_name: args_dict = {self._single_arg_name: args_dict} args = [RunContext(deps, self.current_retry, message.tool_name)] if self.takes_ctx else [] for positional_field in self._positional_fields: args.append(args_dict.pop(positional_field)) if self._var_positional_field: args.extend(args_dict.pop(self._var_positional_field)) return args, args_dict
ちなみにdepsには任意のPythonオブジェクトをRunContext
を介して渡すことが出来ます。
@dataclass class Player: name: str model = OpenAIModel("gpt-4o") agent = Agent(model) @agent.tool def get_player_name(ctx: RunContext[Player]) -> str: """Get the player's name.""" return ctx.deps.name result = await agent.run("Who is a player?", deps=Player(name="Anne")) print(result.data) # > A player is Anne.
注意点として、depsはargsと違って実行時に型のバリデーションが行われません。実際、前述のコードを見てもdepsはargsがバリデーションされた後にツールの引数に追加されています。型違反の場合は静的型チェックの際にエラーになりますが、strやdataclassを指定する場合は実行時には別の型を渡せてしまいます。以下がその例です。
from pydantic_ai import Agent, RunContext from pydantic_ai.models.openai import OpenAIModel model = OpenAIModel("gpt-4o") # depsにstr以外の値を渡す agent1 = Agent(model) @agent1.tool def get_player_name(ctx: RunContext[str]) -> str: """Get the player's name.""" return ctx.deps result1 = await agent1.run("Who is a player?", deps=3) print(result1.data) # > The name of the player is "3". # nameにstr以外の値を渡す agent2 = Agent(model) @agent2.tool def get_player_name(ctx: RunContext[Player]) -> str: """Get the player's name.""" return ctx.deps.name result2 = await agent2.run("Who is a player?", deps=Player(name=3)) print(result2.data) # > The player's name is "3." If you have a specific context or need more details about the player, please let me know!
この依存関係の注入自体にはコードの再利用性を高めたりテストの実行を容易にするという恩恵があるのですが、このように予期せぬ入力値で実行されてしまうのを防ぎたい場合はdepsをPydanticモデルで指定するか、もしくはresult_typeやresult_validatorと組み合わせることになりそうです。
Agentの依存関係としてのAgent
depsには別の用途もあります。前述のようにdepsには任意のPythonオブジェクトを指定出来るので、例えば他のAgentへの依存を定義すればMulti Agentを構築することが出来ます。ただしこの用途でのベストプラクティスはまだ分かりません。公式ドキュメントにはdepsとしてAgentを渡すようなサンプルコードがありましたが、あまり直感的ではない気がします。以下がそのサンプルコードです。
from dataclasses import dataclass from pydantic_ai import Agent, RunContext @dataclass class MyDeps: factory_agent: Agent[None, list[str]] joke_agent = Agent( 'openai:gpt-4o', deps_type=MyDeps, system_prompt=( 'Use the "joke_factory" to generate some jokes, then choose the best. ' 'You must return just a single joke.' ), ) factory_agent = Agent('gemini-1.5-pro', result_type=list[str]) @joke_agent.tool async def joke_factory(ctx: RunContext[MyDeps], count: int) -> str: r = await ctx.deps.factory_agent.run(f'Please generate {count} jokes.') return '\n'.join(r.data) result = joke_agent.run_sync('Tell me a joke.', deps=MyDeps(factory_agent)) print(result.data) #> Did you hear about the toothpaste scandal? They called it Colgate.
素直にagentを分けてmessagesを介してアクセスさせる方がシンプルになる気はします。以下は、Agentのルーティングをしつつ必要な情報を集めて回答するAgentを雑に作るならこんな感じ…?を試した例です。
import random from typing import Any, Callable, Literal from pydantic import BaseModel, ConfigDict from pydantic_ai import Agent, ModelRetry, RunContext from pydantic_ai.messages import Message from pydantic_ai.models import Model from pydantic_ai.models.openai import OpenAIModel #################### # Define the tools #################### def get_player_info() -> str: """ Get information of players who will face off in the finals. Returns: ------- str The information of players who will face off in the finals. """ players = random.choice( [["Anne", "Bob"], ["Charlie", "David"], ["Eve", "Frank"], ["Grace", "Hank"]] ) return f"{players[0]} and {players[1]} will match up." def get_result(players: list[str]) -> str: """ Get the result of the match. Parameters: ---------- players: list[str] The list of players name who will face off in the finals. Returns: ------- str The result of the match. """ idx_winner = random.choice([0, 1]) idx_loser = 1 - idx_winner return f"{players[idx_winner]} wins against {players[idx_loser]}." def get_ticket_info() -> str: """ Get the ticket information including the price and remaining tickets. Returns: ------- str The ticket information including the price and remaining tickets. """ return "The ticket price is $100 and remaining tickets are 10." #################### # Define the context #################### class AgentContext(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) messages: list[Message] = [] game_info_model: Model tools_game_info_agent: list[Callable[[Any], str]] ticket_info_model: Model tools_ticket_info_agent: list[Callable[[Any], str]] #################### # Define the result schema #################### class MatchResult(BaseModel): winner: str loser: str #################### # Define the type of the agent #################### AGENT_NAME = Literal["game_info_agent", "ticket_info_agent"] #################### # Define the tools for the commander #################### def route_agent(agent_name: AGENT_NAME, task_planning: list[str]) -> str: """ Route the agent and generate plan to solve the question. The feature of each agent is as follows: - game_info_agent: get information about the game. - ticket_info_agent: get information about the ticket. The tasks of the game_info_agent: - get_player_info: get information of players who will face off in the finals. - get_result: get the result of the match between the specified two players. The tasks of the ticket_info_agent: - get_ticket_info: get the ticket information including the price and remaining tickets. """ return f"{agent_name} agent will run. Task planning: {', '.join(task_planning)}" async def run_agent( ctx: RunContext[AgentContext], agent_name: AGENT_NAME, task_planning: list[str] ) -> str: """ Run the agent to get information to answer the user's question. Parameters: ---------- agent_name: AGENT_NAME The name of the agent. task_planning: list[str] The list of tasks to run. Returns: ------- str The result of the agent. """ context = ctx.deps # Run the agent agent = Agent( model=getattr(context, agent_name.replace("_agent", "_model")), tools=getattr(context, f"tools_{agent_name}"), ) result = await agent.run( f"Run the following tasks: {', '.join(task_planning)}", message_history=context.messages ) # Update the context context.messages.extend(result.new_messages()) return result.data #################### # Initialize the commander agent #################### commander = Agent( model=OpenAIModel("gpt-4o"), tools=[route_agent, run_agent], result_type=MatchResult, ) #################### # Initialize the context #################### context = AgentContext( game_info_model=OpenAIModel("gpt-4o"), tools_game_info_agent=[get_player_info, get_result], ticket_info_model=OpenAIModel("gpt-4o"), tools_ticket_info_agent=[get_ticket_info], ) #################### # Run the commander agent #################### result = await commander.run( "Tell me the result of the finals.", deps=context, ) print(result.data) # > winner='Charlie' loser='David'
ベストプラクティスかどうかは分かりませんが意外とシンプルに実装出来た気がします。履歴を見てみましょう。
result.all_messages() # > [UserPrompt(content='Tell me the result of the finals.', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 16, 277483, tzinfo=datetime.timezone.utc), role='user'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='route_agent', args=ArgsJson(args_json='{"agent_name":"game_info_agent","task_planning":["get_player_info"]}'), tool_id='call_xSjkApdkzPzPZC27cPIjPGrX')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 17, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='route_agent', content='game_info_agent agent will run. Task planning: get_player_info', tool_id='call_xSjkApdkzPzPZC27cPIjPGrX', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 17, 148880, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='run_agent', args=ArgsJson(args_json='{"agent_name":"game_info_agent","task_planning":["get_player_info"]}'), tool_id='call_XmQbSeelRu4oRqpYKrXiOoRk')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 18, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='run_agent', content='The players who will face off in the finals are Grace and Hank.', tool_id='call_XmQbSeelRu4oRqpYKrXiOoRk', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 19, 158365, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='route_agent', args=ArgsJson(args_json='{"agent_name":"game_info_agent","task_planning":["get_result"]}'), tool_id='call_uYzcAWnskGvQO38SESCMZn2Y')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 20, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='route_agent', content='game_info_agent agent will run. Task planning: get_result', tool_id='call_uYzcAWnskGvQO38SESCMZn2Y', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 19, 957500, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='run_agent', args=ArgsJson(args_json='{"agent_name":"game_info_agent","task_planning":["get_result"]}'), tool_id='call_wu9oL54Yy7yVIa7UF1M5HW3N')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 21, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='run_agent', content='The result of the match is that Hank wins against Grace.', tool_id='call_wu9oL54Yy7yVIa7UF1M5HW3N', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 22, 105317, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='final_result', args=ArgsJson(args_json='{"winner":"Hank","loser":"Grace"}'), tool_id='call_VoEdCIvRVHwwx7k8QpEXCOoF')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 23, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='final_result', content='Final result processed.', tool_id='call_VoEdCIvRVHwwx7k8QpEXCOoF', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 23, 38888, tzinfo=datetime.timezone.utc), role='tool-return')]
コンテキストのメッセージも見てみます。
context.messages # > [UserPrompt(content='Run the following tasks: get_player_info', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 17, 999061, tzinfo=datetime.timezone.utc), role='user'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='get_player_info', args=ArgsJson(args_json='{}'), tool_id='call_prAJ8r0wtJgg8tAnRttOZLLe')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 19, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='get_player_info', content='Grace and Hank will match up.', tool_id='call_prAJ8r0wtJgg8tAnRttOZLLe', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 18, 517051, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelTextResponse(content='The players who will face off in the finals are Grace and Hank.', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 19, tzinfo=datetime.timezone.utc), role='model-text-response'), # > UserPrompt(content='Run the following tasks: get_result', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 20, 911440, tzinfo=datetime.timezone.utc), role='user'), # > ModelStructuredResponse(calls=[ToolCall(tool_name='get_result', args=ArgsJson(args_json='{"players":["Grace","Hank"]}'), tool_id='call_wu9oL54Yy7yVIa7UF1M5HW3N')], timestamp=datetime.datetime(2024, 12, 9, 4, 11, 22, tzinfo=datetime.timezone.utc), role='model-structured-response'), # > ToolReturn(tool_name='get_result', content='Hank wins against Grace.', tool_id='call_wu9oL54Yy7yVIa7UF1M5HW3N', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 21, 511966, tzinfo=datetime.timezone.utc), role='tool-return'), # > ModelTextResponse(content='The result of the match is that Hank wins against Grace.', timestamp=datetime.datetime(2024, 12, 9, 4, 11, 22, tzinfo=datetime.timezone.utc), role='model-text-response')]
game_info_agent
のget_player_info
ツールを実行したもののそれだけでは足りないので再度game_info_agent
のget_result
ツールを実行して最終回答に至っていますね。また、プロンプトに関係のないticket_info_agent
は実行されておらず、適切にAgentのルーティングが出来ていそうです。ツールが増えてきたらルーティングを階層的にするなどするとマルチエージェントな何かが作れそうです。あるAgentは特定の他のAgentからの出力を受け取りたい場合もあると思うのでもうちょっとちゃんと会話履歴とコンテキストのハンドリングをすると良さそうです。
その他の機能
テストの所とかも書こうと思っていたのですが力尽きたのでドキュメントのリンクを貼っておきます。テスト向けにモデルや依存関係をoverrideするデコレータや、ツールのスキーマに沿ったダミーのレスポンスを生成するダミーのモデルTestModel
なんかもあって便利そうです(柔軟にテストしたい場合にはFunctionModel
でカスタマイズ出来ます)。
PydanticAIを深掘りしてみた感想
メリットとしては基本的には「型安全」ということに尽きると思います。型安全なAgentが比較的シンプルに実装出来る点も良いなと思いました。また、マルチエージェントもなんかその気になればわりとシンプルに実装出来そうということが分かりました。デメリットとしては、外側のコードはシンプルに記述出来る一方で、内部のコードはやや複雑でした。サポートされていないモデルをカスタムでサポートしたい場合に何のメソッドを実装すればいいかが自明ではなく、カスタマイズはしにくいなと思いました。また、リトライ時のプロンプトなど英語のプロンプトでハードコーディングされている部分もあるので、日本語を扱う際にその辺りがどう影響するのか未知数ではあります。
PydanticAIは開発が盛んで現状は数日おきにアプデが入っています。v0.0.12でOllamaがサポートされたり、動的にツールの実行可否を制御したりする機能が追加されました。また、AnthropicやMistralのサポート、Model Context Protocolのサポート等も進んでいるようなので、非常に楽しみです。
We Are Hiring!
ABEJAは、テクノロジーの社会実装に取り組んでいます。 技術はもちろん、技術をどのようにして社会やビジネスに組み込んでいくかを考えるのが好きな方は、下記採用ページからエントリーください! (新卒の方やインターンシップのエントリーもお待ちしております!)
特に下記ポジションの募集を強化しています!ぜひ御覧ください!
プラットフォームグループ:シニアソフトウェアエンジニア | 株式会社ABEJA