From eb72885f5b17820aaaa11e56c0c8fb6009627225 Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Wed, 7 Aug 2024 11:29:29 +0200 Subject: [PATCH] Nothing works because I need to find a way to pass configuration to the graph --- src/hellocomputer/db/sessions.py | 18 +++++- src/hellocomputer/graph.py | 12 +++- src/hellocomputer/nodes.py | 23 +++---- src/hellocomputer/tools/__init__.py | 0 src/hellocomputer/tools/db.py | 96 +++++++++++++++++++++++++++++ src/hellocomputer/tools/viz.py | 16 +++++ 6 files changed, 148 insertions(+), 17 deletions(-) create mode 100644 src/hellocomputer/tools/__init__.py create mode 100644 src/hellocomputer/tools/db.py create mode 100644 src/hellocomputer/tools/viz.py diff --git a/src/hellocomputer/db/sessions.py b/src/hellocomputer/db/sessions.py index 76cbf8a..2c576ba 100644 --- a/src/hellocomputer/db/sessions.py +++ b/src/hellocomputer/db/sessions.py @@ -3,16 +3,18 @@ from pathlib import Path import duckdb from langchain_community.utilities.sql_database import SQLDatabase +from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit +from langchain_openai import ChatOpenAI from typing_extensions import Self -from hellocomputer.config import Settings, StorageEngines +from hellocomputer.config import Settings, settings, StorageEngines +from hellocomputer.models import AvailableModels from . import DDB class SessionDB(DDB): - def __init__(self, settings: Settings, sid: str): - super().__init__(settings=settings) + def set_session(self, sid): self.sid = sid # Override storage engine for sessions if settings.storage_engine == StorageEngines.gcs: @@ -171,3 +173,13 @@ class SessionDB(DDB): @property def llmsql(self): return SQLDatabase(self.engine, ignore_tables=["metadata"]) + + @property + def sql_toolkit(self) -> SQLDatabaseToolkit: + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_medium, + temperature=0.3, + ) + return SQLDatabaseToolkit(db=self.llmsql, llm=llm) diff --git a/src/hellocomputer/graph.py b/src/hellocomputer/graph.py index fb4d497..5ab5f15 100644 --- a/src/hellocomputer/graph.py +++ b/src/hellocomputer/graph.py @@ -5,9 +5,11 @@ from langgraph.graph import END, START, MessagesState, StateGraph from hellocomputer.nodes import ( intent, answer_general, - answer_query, answer_visualization, ) +from hellocomputer.config import settings + +from hellocomputer.tools.db import SQLSubgraph def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: @@ -22,7 +24,6 @@ workflow = StateGraph(MessagesState) workflow.add_node("intent", intent) workflow.add_node("answer_general", answer_general) -workflow.add_node("answer_query", answer_query) workflow.add_node("answer_visualization", answer_visualization) # Edges @@ -38,7 +39,12 @@ workflow.add_conditional_edges( }, ) workflow.add_edge("answer_general", END) -workflow.add_edge("answer_query", END) workflow.add_edge("answer_visualization", END) +# SQL Subgraph + +workflow = SQLSubgraph().add_subgraph( + workflow=workflow, origin="intent", destination=END +) + app = workflow.compile() diff --git a/src/hellocomputer/nodes.py b/src/hellocomputer/nodes.py index 9232b09..615d833 100644 --- a/src/hellocomputer/nodes.py +++ b/src/hellocomputer/nodes.py @@ -1,6 +1,7 @@ from langchain_openai import ChatOpenAI from langgraph.graph import MessagesState + from hellocomputer.config import settings from hellocomputer.extraction import initial_intent_parser from hellocomputer.models import AvailableModels @@ -35,17 +36,17 @@ async def answer_general(state: MessagesState): return {"messages": [await chain.ainvoke({})]} -async def answer_query(state: MessagesState): - llm = ChatOpenAI( - base_url=settings.llm_base_url, - api_key=settings.llm_api_key, - model=AvailableModels.llama_small, - temperature=0, - ) - prompt = await Prompts.sql() - chain = prompt | llm - - return {"messages": [await chain.ainvoke({})]} +# async def answer_query(state: MessagesState): +# llm = ChatOpenAI( +# base_url=settings.llm_base_url, +# api_key=settings.llm_api_key, +# model=AvailableModels.llama_small, +# temperature=0, +# ) +# prompt = await Prompts.sql() +# chain = prompt | llm +# +# return {"messages": [await chain.ainvoke({})]} async def answer_visualization(state: MessagesState): diff --git a/src/hellocomputer/tools/__init__.py b/src/hellocomputer/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hellocomputer/tools/db.py b/src/hellocomputer/tools/db.py new file mode 100644 index 0000000..b21332a --- /dev/null +++ b/src/hellocomputer/tools/db.py @@ -0,0 +1,96 @@ +from typing import Type + +from langchain.tools import BaseTool +from langgraph.prebuilt import ToolNode +from langgraph.graph import MessagesState, StateGraph +from langchain_openai import ChatOpenAI +from pydantic import BaseModel, Field + +from hellocomputer.config import settings +from hellocomputer.db.sessions import SessionDB +from hellocomputer.models import AvailableModels + + +## This in case I need to create more ReAct agents + + +class DuckdbQueryInput(BaseModel): + query: str = Field(description="Question to be translated to a SQL statement") + session_id: str = Field(description="Session ID necessary to fetch the data") + + +class DuckdbQueryTool(BaseTool): + name: str = "sql_query" + description: str = "Run a SQL query in the database containing all the datasets " + "and provide a summary of the results, and the name of the table with them if the " + "volume of the results is large" + args_schema: Type[BaseModel] = DuckdbQueryInput + + def _run(self, query: str, session_id: str) -> str: + """Run the query""" + session = SessionDB(settings, session_id) + session.db.sql(query) + + async def _arun(self, query: str, session_id: str) -> str: + """Use the tool asynchronously.""" + session = SessionDB(settings, session_id) + session.db.sql(query) + return "Table" + + +class SQLSubgraph: + """ + Creates the question-answering agent that generates and runs SQL + queries + """ + + async def call_model(self, state: MessagesState): + db = SessionDB(settings=settings) + sql_toolkit = db.sql_toolkit + + agent_llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.firefunction_2, + temperature=0.5, + max_tokens=256, + ).bind_tools(sql_toolkit.get_tools()) + + messages = state["messages"] + response = agent_llm.ainvoke(messages) + return {"messages": [response]} + + @property + def query_tool_node(self) -> ToolNode: + db = SessionDB(settings=settings) + sql_toolkit = db.sql_toolkit + return ToolNode(sql_toolkit.get_tools()) + + def add_subgraph( + self, workflow: StateGraph, origin: str, destination: str + ) -> StateGraph: + """Creates the nodes and edges of the subgraph given a workflow + + Args: + workflow (StateGraph): Workflow that will get nodes and edges added + origin (str): Origin node + destination (str): Destination node + + Returns: + StateGraph: Resulting workflow + """ + + def should_continue(state: MessagesState): + messages = state["messages"] + last_message = messages[-1] + if last_message.tool_calls: + return destination + return "__end__" + + workflow.add_node("sql_agent", self.call_model) + workflow.add_node("sql_tool_node", self.query_tool_node) + workflow.add_edge(origin, "sql_agent") + workflow.add_conditional_edges("agent", should_continue) + workflow.add_edge("sql_agent", "sql_tool_node") + + return workflow diff --git a/src/hellocomputer/tools/viz.py b/src/hellocomputer/tools/viz.py new file mode 100644 index 0000000..59432fe --- /dev/null +++ b/src/hellocomputer/tools/viz.py @@ -0,0 +1,16 @@ +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + + +class PlotHistogramInput(BaseModel): + column_name: str = Field(description="Name of the column containing the values") + table_name: str = Field(description="Name of the table that contains the data") + num_bins: int = Field(description="Number of bins of the histogram") + + +class PlotHistogramTool(BaseTool): + name: str = "plot_histogram" + description: str = """ + Generate a histogram plot given a name of an existing table of the database, + and a name of a column in the table. The default number of bins is 10, but + you can forward the number of bins if you are requested to"""