{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from langchain.chains import LLMMathChain\n", "from langchain_openai import ChatOpenAI\n", "from langchain.agents import AgentExecutor, create_openai_tools_agent\n", "from langchain import hub\n", "from pydantic import BaseModel, Field\n", "from langchain.tools import BaseTool\n", "from typing import Type\n", "from hellocomputer.config import settings\n", "from hellocomputer.models import AvailableModels\n", "\n", "# Get the prompt to use - you can modify this!\n", "prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n", "\n", "llm = ChatOpenAI(\n", " base_url=settings.llm_base_url,\n", " api_key=settings.llm_api_key,\n", " model=AvailableModels.firefunction_2,\n", " temperature=0.5,\n", " max_tokens=256,\n", ")\n", "\n", "math_llm = ChatOpenAI(\n", " base_url=settings.llm_base_url,\n", " api_key=settings.llm_api_key,\n", " model=AvailableModels.mixtral_8x7b,\n", " temperature=0.5,\n", ")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class CalculatorInput(BaseModel):\n", " query: str = Field(description=\"should be a math expression\")\n", "\n", "class CustomCalculatorTool(BaseTool):\n", " name: str = \"Calculator\"\n", " description: str = \"Tool to evaluate mathemetical expressions\"\n", " args_schema: Type[BaseModel] = CalculatorInput\n", "\n", " def _run(self, query: str) -> str:\n", " \"\"\"Use the tool.\"\"\"\n", " return LLMMathChain.from_llm(llm=math_llm, verbose=True).invoke(query)\n", "\n", " async def _arun(self, query: str) -> str:\n", " \"\"\"Use the tool asynchronously.\"\"\"\n", " return LLMMathChain.from_llm(llm=math_llm, verbose=True).ainvoke(query)\n", "\n", "tools = [\n", " CustomCalculatorTool()\n", "]\n", "\n", "agent = create_openai_tools_agent(llm, tools, prompt)\n", "\n", "agent = AgentExecutor(agent=agent, tools=tools, verbose=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3mThe capital of USA is Washington, D.C.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "{'input': 'What is the capital of USA?', 'output': 'The capital of USA is Washington, D.C.'}\n" ] } ], "source": [ "print(agent.invoke({\"input\": \"What is the capital of USA?\"}))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", "\u001b[32;1m\u001b[1;3m\n", "Invoking: `Calculator` with `{'query': '100/25'}`\n", "\n", "\n", "\u001b[0m\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "100/25\u001b[32;1m\u001b[1;3m```text\n", "100 / 25\n", "```\n", "...numexpr.evaluate(\"100 / 25\")...\n", "\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m4.0\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "\u001b[36;1m\u001b[1;3m{'question': '100/25', 'answer': 'Answer: 4.0'}\u001b[0m\u001b[32;1m\u001b[1;3mThe answer is 4.0.\u001b[0m\n", "\n", "\u001b[1m> Finished chain.\u001b[0m\n", "{'input': 'What is 100 divided by 25?', 'output': 'The answer is 4.0.'}\n" ] } ], "source": [ "print(agent.invoke({\"input\": \"What is 100 divided by 25?\"}))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Now let's modify this code and make it a Graph" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from langgraph.prebuilt import ToolNode\n", "from langchain_core.messages import AIMessage\n", "\n", "tool_node = ToolNode(tools=tools)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "50 / 2\u001b[32;1m\u001b[1;3m```text\n", "50 / 2\n", "```\n", "...numexpr.evaluate(\"50 / 2\")...\n", "\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m25.0\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n" ] }, { "data": { "text/plain": [ "{'messages': [ToolMessage(content='{\"question\": \"50 / 2\", \"answer\": \"Answer: 25.0\"}', name='Calculator', tool_call_id='tool_call_id')]}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Call the tools node manually\n", "\n", "message_with_single_tool_call = AIMessage(\n", " content=\"\",\n", " tool_calls=[\n", " {\n", " \"name\": \"Calculator\",\n", " \"args\": {\"query\": \"50 / 2\"},\n", " \"id\": \"tool_call_id\",\n", " \"type\": \"tool_call\",\n", " }\n", " ],\n", ")\n", "\n", "tool_node.invoke({\"messages\": [message_with_single_tool_call]})" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Bind the tools to the agent so it knows what to call\n", "\n", "agent = ChatOpenAI(\n", " base_url=\"https://api.fireworks.ai/inference/v1\",\n", " api_key=\"bGWp5ErQNI7rP8GOcGBJmyC5QMV7z8UdBpLAseTaxhbAk6u1\",\n", " model=\"accounts/fireworks/models/firefunction-v2\",\n", " temperature=0.5,\n", " max_tokens=256,\n", ").bind_tools(tools)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'name': 'Calculator',\n", " 'args': {'query': '234/7'},\n", " 'id': 'call_mP5fctn8N6vilM1Yewb5QxRY',\n", " 'type': 'tool_call'}]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "agent.invoke(\"234/7\").tool_calls" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from typing import Literal\n", "\n", "from langgraph.graph import StateGraph, MessagesState\n", "\n", "\n", "def should_continue(state: MessagesState) -> Literal[\"tools\", \"__end__\"]:\n", " messages = state[\"messages\"]\n", " last_message = messages[-1]\n", " if last_message.tool_calls:\n", " return \"tools\"\n", " return \"__end__\"\n", "\n", "\n", "def call_model(state: MessagesState):\n", " messages = state[\"messages\"]\n", " response = agent.invoke(messages)\n", " return {\"messages\": [response]}\n", "\n", "\n", "workflow = StateGraph(MessagesState)\n", "\n", "# Define the two nodes we will cycle between\n", "workflow.add_node(\"agent\", call_model)\n", "workflow.add_node(\"tools\", tool_node)\n", "\n", "workflow.add_edge(\"__start__\", \"agent\")\n", "workflow.add_conditional_edges(\n", " \"agent\",\n", " should_continue,\n", ")\n", "workflow.add_edge(\"tools\", \"agent\")\n", "\n", "app = workflow.compile()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "================================\u001b[1m Human Message \u001b[0m=================================\n", "\n", "What's twelve times twelve?\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "Tool Calls:\n", " Calculator (call_dEdudo8txWMpLKMek070TIWd)\n", " Call ID: call_dEdudo8txWMpLKMek070TIWd\n", " Args:\n", " query: 12 * 12\n", "\n", "\n", "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", "12 * 12\u001b[32;1m\u001b[1;3m```text\n", "12 * 12\n", "```\n", "...numexpr.evaluate(\"12 * 12\")...\n", "\u001b[0m\n", "Answer: \u001b[33;1m\u001b[1;3m144\u001b[0m\n", "\u001b[1m> Finished chain.\u001b[0m\n", "=================================\u001b[1m Tool Message \u001b[0m=================================\n", "Name: Calculator\n", "\n", "{\"question\": \"12 * 12\", \"answer\": \"Answer: 144\"}\n", "==================================\u001b[1m Ai Message \u001b[0m==================================\n", "\n", "The answer is 144.\n" ] } ], "source": [ "for chunk in app.stream(\n", " {\"messages\": [(\"human\", \"What's twelve times twelve?\")]}, stream_mode=\"values\"\n", "):\n", " chunk[\"messages\"][-1].pretty_print()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 2 }