diff --git a/notebooks/fireworks_tools.ipynb b/notebooks/fireworks_tools.ipynb new file mode 100644 index 0000000..a1dd0c4 --- /dev/null +++ b/notebooks/fireworks_tools.ipynb @@ -0,0 +1,353 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import LLMMathChain\n", + "from langchain_openai import ChatOpenAI\n", + "from langchain.agents import AgentExecutor, create_openai_tools_agent\n", + "from langchain import hub\n", + "from pydantic import BaseModel, Field\n", + "from langchain.tools import BaseTool\n", + "from typing import Type\n", + "from hellocomputer.config import settings\n", + "from hellocomputer.models import AvailableModels\n", + "\n", + "# Get the prompt to use - you can modify this!\n", + "prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n", + "\n", + "llm = ChatOpenAI(\n", + " base_url=settings.llm_base_url,\n", + " api_key=settings.llm_api_key,\n", + " model=AvailableModels.firefunction_2,\n", + " temperature=0.5,\n", + " max_tokens=256,\n", + ")\n", + "\n", + "math_llm = ChatOpenAI(\n", + " base_url=settings.llm_base_url,\n", + " api_key=settings.llm_api_key,\n", + " model=AvailableModels.mixtral_8x7b,\n", + " temperature=0.5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class CalculatorInput(BaseModel):\n", + " query: str = Field(description=\"should be a math expression\")\n", + "\n", + "class CustomCalculatorTool(BaseTool):\n", + " name: str = \"Calculator\"\n", + " description: str = \"Tool to evaluate mathemetical expressions\"\n", + " args_schema: Type[BaseModel] = CalculatorInput\n", + "\n", + " def _run(self, query: str) -> str:\n", + " \"\"\"Use the tool.\"\"\"\n", + " return LLMMathChain.from_llm(llm=math_llm, verbose=True).invoke(query)\n", + "\n", + " async def _arun(self, query: str) -> str:\n", + " \"\"\"Use the tool asynchronously.\"\"\"\n", + " return LLMMathChain.from_llm(llm=math_llm, verbose=True).ainvoke(query)\n", + "\n", + "tools = [\n", + " CustomCalculatorTool()\n", + "]\n", + "\n", + "agent = create_openai_tools_agent(llm, tools, prompt)\n", + "\n", + "agent = AgentExecutor(agent=agent, tools=tools, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3mThe capital of USA is Washington, D.C.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "{'input': 'What is the capital of USA?', 'output': 'The capital of USA is Washington, D.C.'}\n" + ] + } + ], + "source": [ + "print(agent.invoke({\"input\": \"What is the capital of USA?\"}))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", + "\u001b[32;1m\u001b[1;3m\n", + "Invoking: `Calculator` with `{'query': '100/25'}`\n", + "\n", + "\n", + "\u001b[0m\n", + "\n", + "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", + "100/25\u001b[32;1m\u001b[1;3m```text\n", + "100 / 25\n", + "```\n", + "...numexpr.evaluate(\"100 / 25\")...\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m4.0\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "\u001b[36;1m\u001b[1;3m{'question': '100/25', 'answer': 'Answer: 4.0'}\u001b[0m\u001b[32;1m\u001b[1;3mThe answer is 4.0.\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "{'input': 'What is 100 divided by 25?', 'output': 'The answer is 4.0.'}\n" + ] + } + ], + "source": [ + "print(agent.invoke({\"input\": \"What is 100 divided by 25?\"}))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Now let's modify this code and make it a Graph" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from langgraph.prebuilt import ToolNode\n", + "from langchain_core.messages import AIMessage\n", + "\n", + "tool_node = ToolNode(tools=tools)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", + "50 / 2\u001b[32;1m\u001b[1;3m```text\n", + "50 / 2\n", + "```\n", + "...numexpr.evaluate(\"50 / 2\")...\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m25.0\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "{'messages': [ToolMessage(content='{\"question\": \"50 / 2\", \"answer\": \"Answer: 25.0\"}', name='Calculator', tool_call_id='tool_call_id')]}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Call the tools node manually\n", + "\n", + "message_with_single_tool_call = AIMessage(\n", + " content=\"\",\n", + " tool_calls=[\n", + " {\n", + " \"name\": \"Calculator\",\n", + " \"args\": {\"query\": \"50 / 2\"},\n", + " \"id\": \"tool_call_id\",\n", + " \"type\": \"tool_call\",\n", + " }\n", + " ],\n", + ")\n", + "\n", + "tool_node.invoke({\"messages\": [message_with_single_tool_call]})" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# Bind the tools to the agent so it knows what to call\n", + "\n", + "agent = ChatOpenAI(\n", + " base_url=\"https://api.fireworks.ai/inference/v1\",\n", + " api_key=\"bGWp5ErQNI7rP8GOcGBJmyC5QMV7z8UdBpLAseTaxhbAk6u1\",\n", + " model=\"accounts/fireworks/models/firefunction-v2\",\n", + " temperature=0.5,\n", + " max_tokens=256,\n", + ").bind_tools(tools)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'Calculator',\n", + " 'args': {'query': '234/7'},\n", + " 'id': 'call_mP5fctn8N6vilM1Yewb5QxRY',\n", + " 'type': 'tool_call'}]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.invoke(\"234/7\").tool_calls" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Literal\n", + "\n", + "from langgraph.graph import StateGraph, MessagesState\n", + "\n", + "\n", + "def should_continue(state: MessagesState) -> Literal[\"tools\", \"__end__\"]:\n", + " messages = state[\"messages\"]\n", + " last_message = messages[-1]\n", + " if last_message.tool_calls:\n", + " return \"tools\"\n", + " return \"__end__\"\n", + "\n", + "\n", + "def call_model(state: MessagesState):\n", + " messages = state[\"messages\"]\n", + " response = agent.invoke(messages)\n", + " return {\"messages\": [response]}\n", + "\n", + "\n", + "workflow = StateGraph(MessagesState)\n", + "\n", + "# Define the two nodes we will cycle between\n", + "workflow.add_node(\"agent\", call_model)\n", + "workflow.add_node(\"tools\", tool_node)\n", + "\n", + "workflow.add_edge(\"__start__\", \"agent\")\n", + "workflow.add_conditional_edges(\n", + " \"agent\",\n", + " should_continue,\n", + ")\n", + "workflow.add_edge(\"tools\", \"agent\")\n", + "\n", + "app = workflow.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "================================\u001b[1m Human Message \u001b[0m=================================\n", + "\n", + "What's twelve times twelve?\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "Tool Calls:\n", + " Calculator (call_dEdudo8txWMpLKMek070TIWd)\n", + " Call ID: call_dEdudo8txWMpLKMek070TIWd\n", + " Args:\n", + " query: 12 * 12\n", + "\n", + "\n", + "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", + "12 * 12\u001b[32;1m\u001b[1;3m```text\n", + "12 * 12\n", + "```\n", + "...numexpr.evaluate(\"12 * 12\")...\n", + "\u001b[0m\n", + "Answer: \u001b[33;1m\u001b[1;3m144\u001b[0m\n", + "\u001b[1m> Finished chain.\u001b[0m\n", + "=================================\u001b[1m Tool Message \u001b[0m=================================\n", + "Name: Calculator\n", + "\n", + "{\"question\": \"12 * 12\", \"answer\": \"Answer: 144\"}\n", + "==================================\u001b[1m Ai Message \u001b[0m==================================\n", + "\n", + "The answer is 144.\n" + ] + } + ], + "source": [ + "for chunk in app.stream(\n", + " {\"messages\": [(\"human\", \"What's twelve times twelve?\")]}, stream_mode=\"values\"\n", + "):\n", + " chunk[\"messages\"][-1].pretty_print()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}