LangGraphのチュートリアルをGemini APIで実行してみたので、メモ
事前準備
実行したコード
! pip install langgraph langchain==0.1.9 langchain-google-genai pillow
まずはツールなしのシンプルなグラフ作成
# 環境変数の準備 (左端の鍵アイコンでGOOGLE_API_KEYを設定) import os from google.colab import userdata os.environ["GOOGLE_API_KEY"] = userdata.get("GOOGLE_API_KEY") from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import HumanMessage from langgraph.graph import END, MessageGraph # 利用するモデルをGeminiに変更 model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.9) # model = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0.9) # Gemini1.5を使いたい場合はこちら graph = MessageGraph() graph.add_node("oracle", model) graph.add_edge("oracle", END) graph.set_entry_point("oracle") runnable = graph.compile()
作成したグラフの実行
runnable.invoke(HumanMessage("What is 1 + 1?"))
必要な場合にmultiplyを呼び出す例
import json from langchain_core.messages import ToolMessage,BaseMessage from langchain_core.tools import tool # チュートリアルにはなかったが追加 from typing import List @tool def multiply(first_number: int, second_number: int) -> int: """Multiplies two numbers together.""" return first_number * second_number # Geminiを利用するように変更 model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.9) # model = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0.9) # Gemini1.5を使いたい場合はこちら # https://python.langchain.com/docs/modules/model_io/chat/function_calling/のGoogleの例を参考に変更 model_with_tools = model.bind_tools([multiply]) graph = MessageGraph() def invoke_model(state: List[BaseMessage]): return model_with_tools.invoke(state) graph.add_node("oracle", invoke_model) def invoke_tool(state: List[BaseMessage]): # idが含まれるのがadditional_kwargsではなく、tool_callsだったので変更 tool_calls = state[-1].tool_calls multiply_call = None for tool_call in tool_calls: # tool_callsの形式に合わせて変更 if tool_call.get("name") == "multiply": multiply_call = tool_call if multiply_call is None: raise Exception("No adder input found.") res = multiply.invoke( # 引数名をあわせる multiply_call.get("args") ) return ToolMessage( tool_call_id=multiply_call["id"], content=res ) graph.add_node("multiply", invoke_tool) graph.add_edge("multiply", END) graph.set_entry_point("oracle")
LLMがツール呼び出しをするように指示した場合のみ、ツール呼び出しするように修正
def router(state: List[BaseMessage]): # ここもtool_callsに変更 tool_calls = state[-1].tool_calls if len(tool_calls): return "multiply" else: return "end" graph.add_conditional_edges("oracle", router, { "multiply": "multiply", "end": END, })
グラフをコンパイルし、実行
runnable = graph.compile()
runnable.invoke(HumanMessage("What is 123 * 456?"))
参考
メモ
以下の記事を読む限りだと、bind_toolsを使った場合にtool_callsを使うこと自体はおかしくなさそう blog.langchain.dev