diff --git a/src/hellocomputer/db/sessions.py b/src/hellocomputer/db/sessions.py index 2c576ba..507b807 100644 --- a/src/hellocomputer/db/sessions.py +++ b/src/hellocomputer/db/sessions.py @@ -7,7 +7,7 @@ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit from langchain_openai import ChatOpenAI from typing_extensions import Self -from hellocomputer.config import Settings, settings, StorageEngines +from hellocomputer.config import settings, StorageEngines from hellocomputer.models import AvailableModels from . import DDB @@ -172,7 +172,9 @@ class SessionDB(DDB): @property def llmsql(self): - return SQLDatabase(self.engine, ignore_tables=["metadata"]) + return SQLDatabase( + self.engine + ) ## Cannot ignore tables because it creates a connection at import time @property def sql_toolkit(self) -> SQLDatabaseToolkit: diff --git a/src/hellocomputer/graph.py b/src/hellocomputer/graph.py index 5ab5f15..ab67bea 100644 --- a/src/hellocomputer/graph.py +++ b/src/hellocomputer/graph.py @@ -1,24 +1,35 @@ from typing import Literal +from langgraph.graph import END, START, StateGraph -from langgraph.graph import END, START, MessagesState, StateGraph from hellocomputer.nodes import ( intent, answer_general, answer_visualization, ) -from hellocomputer.config import settings +from hellocomputer.tools import extract_sid from hellocomputer.tools.db import SQLSubgraph +from hellocomputer.db.sessions import SessionDB +from hellocomputer.state import SidState -def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: +def route_intent(state: SidState) -> Literal["general", "query", "visualization"]: messages = state["messages"] last_message = messages[-1] return last_message.content -workflow = StateGraph(MessagesState) +def get_sid(state: SidState) -> Literal["ok"]: + print(state["messages"]) + last_message = state["messages"][-1] + sid = extract_sid(last_message) + state.sid = SessionDB(sid) + + +sql_subgraph = SQLSubgraph() + +workflow = StateGraph(SidState) # Nodes @@ -27,14 +38,13 @@ workflow.add_node("answer_general", answer_general) workflow.add_node("answer_visualization", answer_visualization) # Edges - workflow.add_edge(START, "intent") workflow.add_conditional_edges( "intent", route_intent, { "general": "answer_general", - "query": "answer_query", + "query": sql_subgraph.start_node, "visualization": "answer_visualization", }, ) @@ -43,7 +53,7 @@ workflow.add_edge("answer_visualization", END) # SQL Subgraph -workflow = SQLSubgraph().add_subgraph( +workflow = sql_subgraph.add_nodes_edges( workflow=workflow, origin="intent", destination=END ) diff --git a/src/hellocomputer/routers/chat.py b/src/hellocomputer/routers/chat.py index 41bf4e2..ac0bc78 100644 --- a/src/hellocomputer/routers/chat.py +++ b/src/hellocomputer/routers/chat.py @@ -5,11 +5,16 @@ from starlette.requests import Request from hellocomputer.graph import app +import os + router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["chat"]) async def query(request: Request, sid: str = "", q: str = "") -> str: user = request.session.get("user") # noqa - response = await app.ainvoke({"messages": [HumanMessage(content=q)]}) + content = f"{q}{os.linesep}******{sid}******" + response = await app.ainvoke( + {"messages": [HumanMessage(content=content)]}, + ) return response["messages"][-1].content diff --git a/src/hellocomputer/state.py b/src/hellocomputer/state.py new file mode 100644 index 0000000..1f4dce3 --- /dev/null +++ b/src/hellocomputer/state.py @@ -0,0 +1,8 @@ +from typing import TypedDict, Annotated +from langgraph.graph.message import add_messages +from hellocomputer.db.sessions import SessionDB + + +class SidState(TypedDict): + messages: Annotated[list, add_messages] + sid: SessionDB diff --git a/src/hellocomputer/tools.py b/src/hellocomputer/tools.py deleted file mode 100644 index e654c99..0000000 --- a/src/hellocomputer/tools.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Type - -from langchain.tools import BaseTool -from pydantic import BaseModel, Field - -from hellocomputer.config import settings -from hellocomputer.db.sessions import SessionDB - - -class DuckdbQueryInput(BaseModel): - query: str = Field(description="Question to be translated to a SQL statement") - session_id: str = Field(description="Session ID necessary to fetch the data") - - -class DuckdbQueryTool(BaseTool): - name: str = "sql_query" - description: str = "Run a SQL query in the database containing all the datasets " - "and provide a summary of the results, and the name of the table with them if the " - "volume of the results is large" - args_schema: Type[BaseModel] = DuckdbQueryInput - - def _run(self, query: str, session_id: str) -> str: - """Run the query""" - session = SessionDB(settings, session_id) - session.db.sql(query) - - async def _arun(self, query: str, session_id: str) -> str: - """Use the tool asynchronously.""" - session = SessionDB(settings, session_id) - session.db.sql(query) - return "Table" - - -class PlotHistogramInput(BaseModel): - column_name: str = Field(description="Name of the column containing the values") - table_name: str = Field(description="Name of the table that contains the data") - num_bins: int = Field(description="Number of bins of the histogram") - - -class PlotHistogramTool(BaseTool): - name: str = "plot_histogram" - description: str = """ - Generate a histogram plot given a name of an existing table of the database, - and a name of a column in the table. The default number of bins is 10, but - you can forward the number of bins if you are requested to""" diff --git a/src/hellocomputer/tools/__init__.py b/src/hellocomputer/tools/__init__.py index e69de29..3d14a7f 100644 --- a/src/hellocomputer/tools/__init__.py +++ b/src/hellocomputer/tools/__init__.py @@ -0,0 +1,74 @@ +from typing import Type + +from langchain.tools import BaseTool +from pydantic import BaseModel, Field + +from hellocomputer.config import settings +from hellocomputer.db.sessions import SessionDB + +import re +import os + + +def extract_sid(message: str) -> str: + """Extract the session id from the initial message + + Args: + message (str): Initial message to the entrypoint + + Returns: + str: session id + """ + pattern = r"\*{6}(.*?)\*{6}" + matches = re.findall(pattern, message) + return matches[0] + + +def remove_sid(message: str) -> str: + """_summary_ + + Args: + message (str): _description_ + + Returns: + str: _description_ + """ + return os.linesep.join(message.splitlines()[:-1]) + + +class DuckdbQueryInput(BaseModel): + query: str = Field(description="Question to be translated to a SQL statement") + session_id: str = Field(description="Session ID necessary to fetch the data") + + +class DuckdbQueryTool(BaseTool): + name: str = "sql_query" + description: str = "Run a SQL query in the database containing all the datasets " + "and provide a summary of the results, and the name of the table with them if the " + "volume of the results is large" + args_schema: Type[BaseModel] = DuckdbQueryInput + + def _run(self, query: str, session_id: str) -> str: + """Run the query""" + session = SessionDB(settings, session_id) + session.db.sql(query) + + async def _arun(self, query: str, session_id: str) -> str: + """Use the tool asynchronously.""" + session = SessionDB(settings, session_id) + session.db.sql(query) + return "Table" + + +class PlotHistogramInput(BaseModel): + column_name: str = Field(description="Name of the column containing the values") + table_name: str = Field(description="Name of the table that contains the data") + num_bins: int = Field(description="Number of bins of the histogram") + + +class PlotHistogramTool(BaseTool): + name: str = "plot_histogram" + description: str = """ + Generate a histogram plot given a name of an existing table of the database, + and a name of a column in the table. The default number of bins is 10, but + you can forward the number of bins if you are requested to""" diff --git a/src/hellocomputer/tools/db.py b/src/hellocomputer/tools/db.py index b21332a..8888738 100644 --- a/src/hellocomputer/tools/db.py +++ b/src/hellocomputer/tools/db.py @@ -1,12 +1,13 @@ -from typing import Type +from typing import Type, Literal from langchain.tools import BaseTool from langgraph.prebuilt import ToolNode -from langgraph.graph import MessagesState, StateGraph +from langgraph.graph import StateGraph from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from hellocomputer.config import settings +from hellocomputer.state import SidState from hellocomputer.db.sessions import SessionDB from hellocomputer.models import AvailableModels @@ -44,8 +45,12 @@ class SQLSubgraph: queries """ - async def call_model(self, state: MessagesState): - db = SessionDB(settings=settings) + @property + def start_node(self): + return "sql_agent" + + async def call_model(self, state: SidState): + db = SessionDB(settings=settings).set_session(state.sid) sql_toolkit = db.sql_toolkit agent_llm = ChatOpenAI( @@ -66,7 +71,7 @@ class SQLSubgraph: sql_toolkit = db.sql_toolkit return ToolNode(sql_toolkit.get_tools()) - def add_subgraph( + def add_nodes_edges( self, workflow: StateGraph, origin: str, destination: str ) -> StateGraph: """Creates the nodes and edges of the subgraph given a workflow @@ -80,7 +85,7 @@ class SQLSubgraph: StateGraph: Resulting workflow """ - def should_continue(state: MessagesState): + def should_continue(state: SidState): messages = state["messages"] last_message = messages[-1] if last_message.tool_calls: @@ -90,7 +95,7 @@ class SQLSubgraph: workflow.add_node("sql_agent", self.call_model) workflow.add_node("sql_tool_node", self.query_tool_node) workflow.add_edge(origin, "sql_agent") - workflow.add_conditional_edges("agent", should_continue) + workflow.add_conditional_edges("sql_agent", should_continue) workflow.add_edge("sql_agent", "sql_tool_node") return workflow diff --git a/test/test_tools.py b/test/test_tools.py new file mode 100644 index 0000000..5143197 --- /dev/null +++ b/test/test_tools.py @@ -0,0 +1,13 @@ +from hellocomputer.tools import extract_sid, remove_sid + +message = """This is a message +******sid****** +""" + + +def test_match_sid(): + assert extract_sid(message) == "sid" + + +def test_remove_sid(): + assert remove_sid(message) == "This is a message"