SEが最近起こったことを書くブログ

ITエンジニアが試したこと、気になったことを書いていきます。

Gemini APIでLangGraphのチュートリアルをGoogle Colabでやってみた

LangGraphのチュートリアルをGemini APIで実行してみたので、メモ

事前準備

実行したコード

! pip install langgraph langchain==0.1.9 langchain-google-genai pillow

まずはツールなしのシンプルなグラフ作成

# 環境変数の準備 (左端の鍵アイコンでGOOGLE_API_KEYを設定)
import os
from google.colab import userdata
os.environ["GOOGLE_API_KEY"] = userdata.get("GOOGLE_API_KEY")

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from langgraph.graph import END, MessageGraph



# 利用するモデルをGeminiに変更
model  = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.9)
# model  = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0.9) # Gemini1.5を使いたい場合はこちら


graph = MessageGraph()

graph.add_node("oracle", model)
graph.add_edge("oracle", END)

graph.set_entry_point("oracle")

runnable = graph.compile()

作成したグラフの実行

runnable.invoke(HumanMessage("What is 1 + 1?"))

必要な場合にmultiplyを呼び出す例

import json
from langchain_core.messages import ToolMessage,BaseMessage
from langchain_core.tools import tool

# チュートリアルにはなかったが追加
from typing import List

@tool
def multiply(first_number: int, second_number: int) -> int:
    """Multiplies two numbers together."""
    return first_number * second_number



# Geminiを利用するように変更
model  = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.9)
# model  = ChatGoogleGenerativeAI(model="gemini-1.5-pro-latest", temperature=0.9) # Gemini1.5を使いたい場合はこちら

# https://python.langchain.com/docs/modules/model_io/chat/function_calling/のGoogleの例を参考に変更
model_with_tools = model.bind_tools([multiply])

graph = MessageGraph()

def invoke_model(state: List[BaseMessage]):
    return model_with_tools.invoke(state)

graph.add_node("oracle", invoke_model)

def invoke_tool(state: List[BaseMessage]):
    # idが含まれるのがadditional_kwargsではなく、tool_callsだったので変更
    tool_calls = state[-1].tool_calls
    multiply_call = None

    for tool_call in tool_calls:
        # tool_callsの形式に合わせて変更
        if tool_call.get("name") == "multiply":
            multiply_call = tool_call

    if multiply_call is None:
        raise Exception("No adder input found.")

    res = multiply.invoke(
        # 引数名をあわせる
        multiply_call.get("args")
    )
    return ToolMessage(
        tool_call_id=multiply_call["id"],
        content=res
    )

graph.add_node("multiply", invoke_tool)

graph.add_edge("multiply", END)

graph.set_entry_point("oracle")

LLMがツール呼び出しをするように指示した場合のみ、ツール呼び出しするように修正

def router(state: List[BaseMessage]):
    # ここもtool_callsに変更
    tool_calls = state[-1].tool_calls
    if len(tool_calls):
        return "multiply"
    else:
        return "end"

graph.add_conditional_edges("oracle", router, {
    "multiply": "multiply",
    "end": END,
})

グラフをコンパイルし、実行

runnable = graph.compile()

runnable.invoke(HumanMessage("What is 123 * 456?"))

参考

python.langchain.com

python.langchain.com

note.com

メモ

以下の記事を読む限りだと、bind_toolsを使った場合にtool_callsを使うこと自体はおかしくなさそう blog.langchain.dev