Compare commits

...

2 commits

Author SHA1 Message Date
Guillem Borrell f16bb6b8cf Working on SQL tools
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
2024-07-29 17:20:56 +02:00
Guillem Borrell 9e1da05276 Allow notebooks 2024-07-29 17:20:51 +02:00
3 changed files with 360 additions and 5 deletions

3
.gitignore vendored
View file

@ -165,5 +165,4 @@ cython_debug/
test/data/output/*
.pytest_cache
.ruff_cache
*.ipynb
.ruff_cache

View file

@ -0,0 +1,353 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMMathChain\n",
"from langchain_openai import ChatOpenAI\n",
"from langchain.agents import AgentExecutor, create_openai_tools_agent\n",
"from langchain import hub\n",
"from pydantic import BaseModel, Field\n",
"from langchain.tools import BaseTool\n",
"from typing import Type\n",
"from hellocomputer.config import settings\n",
"from hellocomputer.models import AvailableModels\n",
"\n",
"# Get the prompt to use - you can modify this!\n",
"prompt = hub.pull(\"hwchase17/openai-tools-agent\")\n",
"\n",
"llm = ChatOpenAI(\n",
" base_url=settings.llm_base_url,\n",
" api_key=settings.llm_api_key,\n",
" model=AvailableModels.firefunction_2,\n",
" temperature=0.5,\n",
" max_tokens=256,\n",
")\n",
"\n",
"math_llm = ChatOpenAI(\n",
" base_url=settings.llm_base_url,\n",
" api_key=settings.llm_api_key,\n",
" model=AvailableModels.mixtral_8x7b,\n",
" temperature=0.5,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class CalculatorInput(BaseModel):\n",
" query: str = Field(description=\"should be a math expression\")\n",
"\n",
"class CustomCalculatorTool(BaseTool):\n",
" name: str = \"Calculator\"\n",
" description: str = \"Tool to evaluate mathemetical expressions\"\n",
" args_schema: Type[BaseModel] = CalculatorInput\n",
"\n",
" def _run(self, query: str) -> str:\n",
" \"\"\"Use the tool.\"\"\"\n",
" return LLMMathChain.from_llm(llm=math_llm, verbose=True).invoke(query)\n",
"\n",
" async def _arun(self, query: str) -> str:\n",
" \"\"\"Use the tool asynchronously.\"\"\"\n",
" return LLMMathChain.from_llm(llm=math_llm, verbose=True).ainvoke(query)\n",
"\n",
"tools = [\n",
" CustomCalculatorTool()\n",
"]\n",
"\n",
"agent = create_openai_tools_agent(llm, tools, prompt)\n",
"\n",
"agent = AgentExecutor(agent=agent, tools=tools, verbose=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3mThe capital of USA is Washington, D.C.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"{'input': 'What is the capital of USA?', 'output': 'The capital of USA is Washington, D.C.'}\n"
]
}
],
"source": [
"print(agent.invoke({\"input\": \"What is the capital of USA?\"}))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
"\u001b[32;1m\u001b[1;3m\n",
"Invoking: `Calculator` with `{'query': '100/25'}`\n",
"\n",
"\n",
"\u001b[0m\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"100/25\u001b[32;1m\u001b[1;3m```text\n",
"100 / 25\n",
"```\n",
"...numexpr.evaluate(\"100 / 25\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m4.0\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"\u001b[36;1m\u001b[1;3m{'question': '100/25', 'answer': 'Answer: 4.0'}\u001b[0m\u001b[32;1m\u001b[1;3mThe answer is 4.0.\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"{'input': 'What is 100 divided by 25?', 'output': 'The answer is 4.0.'}\n"
]
}
],
"source": [
"print(agent.invoke({\"input\": \"What is 100 divided by 25?\"}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Now let's modify this code and make it a Graph"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from langgraph.prebuilt import ToolNode\n",
"from langchain_core.messages import AIMessage\n",
"\n",
"tool_node = ToolNode(tools=tools)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"50 / 2\u001b[32;1m\u001b[1;3m```text\n",
"50 / 2\n",
"```\n",
"...numexpr.evaluate(\"50 / 2\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m25.0\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"{'messages': [ToolMessage(content='{\"question\": \"50 / 2\", \"answer\": \"Answer: 25.0\"}', name='Calculator', tool_call_id='tool_call_id')]}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Call the tools node manually\n",
"\n",
"message_with_single_tool_call = AIMessage(\n",
" content=\"\",\n",
" tool_calls=[\n",
" {\n",
" \"name\": \"Calculator\",\n",
" \"args\": {\"query\": \"50 / 2\"},\n",
" \"id\": \"tool_call_id\",\n",
" \"type\": \"tool_call\",\n",
" }\n",
" ],\n",
")\n",
"\n",
"tool_node.invoke({\"messages\": [message_with_single_tool_call]})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Bind the tools to the agent so it knows what to call\n",
"\n",
"agent = ChatOpenAI(\n",
" base_url=\"https://api.fireworks.ai/inference/v1\",\n",
" api_key=\"bGWp5ErQNI7rP8GOcGBJmyC5QMV7z8UdBpLAseTaxhbAk6u1\",\n",
" model=\"accounts/fireworks/models/firefunction-v2\",\n",
" temperature=0.5,\n",
" max_tokens=256,\n",
").bind_tools(tools)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'name': 'Calculator',\n",
" 'args': {'query': '234/7'},\n",
" 'id': 'call_mP5fctn8N6vilM1Yewb5QxRY',\n",
" 'type': 'tool_call'}]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"agent.invoke(\"234/7\").tool_calls"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from typing import Literal\n",
"\n",
"from langgraph.graph import StateGraph, MessagesState\n",
"\n",
"\n",
"def should_continue(state: MessagesState) -> Literal[\"tools\", \"__end__\"]:\n",
" messages = state[\"messages\"]\n",
" last_message = messages[-1]\n",
" if last_message.tool_calls:\n",
" return \"tools\"\n",
" return \"__end__\"\n",
"\n",
"\n",
"def call_model(state: MessagesState):\n",
" messages = state[\"messages\"]\n",
" response = agent.invoke(messages)\n",
" return {\"messages\": [response]}\n",
"\n",
"\n",
"workflow = StateGraph(MessagesState)\n",
"\n",
"# Define the two nodes we will cycle between\n",
"workflow.add_node(\"agent\", call_model)\n",
"workflow.add_node(\"tools\", tool_node)\n",
"\n",
"workflow.add_edge(\"__start__\", \"agent\")\n",
"workflow.add_conditional_edges(\n",
" \"agent\",\n",
" should_continue,\n",
")\n",
"workflow.add_edge(\"tools\", \"agent\")\n",
"\n",
"app = workflow.compile()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"================================\u001b[1m Human Message \u001b[0m=================================\n",
"\n",
"What's twelve times twelve?\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"Tool Calls:\n",
" Calculator (call_dEdudo8txWMpLKMek070TIWd)\n",
" Call ID: call_dEdudo8txWMpLKMek070TIWd\n",
" Args:\n",
" query: 12 * 12\n",
"\n",
"\n",
"\u001b[1m> Entering new LLMMathChain chain...\u001b[0m\n",
"12 * 12\u001b[32;1m\u001b[1;3m```text\n",
"12 * 12\n",
"```\n",
"...numexpr.evaluate(\"12 * 12\")...\n",
"\u001b[0m\n",
"Answer: \u001b[33;1m\u001b[1;3m144\u001b[0m\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"=================================\u001b[1m Tool Message \u001b[0m=================================\n",
"Name: Calculator\n",
"\n",
"{\"question\": \"12 * 12\", \"answer\": \"Answer: 144\"}\n",
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
"\n",
"The answer is 144.\n"
]
}
],
"source": [
"for chunk in app.stream(\n",
" {\"messages\": [(\"human\", \"What's twelve times twelve?\")]}, stream_mode=\"values\"\n",
"):\n",
" chunk[\"messages\"][-1].pretty_print()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -15,16 +15,19 @@ class DuckdbQueryInput(BaseModel):
class DuckdbQueryTool(BaseTool):
name: str = "sql_query"
description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results"
"and provide a summary of the results, and the name of the table with them if the "
"volume of the results is large"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
db = SessionDB(settings, session_id)
session = SessionDB(settings, session_id)
session.db.sql(query)
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
db = SessionDB(settings, session_id)
session = SessionDB(settings, session_id)
session.db.sql(query)
return "Table"