Compare commits
	
		
			2 commits
		
	
	
		
			9dea79b4a4
			...
			f16bb6b8cf
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | f16bb6b8cf | ||
|  | 9e1da05276 | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							|  | @ -166,4 +166,3 @@ test/data/output/* | |||
| 
 | ||||
| .pytest_cache | ||||
| .ruff_cache | ||||
| *.ipynb | ||||
							
								
								
									
										353
									
								
								notebooks/fireworks_tools.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										353
									
								
								notebooks/fireworks_tools.ipynb
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,353 @@ | |||
| { | ||||
|  "cells": [ | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 1, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from langchain.chains import LLMMathChain\n", | ||||
|     "from langchain_openai import ChatOpenAI\n", | ||||
|     "from langchain.agents import AgentExecutor, create_openai_tools_agent\n", | ||||
|     "from langchain import hub\n", | ||||
|     "from pydantic import BaseModel, Field\n", | ||||
|     "from langchain.tools import BaseTool\n", | ||||
|     "from typing import Type\n", | ||||
|     "from hellocomputer.config import settings\n", | ||||
|     "from hellocomputer.models import AvailableModels\n", | ||||
|     "\n", | ||||
|     "# Get the prompt to use - you can modify this!\n", | ||||
|     "prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n", | ||||
|     "\n", | ||||
|     "llm = ChatOpenAI(\n", | ||||
|     "    base_url=settings.llm_base_url,\n", | ||||
|     "    api_key=settings.llm_api_key,\n", | ||||
|     "    model=AvailableModels.firefunction_2,\n", | ||||
|     "    temperature=0.5,\n", | ||||
|     "    max_tokens=256,\n", | ||||
|     ")\n", | ||||
|     "\n", | ||||
|     "math_llm = ChatOpenAI(\n", | ||||
|     "    base_url=settings.llm_base_url,\n", | ||||
|     "    api_key=settings.llm_api_key,\n", | ||||
|     "    model=AvailableModels.mixtral_8x7b,\n", | ||||
|     "    temperature=0.5,\n", | ||||
|     ")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 2, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "class CalculatorInput(BaseModel):\n", | ||||
|     "    query: str = Field(description=\"should be a math expression\")\n", | ||||
|     "\n", | ||||
|     "class CustomCalculatorTool(BaseTool):\n", | ||||
|     "    name: str = \"Calculator\"\n", | ||||
|     "    description: str = \"Tool to evaluate mathemetical expressions\"\n", | ||||
|     "    args_schema: Type[BaseModel] = CalculatorInput\n", | ||||
|     "\n", | ||||
|     "    def _run(self, query: str) -> str:\n", | ||||
|     "        \"\"\"Use the tool.\"\"\"\n", | ||||
|     "        return LLMMathChain.from_llm(llm=math_llm, verbose=True).invoke(query)\n", | ||||
|     "\n", | ||||
|     "    async def _arun(self, query: str) -> str:\n", | ||||
|     "        \"\"\"Use the tool asynchronously.\"\"\"\n", | ||||
|     "        return LLMMathChain.from_llm(llm=math_llm, verbose=True).ainvoke(query)\n", | ||||
|     "\n", | ||||
|     "tools = [\n", | ||||
|     "  CustomCalculatorTool()\n", | ||||
|     "]\n", | ||||
|     "\n", | ||||
|     "agent = create_openai_tools_agent(llm, tools, prompt)\n", | ||||
|     "\n", | ||||
|     "agent = AgentExecutor(agent=agent, tools=tools, verbose=True)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 3, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", | ||||
|       "\u001b[32;1m\u001b[1;3mThe capital of USA is Washington, D.C.\u001b[0m\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||
|       "{'input': 'What is the capital of USA?', 'output': 'The capital of USA is Washington, D.C.'}\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(agent.invoke({\"input\": \"What is the capital of USA?\"}))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 4, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", | ||||
|       "\u001b[32;1m\u001b[1;3m\n", | ||||
|       "Invoking: `Calculator` with `{'query': '100/25'}`\n", | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "\u001b[0m\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", | ||||
|       "100/25\u001b[32;1m\u001b[1;3m```text\n", | ||||
|       "100 / 25\n", | ||||
|       "```\n", | ||||
|       "...numexpr.evaluate(\"100 / 25\")...\n", | ||||
|       "\u001b[0m\n", | ||||
|       "Answer: \u001b[33;1m\u001b[1;3m4.0\u001b[0m\n", | ||||
|       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||
|       "\u001b[36;1m\u001b[1;3m{'question': '100/25', 'answer': 'Answer: 4.0'}\u001b[0m\u001b[32;1m\u001b[1;3mThe answer is 4.0.\u001b[0m\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||
|       "{'input': 'What is 100 divided by 25?', 'output': 'The answer is 4.0.'}\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "print(agent.invoke({\"input\": \"What is 100 divided by 25?\"}))" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "markdown", | ||||
|    "metadata": {}, | ||||
|    "source": [ | ||||
|     "# Now let's modify this code and make it a Graph" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 6, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from langgraph.prebuilt import ToolNode\n", | ||||
|     "from langchain_core.messages import AIMessage\n", | ||||
|     "\n", | ||||
|     "tool_node = ToolNode(tools=tools)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 7, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", | ||||
|       "50 / 2\u001b[32;1m\u001b[1;3m```text\n", | ||||
|       "50 / 2\n", | ||||
|       "```\n", | ||||
|       "...numexpr.evaluate(\"50 / 2\")...\n", | ||||
|       "\u001b[0m\n", | ||||
|       "Answer: \u001b[33;1m\u001b[1;3m25.0\u001b[0m\n", | ||||
|       "\u001b[1m> Finished chain.\u001b[0m\n" | ||||
|      ] | ||||
|     }, | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "{'messages': [ToolMessage(content='{\"question\": \"50 / 2\", \"answer\": \"Answer: 25.0\"}', name='Calculator', tool_call_id='tool_call_id')]}" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 7, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "# Call the tools node manually\n", | ||||
|     "\n", | ||||
|     "message_with_single_tool_call = AIMessage(\n", | ||||
|     "    content=\"\",\n", | ||||
|     "    tool_calls=[\n", | ||||
|     "        {\n", | ||||
|     "            \"name\": \"Calculator\",\n", | ||||
|     "            \"args\": {\"query\": \"50 / 2\"},\n", | ||||
|     "            \"id\": \"tool_call_id\",\n", | ||||
|     "            \"type\": \"tool_call\",\n", | ||||
|     "        }\n", | ||||
|     "    ],\n", | ||||
|     ")\n", | ||||
|     "\n", | ||||
|     "tool_node.invoke({\"messages\": [message_with_single_tool_call]})" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 8, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "# Bind the tools to the agent so it knows what to call\n", | ||||
|     "\n", | ||||
|     "agent = ChatOpenAI(\n", | ||||
|     "    base_url=\"https://api.fireworks.ai/inference/v1\",\n", | ||||
|     "    api_key=\"bGWp5ErQNI7rP8GOcGBJmyC5QMV7z8UdBpLAseTaxhbAk6u1\",\n", | ||||
|     "    model=\"accounts/fireworks/models/firefunction-v2\",\n", | ||||
|     "    temperature=0.5,\n", | ||||
|     "    max_tokens=256,\n", | ||||
|     ").bind_tools(tools)" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 9, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "data": { | ||||
|       "text/plain": [ | ||||
|        "[{'name': 'Calculator',\n", | ||||
|        "  'args': {'query': '234/7'},\n", | ||||
|        "  'id': 'call_mP5fctn8N6vilM1Yewb5QxRY',\n", | ||||
|        "  'type': 'tool_call'}]" | ||||
|       ] | ||||
|      }, | ||||
|      "execution_count": 9, | ||||
|      "metadata": {}, | ||||
|      "output_type": "execute_result" | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "agent.invoke(\"234/7\").tool_calls" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 10, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "from typing import Literal\n", | ||||
|     "\n", | ||||
|     "from langgraph.graph import StateGraph, MessagesState\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "def should_continue(state: MessagesState) -> Literal[\"tools\", \"__end__\"]:\n", | ||||
|     "    messages = state[\"messages\"]\n", | ||||
|     "    last_message = messages[-1]\n", | ||||
|     "    if last_message.tool_calls:\n", | ||||
|     "        return \"tools\"\n", | ||||
|     "    return \"__end__\"\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "def call_model(state: MessagesState):\n", | ||||
|     "    messages = state[\"messages\"]\n", | ||||
|     "    response = agent.invoke(messages)\n", | ||||
|     "    return {\"messages\": [response]}\n", | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "workflow = StateGraph(MessagesState)\n", | ||||
|     "\n", | ||||
|     "# Define the two nodes we will cycle between\n", | ||||
|     "workflow.add_node(\"agent\", call_model)\n", | ||||
|     "workflow.add_node(\"tools\", tool_node)\n", | ||||
|     "\n", | ||||
|     "workflow.add_edge(\"__start__\", \"agent\")\n", | ||||
|     "workflow.add_conditional_edges(\n", | ||||
|     "    \"agent\",\n", | ||||
|     "    should_continue,\n", | ||||
|     ")\n", | ||||
|     "workflow.add_edge(\"tools\", \"agent\")\n", | ||||
|     "\n", | ||||
|     "app = workflow.compile()" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": 11, | ||||
|    "metadata": {}, | ||||
|    "outputs": [ | ||||
|     { | ||||
|      "name": "stdout", | ||||
|      "output_type": "stream", | ||||
|      "text": [ | ||||
|       "================================\u001b[1m Human Message \u001b[0m=================================\n", | ||||
|       "\n", | ||||
|       "What's twelve times twelve?\n", | ||||
|       "==================================\u001b[1m Ai Message \u001b[0m==================================\n", | ||||
|       "Tool Calls:\n", | ||||
|       "  Calculator (call_dEdudo8txWMpLKMek070TIWd)\n", | ||||
|       " Call ID: call_dEdudo8txWMpLKMek070TIWd\n", | ||||
|       "  Args:\n", | ||||
|       "    query: 12 * 12\n", | ||||
|       "\n", | ||||
|       "\n", | ||||
|       "\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n", | ||||
|       "12 * 12\u001b[32;1m\u001b[1;3m```text\n", | ||||
|       "12 * 12\n", | ||||
|       "```\n", | ||||
|       "...numexpr.evaluate(\"12 * 12\")...\n", | ||||
|       "\u001b[0m\n", | ||||
|       "Answer: \u001b[33;1m\u001b[1;3m144\u001b[0m\n", | ||||
|       "\u001b[1m> Finished chain.\u001b[0m\n", | ||||
|       "=================================\u001b[1m Tool Message \u001b[0m=================================\n", | ||||
|       "Name: Calculator\n", | ||||
|       "\n", | ||||
|       "{\"question\": \"12 * 12\", \"answer\": \"Answer: 144\"}\n", | ||||
|       "==================================\u001b[1m Ai Message \u001b[0m==================================\n", | ||||
|       "\n", | ||||
|       "The answer is 144.\n" | ||||
|      ] | ||||
|     } | ||||
|    ], | ||||
|    "source": [ | ||||
|     "for chunk in app.stream(\n", | ||||
|     "    {\"messages\": [(\"human\", \"What's twelve times twelve?\")]}, stream_mode=\"values\"\n", | ||||
|     "):\n", | ||||
|     "    chunk[\"messages\"][-1].pretty_print()\n" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
|    "cell_type": "code", | ||||
|    "execution_count": null, | ||||
|    "metadata": {}, | ||||
|    "outputs": [], | ||||
|    "source": [] | ||||
|   } | ||||
|  ], | ||||
|  "metadata": { | ||||
|   "kernelspec": { | ||||
|    "display_name": ".venv", | ||||
|    "language": "python", | ||||
|    "name": "python3" | ||||
|   }, | ||||
|   "language_info": { | ||||
|    "codemirror_mode": { | ||||
|     "name": "ipython", | ||||
|     "version": 3 | ||||
|    }, | ||||
|    "file_extension": ".py", | ||||
|    "mimetype": "text/x-python", | ||||
|    "name": "python", | ||||
|    "nbconvert_exporter": "python", | ||||
|    "pygments_lexer": "ipython3", | ||||
|    "version": "3.12.4" | ||||
|   } | ||||
|  }, | ||||
|  "nbformat": 4, | ||||
|  "nbformat_minor": 2 | ||||
| } | ||||
|  | @ -15,16 +15,19 @@ class DuckdbQueryInput(BaseModel): | |||
| class DuckdbQueryTool(BaseTool): | ||||
|     name: str = "sql_query" | ||||
|     description: str = "Run a SQL query in the database containing all the datasets " | ||||
|     "and provide a summary of the results" | ||||
|     "and provide a summary of the results, and the name of the table with them if the " | ||||
|     "volume of the results is large" | ||||
|     args_schema: Type[BaseModel] = DuckdbQueryInput | ||||
| 
 | ||||
|     def _run(self, query: str, session_id: str) -> str: | ||||
|         """Run the query""" | ||||
|         db = SessionDB(settings, session_id) | ||||
|         session = SessionDB(settings, session_id) | ||||
|         session.db.sql(query) | ||||
| 
 | ||||
|     async def _arun(self, query: str, session_id: str) -> str: | ||||
|         """Use the tool asynchronously.""" | ||||
|         db = SessionDB(settings, session_id) | ||||
|         session = SessionDB(settings, session_id) | ||||
|         session.db.sql(query) | ||||
|         return "Table" | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue