Allow notebooks
This commit is contained in:
parent
9dea79b4a4
commit
9e1da05276
353
notebooks/fireworks_tools.ipynb
Normal file
353
notebooks/fireworks_tools.ipynb
Normal file
|
@ -0,0 +1,353 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMMathChain\n",
|
||||
"from langchain_openai import ChatOpenAI\n",
|
||||
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
|
||||
"from langchain import hub\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from langchain.tools import BaseTool\n",
|
||||
"from typing import Type\n",
|
||||
"from hellocomputer.config import settings\n",
|
||||
"from hellocomputer.models import AvailableModels\n",
|
||||
"\n",
|
||||
"# Get the prompt to use - you can modify this!\n",
|
||||
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
|
||||
"\n",
|
||||
"llm = ChatOpenAI(\n",
|
||||
" base_url=settings.llm_base_url,\n",
|
||||
" api_key=settings.llm_api_key,\n",
|
||||
" model=AvailableModels.firefunction_2,\n",
|
||||
" temperature=0.5,\n",
|
||||
" max_tokens=256,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"math_llm = ChatOpenAI(\n",
|
||||
" base_url=settings.llm_base_url,\n",
|
||||
" api_key=settings.llm_api_key,\n",
|
||||
" model=AvailableModels.mixtral_8x7b,\n",
|
||||
" temperature=0.5,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class CalculatorInput(BaseModel):\n",
|
||||
" query: str = Field(description=\"should be a math expression\")\n",
|
||||
"\n",
|
||||
"class CustomCalculatorTool(BaseTool):\n",
|
||||
" name: str = \"Calculator\"\n",
|
||||
" description: str = \"Tool to evaluate mathemetical expressions\"\n",
|
||||
" args_schema: Type[BaseModel] = CalculatorInput\n",
|
||||
"\n",
|
||||
" def _run(self, query: str) -> str:\n",
|
||||
" \"\"\"Use the tool.\"\"\"\n",
|
||||
" return LLMMathChain.from_llm(llm=math_llm, verbose=True).invoke(query)\n",
|
||||
"\n",
|
||||
" async def _arun(self, query: str) -> str:\n",
|
||||
" \"\"\"Use the tool asynchronously.\"\"\"\n",
|
||||
" return LLMMathChain.from_llm(llm=math_llm, verbose=True).ainvoke(query)\n",
|
||||
"\n",
|
||||
"tools = [\n",
|
||||
" CustomCalculatorTool()\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"agent = create_openai_tools_agent(llm, tools, prompt)\n",
|
||||
"\n",
|
||||
"agent = AgentExecutor(agent=agent, tools=tools, verbose=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3mThe capital of USA is Washington, D.C.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"{'input': 'What is the capital of USA?', 'output': 'The capital of USA is Washington, D.C.'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(agent.invoke({\"input\": \"What is the capital of USA?\"}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
||||
"\u001b[32;1m\u001b[1;3m\n",
|
||||
"Invoking: `Calculator` with `{'query': '100/25'}`\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||||
"100/25\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"100 / 25\n",
|
||||
"```\n",
|
||||
"...numexpr.evaluate(\"100 / 25\")...\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m4.0\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"\u001b[36;1m\u001b[1;3m{'question': '100/25', 'answer': 'Answer: 4.0'}\u001b[0m\u001b[32;1m\u001b[1;3mThe answer is 4.0.\u001b[0m\n",
|
||||
"\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"{'input': 'What is 100 divided by 25?', 'output': 'The answer is 4.0.'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(agent.invoke({\"input\": \"What is 100 divided by 25?\"}))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Now let's modify this code and make it a Graph"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langgraph.prebuilt import ToolNode\n",
|
||||
"from langchain_core.messages import AIMessage\n",
|
||||
"\n",
|
||||
"tool_node = ToolNode(tools=tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||||
"50 / 2\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"50 / 2\n",
|
||||
"```\n",
|
||||
"...numexpr.evaluate(\"50 / 2\")...\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m25.0\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'messages': [ToolMessage(content='{\"question\": \"50 / 2\", \"answer\": \"Answer: 25.0\"}', name='Calculator', tool_call_id='tool_call_id')]}"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Call the tools node manually\n",
|
||||
"\n",
|
||||
"message_with_single_tool_call = AIMessage(\n",
|
||||
" content=\"\",\n",
|
||||
" tool_calls=[\n",
|
||||
" {\n",
|
||||
" \"name\": \"Calculator\",\n",
|
||||
" \"args\": {\"query\": \"50 / 2\"},\n",
|
||||
" \"id\": \"tool_call_id\",\n",
|
||||
" \"type\": \"tool_call\",\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"tool_node.invoke({\"messages\": [message_with_single_tool_call]})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Bind the tools to the agent so it knows what to call\n",
|
||||
"\n",
|
||||
"agent = ChatOpenAI(\n",
|
||||
" base_url=\"https://api.fireworks.ai/inference/v1\",\n",
|
||||
" api_key=\"bGWp5ErQNI7rP8GOcGBJmyC5QMV7z8UdBpLAseTaxhbAk6u1\",\n",
|
||||
" model=\"accounts/fireworks/models/firefunction-v2\",\n",
|
||||
" temperature=0.5,\n",
|
||||
" max_tokens=256,\n",
|
||||
").bind_tools(tools)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[{'name': 'Calculator',\n",
|
||||
" 'args': {'query': '234/7'},\n",
|
||||
" 'id': 'call_mP5fctn8N6vilM1Yewb5QxRY',\n",
|
||||
" 'type': 'tool_call'}]"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent.invoke(\"234/7\").tool_calls"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Literal\n",
|
||||
"\n",
|
||||
"from langgraph.graph import StateGraph, MessagesState\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def should_continue(state: MessagesState) -> Literal[\"tools\", \"__end__\"]:\n",
|
||||
" messages = state[\"messages\"]\n",
|
||||
" last_message = messages[-1]\n",
|
||||
" if last_message.tool_calls:\n",
|
||||
" return \"tools\"\n",
|
||||
" return \"__end__\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def call_model(state: MessagesState):\n",
|
||||
" messages = state[\"messages\"]\n",
|
||||
" response = agent.invoke(messages)\n",
|
||||
" return {\"messages\": [response]}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"workflow = StateGraph(MessagesState)\n",
|
||||
"\n",
|
||||
"# Define the two nodes we will cycle between\n",
|
||||
"workflow.add_node(\"agent\", call_model)\n",
|
||||
"workflow.add_node(\"tools\", tool_node)\n",
|
||||
"\n",
|
||||
"workflow.add_edge(\"__start__\", \"agent\")\n",
|
||||
"workflow.add_conditional_edges(\n",
|
||||
" \"agent\",\n",
|
||||
" should_continue,\n",
|
||||
")\n",
|
||||
"workflow.add_edge(\"tools\", \"agent\")\n",
|
||||
"\n",
|
||||
"app = workflow.compile()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
||||
"\n",
|
||||
"What's twelve times twelve?\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"Tool Calls:\n",
|
||||
" Calculator (call_dEdudo8txWMpLKMek070TIWd)\n",
|
||||
" Call ID: call_dEdudo8txWMpLKMek070TIWd\n",
|
||||
" Args:\n",
|
||||
" query: 12 * 12\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
|
||||
"12 * 12\u001b[32;1m\u001b[1;3m```text\n",
|
||||
"12 * 12\n",
|
||||
"```\n",
|
||||
"...numexpr.evaluate(\"12 * 12\")...\n",
|
||||
"\u001b[0m\n",
|
||||
"Answer: \u001b[33;1m\u001b[1;3m144\u001b[0m\n",
|
||||
"\u001b[1m> Finished chain.\u001b[0m\n",
|
||||
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
|
||||
"Name: Calculator\n",
|
||||
"\n",
|
||||
"{\"question\": \"12 * 12\", \"answer\": \"Answer: 144\"}\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"The answer is 144.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for chunk in app.stream(\n",
|
||||
" {\"messages\": [(\"human\", \"What's twelve times twelve?\")]}, stream_mode=\"values\"\n",
|
||||
"):\n",
|
||||
" chunk[\"messages\"][-1].pretty_print()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in a new issue