This commit is contained in:
parent
eb72885f5b
commit
e040b2e728
|
@ -7,7 +7,7 @@ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
|||
from langchain_openai import ChatOpenAI
|
||||
from typing_extensions import Self
|
||||
|
||||
from hellocomputer.config import Settings, settings, StorageEngines
|
||||
from hellocomputer.config import settings, StorageEngines
|
||||
from hellocomputer.models import AvailableModels
|
||||
|
||||
from . import DDB
|
||||
|
@ -172,7 +172,9 @@ class SessionDB(DDB):
|
|||
|
||||
@property
|
||||
def llmsql(self):
|
||||
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
||||
return SQLDatabase(
|
||||
self.engine
|
||||
) ## Cannot ignore tables because it creates a connection at import time
|
||||
|
||||
@property
|
||||
def sql_toolkit(self) -> SQLDatabaseToolkit:
|
||||
|
|
|
@ -1,24 +1,35 @@
|
|||
from typing import Literal
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
|
||||
from hellocomputer.nodes import (
|
||||
intent,
|
||||
answer_general,
|
||||
answer_visualization,
|
||||
)
|
||||
from hellocomputer.config import settings
|
||||
|
||||
from hellocomputer.tools import extract_sid
|
||||
from hellocomputer.tools.db import SQLSubgraph
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.state import SidState
|
||||
|
||||
|
||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||
def route_intent(state: SidState) -> Literal["general", "query", "visualization"]:
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
return last_message.content
|
||||
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
def get_sid(state: SidState) -> Literal["ok"]:
|
||||
print(state["messages"])
|
||||
last_message = state["messages"][-1]
|
||||
sid = extract_sid(last_message)
|
||||
state.sid = SessionDB(sid)
|
||||
|
||||
|
||||
sql_subgraph = SQLSubgraph()
|
||||
|
||||
workflow = StateGraph(SidState)
|
||||
|
||||
# Nodes
|
||||
|
||||
|
@ -27,14 +38,13 @@ workflow.add_node("answer_general", answer_general)
|
|||
workflow.add_node("answer_visualization", answer_visualization)
|
||||
|
||||
# Edges
|
||||
|
||||
workflow.add_edge(START, "intent")
|
||||
workflow.add_conditional_edges(
|
||||
"intent",
|
||||
route_intent,
|
||||
{
|
||||
"general": "answer_general",
|
||||
"query": "answer_query",
|
||||
"query": sql_subgraph.start_node,
|
||||
"visualization": "answer_visualization",
|
||||
},
|
||||
)
|
||||
|
@ -43,7 +53,7 @@ workflow.add_edge("answer_visualization", END)
|
|||
|
||||
# SQL Subgraph
|
||||
|
||||
workflow = SQLSubgraph().add_subgraph(
|
||||
workflow = sql_subgraph.add_nodes_edges(
|
||||
workflow=workflow, origin="intent", destination=END
|
||||
)
|
||||
|
||||
|
|
|
@ -5,11 +5,16 @@ from starlette.requests import Request
|
|||
|
||||
from hellocomputer.graph import app
|
||||
|
||||
import os
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/query", response_class=PlainTextResponse, tags=["chat"])
|
||||
async def query(request: Request, sid: str = "", q: str = "") -> str:
|
||||
user = request.session.get("user") # noqa
|
||||
response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
|
||||
content = f"{q}{os.linesep}******{sid}******"
|
||||
response = await app.ainvoke(
|
||||
{"messages": [HumanMessage(content=content)]},
|
||||
)
|
||||
return response["messages"][-1].content
|
||||
|
|
8
src/hellocomputer/state.py
Normal file
8
src/hellocomputer/state.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from typing import TypedDict, Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
|
||||
class SidState(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
sid: SessionDB
|
|
@ -1,45 +0,0 @@
|
|||
from typing import Type
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
|
||||
class DuckdbQueryInput(BaseModel):
|
||||
query: str = Field(description="Question to be translated to a SQL statement")
|
||||
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||
|
||||
|
||||
class DuckdbQueryTool(BaseTool):
|
||||
name: str = "sql_query"
|
||||
description: str = "Run a SQL query in the database containing all the datasets "
|
||||
"and provide a summary of the results, and the name of the table with them if the "
|
||||
"volume of the results is large"
|
||||
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||
|
||||
def _run(self, query: str, session_id: str) -> str:
|
||||
"""Run the query"""
|
||||
session = SessionDB(settings, session_id)
|
||||
session.db.sql(query)
|
||||
|
||||
async def _arun(self, query: str, session_id: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
session = SessionDB(settings, session_id)
|
||||
session.db.sql(query)
|
||||
return "Table"
|
||||
|
||||
|
||||
class PlotHistogramInput(BaseModel):
|
||||
column_name: str = Field(description="Name of the column containing the values")
|
||||
table_name: str = Field(description="Name of the table that contains the data")
|
||||
num_bins: int = Field(description="Number of bins of the histogram")
|
||||
|
||||
|
||||
class PlotHistogramTool(BaseTool):
|
||||
name: str = "plot_histogram"
|
||||
description: str = """
|
||||
Generate a histogram plot given a name of an existing table of the database,
|
||||
and a name of a column in the table. The default number of bins is 10, but
|
||||
you can forward the number of bins if you are requested to"""
|
|
@ -0,0 +1,74 @@
|
|||
from typing import Type
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
import re
|
||||
import os
|
||||
|
||||
|
||||
def extract_sid(message: str) -> str:
|
||||
"""Extract the session id from the initial message
|
||||
|
||||
Args:
|
||||
message (str): Initial message to the entrypoint
|
||||
|
||||
Returns:
|
||||
str: session id
|
||||
"""
|
||||
pattern = r"\*{6}(.*?)\*{6}"
|
||||
matches = re.findall(pattern, message)
|
||||
return matches[0]
|
||||
|
||||
|
||||
def remove_sid(message: str) -> str:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
message (str): _description_
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
return os.linesep.join(message.splitlines()[:-1])
|
||||
|
||||
|
||||
class DuckdbQueryInput(BaseModel):
|
||||
query: str = Field(description="Question to be translated to a SQL statement")
|
||||
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||
|
||||
|
||||
class DuckdbQueryTool(BaseTool):
|
||||
name: str = "sql_query"
|
||||
description: str = "Run a SQL query in the database containing all the datasets "
|
||||
"and provide a summary of the results, and the name of the table with them if the "
|
||||
"volume of the results is large"
|
||||
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||
|
||||
def _run(self, query: str, session_id: str) -> str:
|
||||
"""Run the query"""
|
||||
session = SessionDB(settings, session_id)
|
||||
session.db.sql(query)
|
||||
|
||||
async def _arun(self, query: str, session_id: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
session = SessionDB(settings, session_id)
|
||||
session.db.sql(query)
|
||||
return "Table"
|
||||
|
||||
|
||||
class PlotHistogramInput(BaseModel):
|
||||
column_name: str = Field(description="Name of the column containing the values")
|
||||
table_name: str = Field(description="Name of the table that contains the data")
|
||||
num_bins: int = Field(description="Number of bins of the histogram")
|
||||
|
||||
|
||||
class PlotHistogramTool(BaseTool):
|
||||
name: str = "plot_histogram"
|
||||
description: str = """
|
||||
Generate a histogram plot given a name of an existing table of the database,
|
||||
and a name of a column in the table. The default number of bins is 10, but
|
||||
you can forward the number of bins if you are requested to"""
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Type
|
||||
from typing import Type, Literal
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.graph import MessagesState, StateGraph
|
||||
from langgraph.graph import StateGraph
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.state import SidState
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.models import AvailableModels
|
||||
|
||||
|
@ -44,8 +45,12 @@ class SQLSubgraph:
|
|||
queries
|
||||
"""
|
||||
|
||||
async def call_model(self, state: MessagesState):
|
||||
db = SessionDB(settings=settings)
|
||||
@property
|
||||
def start_node(self):
|
||||
return "sql_agent"
|
||||
|
||||
async def call_model(self, state: SidState):
|
||||
db = SessionDB(settings=settings).set_session(state.sid)
|
||||
sql_toolkit = db.sql_toolkit
|
||||
|
||||
agent_llm = ChatOpenAI(
|
||||
|
@ -66,7 +71,7 @@ class SQLSubgraph:
|
|||
sql_toolkit = db.sql_toolkit
|
||||
return ToolNode(sql_toolkit.get_tools())
|
||||
|
||||
def add_subgraph(
|
||||
def add_nodes_edges(
|
||||
self, workflow: StateGraph, origin: str, destination: str
|
||||
) -> StateGraph:
|
||||
"""Creates the nodes and edges of the subgraph given a workflow
|
||||
|
@ -80,7 +85,7 @@ class SQLSubgraph:
|
|||
StateGraph: Resulting workflow
|
||||
"""
|
||||
|
||||
def should_continue(state: MessagesState):
|
||||
def should_continue(state: SidState):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
|
@ -90,7 +95,7 @@ class SQLSubgraph:
|
|||
workflow.add_node("sql_agent", self.call_model)
|
||||
workflow.add_node("sql_tool_node", self.query_tool_node)
|
||||
workflow.add_edge(origin, "sql_agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_conditional_edges("sql_agent", should_continue)
|
||||
workflow.add_edge("sql_agent", "sql_tool_node")
|
||||
|
||||
return workflow
|
||||
|
|
13
test/test_tools.py
Normal file
13
test/test_tools.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from hellocomputer.tools import extract_sid, remove_sid
|
||||
|
||||
message = """This is a message
|
||||
******sid******
|
||||
"""
|
||||
|
||||
|
||||
def test_match_sid():
|
||||
assert extract_sid(message) == "sid"
|
||||
|
||||
|
||||
def test_remove_sid():
|
||||
assert remove_sid(message) == "This is a message"
|
Loading…
Reference in a new issue