Allow notebooks
This commit is contained in:
		
							parent
							
								
									9dea79b4a4
								
							
						
					
					
						commit
						9e1da05276
					
				
							
								
								
									
										353
									
								
								notebooks/fireworks_tools.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										353
									
								
								notebooks/fireworks_tools.ipynb
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,353 @@ | ||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 1, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "from langchain.chains import LLMMathChain\n", | ||||||
|  |     "from langchain_openai import ChatOpenAI\n", | ||||||
|  |     "from langchain.agents import AgentExecutor, create_openai_tools_agent\n", | ||||||
|  |     "from langchain import hub\n", | ||||||
|  |     "from pydantic import BaseModel, Field\n", | ||||||
|  |     "from langchain.tools import BaseTool\n", | ||||||
|  |     "from typing import Type\n", | ||||||
|  |     "from hellocomputer.config import settings\n", | ||||||
|  |     "from hellocomputer.models import AvailableModels\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Get the prompt to use - you can modify this!\n", | ||||||
|  |     "prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "llm = ChatOpenAI(\n", | ||||||
|  |     "    base_url=settings.llm_base_url,\n", | ||||||
|  |     "    api_key=settings.llm_api_key,\n", | ||||||
|  |     "    model=AvailableModels.firefunction_2,\n", | ||||||
|  |     "    temperature=0.5,\n", | ||||||
|  |     "    max_tokens=256,\n", | ||||||
|  |     ")\n", | ||||||
|  |     "\n", | ||||||
|  |     "math_llm = ChatOpenAI(\n", | ||||||
|  |     "    base_url=settings.llm_base_url,\n", | ||||||
|  |     "    api_key=settings.llm_api_key,\n", | ||||||
|  |     "    model=AvailableModels.mixtral_8x7b,\n", | ||||||
|  |     "    temperature=0.5,\n", | ||||||
|  |     ")" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 2, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "class CalculatorInput(BaseModel):\n", | ||||||
|  |     "    query: str = Field(description=\"should be a math expression\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "class CustomCalculatorTool(BaseTool):\n", | ||||||
|  |     "    name: str = \"Calculator\"\n", | ||||||
|  |     "    description: str = \"Tool to evaluate mathemetical expressions\"\n", | ||||||
|  |     "    args_schema: Type[BaseModel] = CalculatorInput\n", | ||||||
|  |     "\n", | ||||||
|  |     "    def _run(self, query: str) -> str:\n", | ||||||
|  |     "        \"\"\"Use the tool.\"\"\"\n", | ||||||
|  |     "        return LLMMathChain.from_llm(llm=math_llm, verbose=True).invoke(query)\n", | ||||||
|  |     "\n", | ||||||
|  |     "    async def _arun(self, query: str) -> str:\n", | ||||||
|  |     "        \"\"\"Use the tool asynchronously.\"\"\"\n", | ||||||
|  |     "        return LLMMathChain.from_llm(llm=math_llm, verbose=True).ainvoke(query)\n", | ||||||
|  |     "\n", | ||||||
|  |     "tools = [\n", | ||||||
|  |     "  CustomCalculatorTool()\n", | ||||||
|  |     "]\n", | ||||||
|  |     "\n", | ||||||
|  |     "agent = create_openai_tools_agent(llm, tools, prompt)\n", | ||||||
|  |     "\n", | ||||||
|  |     "agent = AgentExecutor(agent=agent, tools=tools, verbose=True)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 3, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", | ||||||
|  |       "\u001b[32;1m\u001b[1;3mThe capital of USA is Washington, D.C.\u001b[0m\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||||
|  |       "{'input': 'What is the capital of USA?', 'output': 'The capital of USA is Washington, D.C.'}\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "print(agent.invoke({\"input\": \"What is the capital of USA?\"}))" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 4, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", | ||||||
|  |       "\u001b[32;1m\u001b[1;3m\n", | ||||||
|  |       "Invoking: `Calculator` with `{'query': '100/25'}`\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[0m\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", | ||||||
|  |       "100/25\u001b[32;1m\u001b[1;3m```text\n", | ||||||
|  |       "100 / 25\n", | ||||||
|  |       "```\n", | ||||||
|  |       "...numexpr.evaluate(\"100 / 25\")...\n", | ||||||
|  |       "\u001b[0m\n", | ||||||
|  |       "Answer: \u001b[33;1m\u001b[1;3m4.0\u001b[0m\n", | ||||||
|  |       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||||
|  |       "\u001b[36;1m\u001b[1;3m{'question': '100/25', 'answer': 'Answer: 4.0'}\u001b[0m\u001b[32;1m\u001b[1;3mThe answer is 4.0.\u001b[0m\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||||
|  |       "{'input': 'What is 100 divided by 25?', 'output': 'The answer is 4.0.'}\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "print(agent.invoke({\"input\": \"What is 100 divided by 25?\"}))" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "markdown", | ||||||
|  |    "metadata": {}, | ||||||
|  |    "source": [ | ||||||
|  |     "# Now let's modify this code and make it a Graph" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 6, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "from langgraph.prebuilt import ToolNode\n", | ||||||
|  |     "from langchain_core.messages import AIMessage\n", | ||||||
|  |     "\n", | ||||||
|  |     "tool_node = ToolNode(tools=tools)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 7, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", | ||||||
|  |       "50 / 2\u001b[32;1m\u001b[1;3m```text\n", | ||||||
|  |       "50 / 2\n", | ||||||
|  |       "```\n", | ||||||
|  |       "...numexpr.evaluate(\"50 / 2\")...\n", | ||||||
|  |       "\u001b[0m\n", | ||||||
|  |       "Answer: \u001b[33;1m\u001b[1;3m25.0\u001b[0m\n", | ||||||
|  |       "\u001b[1m> Finished chain.\u001b[0m\n" | ||||||
|  |      ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "{'messages': [ToolMessage(content='{\"question\": \"50 / 2\", \"answer\": \"Answer: 25.0\"}', name='Calculator', tool_call_id='tool_call_id')]}" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "execution_count": 7, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "execute_result" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "# Call the tools node manually\n", | ||||||
|  |     "\n", | ||||||
|  |     "message_with_single_tool_call = AIMessage(\n", | ||||||
|  |     "    content=\"\",\n", | ||||||
|  |     "    tool_calls=[\n", | ||||||
|  |     "        {\n", | ||||||
|  |     "            \"name\": \"Calculator\",\n", | ||||||
|  |     "            \"args\": {\"query\": \"50 / 2\"},\n", | ||||||
|  |     "            \"id\": \"tool_call_id\",\n", | ||||||
|  |     "            \"type\": \"tool_call\",\n", | ||||||
|  |     "        }\n", | ||||||
|  |     "    ],\n", | ||||||
|  |     ")\n", | ||||||
|  |     "\n", | ||||||
|  |     "tool_node.invoke({\"messages\": [message_with_single_tool_call]})" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 8, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "# Bind the tools to the agent so it knows what to call\n", | ||||||
|  |     "\n", | ||||||
|  |     "agent = ChatOpenAI(\n", | ||||||
|  |     "    base_url=\"https://api.fireworks.ai/inference/v1\",\n", | ||||||
|  |     "    api_key=\"bGWp5ErQNI7rP8GOcGBJmyC5QMV7z8UdBpLAseTaxhbAk6u1\",\n", | ||||||
|  |     "    model=\"accounts/fireworks/models/firefunction-v2\",\n", | ||||||
|  |     "    temperature=0.5,\n", | ||||||
|  |     "    max_tokens=256,\n", | ||||||
|  |     ").bind_tools(tools)" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 9, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "data": { | ||||||
|  |       "text/plain": [ | ||||||
|  |        "[{'name': 'Calculator',\n", | ||||||
|  |        "  'args': {'query': '234/7'},\n", | ||||||
|  |        "  'id': 'call_mP5fctn8N6vilM1Yewb5QxRY',\n", | ||||||
|  |        "  'type': 'tool_call'}]" | ||||||
|  |       ] | ||||||
|  |      }, | ||||||
|  |      "execution_count": 9, | ||||||
|  |      "metadata": {}, | ||||||
|  |      "output_type": "execute_result" | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "agent.invoke(\"234/7\").tool_calls" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 10, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "from typing import Literal\n", | ||||||
|  |     "\n", | ||||||
|  |     "from langgraph.graph import StateGraph, MessagesState\n", | ||||||
|  |     "\n", | ||||||
|  |     "\n", | ||||||
|  |     "def should_continue(state: MessagesState) -> Literal[\"tools\", \"__end__\"]:\n", | ||||||
|  |     "    messages = state[\"messages\"]\n", | ||||||
|  |     "    last_message = messages[-1]\n", | ||||||
|  |     "    if last_message.tool_calls:\n", | ||||||
|  |     "        return \"tools\"\n", | ||||||
|  |     "    return \"__end__\"\n", | ||||||
|  |     "\n", | ||||||
|  |     "\n", | ||||||
|  |     "def call_model(state: MessagesState):\n", | ||||||
|  |     "    messages = state[\"messages\"]\n", | ||||||
|  |     "    response = agent.invoke(messages)\n", | ||||||
|  |     "    return {\"messages\": [response]}\n", | ||||||
|  |     "\n", | ||||||
|  |     "\n", | ||||||
|  |     "workflow = StateGraph(MessagesState)\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Define the two nodes we will cycle between\n", | ||||||
|  |     "workflow.add_node(\"agent\", call_model)\n", | ||||||
|  |     "workflow.add_node(\"tools\", tool_node)\n", | ||||||
|  |     "\n", | ||||||
|  |     "workflow.add_edge(\"__start__\", \"agent\")\n", | ||||||
|  |     "workflow.add_conditional_edges(\n", | ||||||
|  |     "    \"agent\",\n", | ||||||
|  |     "    should_continue,\n", | ||||||
|  |     ")\n", | ||||||
|  |     "workflow.add_edge(\"tools\", \"agent\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "app = workflow.compile()" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 11, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "================================\u001b[1m Human Message \u001b[0m=================================\n", | ||||||
|  |       "\n", | ||||||
|  |       "What's twelve times twelve?\n", | ||||||
|  |       "==================================\u001b[1m Ai Message \u001b[0m==================================\n", | ||||||
|  |       "Tool Calls:\n", | ||||||
|  |       "  Calculator (call_dEdudo8txWMpLKMek070TIWd)\n", | ||||||
|  |       " Call ID: call_dEdudo8txWMpLKMek070TIWd\n", | ||||||
|  |       "  Args:\n", | ||||||
|  |       "    query: 12 * 12\n", | ||||||
|  |       "\n", | ||||||
|  |       "\n", | ||||||
|  |       "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", | ||||||
|  |       "12 * 12\u001b[32;1m\u001b[1;3m```text\n", | ||||||
|  |       "12 * 12\n", | ||||||
|  |       "```\n", | ||||||
|  |       "...numexpr.evaluate(\"12 * 12\")...\n", | ||||||
|  |       "\u001b[0m\n", | ||||||
|  |       "Answer: \u001b[33;1m\u001b[1;3m144\u001b[0m\n", | ||||||
|  |       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||||
|  |       "=================================\u001b[1m Tool Message \u001b[0m=================================\n", | ||||||
|  |       "Name: Calculator\n", | ||||||
|  |       "\n", | ||||||
|  |       "{\"question\": \"12 * 12\", \"answer\": \"Answer: 144\"}\n", | ||||||
|  |       "==================================\u001b[1m Ai Message \u001b[0m==================================\n", | ||||||
|  |       "\n", | ||||||
|  |       "The answer is 144.\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "for chunk in app.stream(\n", | ||||||
|  |     "    {\"messages\": [(\"human\", \"What's twelve times twelve?\")]}, stream_mode=\"values\"\n", | ||||||
|  |     "):\n", | ||||||
|  |     "    chunk[\"messages\"][-1].pretty_print()\n" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": ".venv", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.12.4" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 2 | ||||||
|  | } | ||||||
		Loading…
	
		Reference in a new issue