This commit is contained in:
parent
eb72885f5b
commit
e040b2e728
|
@ -7,7 +7,7 @@ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from hellocomputer.config import Settings, settings, StorageEngines
|
from hellocomputer.config import settings, StorageEngines
|
||||||
from hellocomputer.models import AvailableModels
|
from hellocomputer.models import AvailableModels
|
||||||
|
|
||||||
from . import DDB
|
from . import DDB
|
||||||
|
@ -172,7 +172,9 @@ class SessionDB(DDB):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def llmsql(self):
|
def llmsql(self):
|
||||||
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
return SQLDatabase(
|
||||||
|
self.engine
|
||||||
|
) ## Cannot ignore tables because it creates a connection at import time
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sql_toolkit(self) -> SQLDatabaseToolkit:
|
def sql_toolkit(self) -> SQLDatabaseToolkit:
|
||||||
|
|
|
@ -1,24 +1,35 @@
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
from langgraph.graph import END, START, StateGraph
|
||||||
|
|
||||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
||||||
|
|
||||||
from hellocomputer.nodes import (
|
from hellocomputer.nodes import (
|
||||||
intent,
|
intent,
|
||||||
answer_general,
|
answer_general,
|
||||||
answer_visualization,
|
answer_visualization,
|
||||||
)
|
)
|
||||||
from hellocomputer.config import settings
|
|
||||||
|
|
||||||
|
from hellocomputer.tools import extract_sid
|
||||||
from hellocomputer.tools.db import SQLSubgraph
|
from hellocomputer.tools.db import SQLSubgraph
|
||||||
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
from hellocomputer.state import SidState
|
||||||
|
|
||||||
|
|
||||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
def route_intent(state: SidState) -> Literal["general", "query", "visualization"]:
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
return last_message.content
|
return last_message.content
|
||||||
|
|
||||||
|
|
||||||
workflow = StateGraph(MessagesState)
|
def get_sid(state: SidState) -> Literal["ok"]:
|
||||||
|
print(state["messages"])
|
||||||
|
last_message = state["messages"][-1]
|
||||||
|
sid = extract_sid(last_message)
|
||||||
|
state.sid = SessionDB(sid)
|
||||||
|
|
||||||
|
|
||||||
|
sql_subgraph = SQLSubgraph()
|
||||||
|
|
||||||
|
workflow = StateGraph(SidState)
|
||||||
|
|
||||||
# Nodes
|
# Nodes
|
||||||
|
|
||||||
|
@ -27,14 +38,13 @@ workflow.add_node("answer_general", answer_general)
|
||||||
workflow.add_node("answer_visualization", answer_visualization)
|
workflow.add_node("answer_visualization", answer_visualization)
|
||||||
|
|
||||||
# Edges
|
# Edges
|
||||||
|
|
||||||
workflow.add_edge(START, "intent")
|
workflow.add_edge(START, "intent")
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"intent",
|
"intent",
|
||||||
route_intent,
|
route_intent,
|
||||||
{
|
{
|
||||||
"general": "answer_general",
|
"general": "answer_general",
|
||||||
"query": "answer_query",
|
"query": sql_subgraph.start_node,
|
||||||
"visualization": "answer_visualization",
|
"visualization": "answer_visualization",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -43,7 +53,7 @@ workflow.add_edge("answer_visualization", END)
|
||||||
|
|
||||||
# SQL Subgraph
|
# SQL Subgraph
|
||||||
|
|
||||||
workflow = SQLSubgraph().add_subgraph(
|
workflow = sql_subgraph.add_nodes_edges(
|
||||||
workflow=workflow, origin="intent", destination=END
|
workflow=workflow, origin="intent", destination=END
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,16 @@ from starlette.requests import Request
|
||||||
|
|
||||||
from hellocomputer.graph import app
|
from hellocomputer.graph import app
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["chat"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["chat"])
|
||||||
async def query(request: Request, sid: str = "", q: str = "") -> str:
|
async def query(request: Request, sid: str = "", q: str = "") -> str:
|
||||||
user = request.session.get("user") # noqa
|
user = request.session.get("user") # noqa
|
||||||
response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
|
content = f"{q}{os.linesep}******{sid}******"
|
||||||
|
response = await app.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content=content)]},
|
||||||
|
)
|
||||||
return response["messages"][-1].content
|
return response["messages"][-1].content
|
||||||
|
|
8
src/hellocomputer/state.py
Normal file
8
src/hellocomputer/state.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from typing import TypedDict, Annotated
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
|
||||||
|
|
||||||
|
class SidState(TypedDict):
|
||||||
|
messages: Annotated[list, add_messages]
|
||||||
|
sid: SessionDB
|
|
@ -1,45 +0,0 @@
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
|
||||||
from hellocomputer.db.sessions import SessionDB
|
|
||||||
|
|
||||||
|
|
||||||
class DuckdbQueryInput(BaseModel):
|
|
||||||
query: str = Field(description="Question to be translated to a SQL statement")
|
|
||||||
session_id: str = Field(description="Session ID necessary to fetch the data")
|
|
||||||
|
|
||||||
|
|
||||||
class DuckdbQueryTool(BaseTool):
|
|
||||||
name: str = "sql_query"
|
|
||||||
description: str = "Run a SQL query in the database containing all the datasets "
|
|
||||||
"and provide a summary of the results, and the name of the table with them if the "
|
|
||||||
"volume of the results is large"
|
|
||||||
args_schema: Type[BaseModel] = DuckdbQueryInput
|
|
||||||
|
|
||||||
def _run(self, query: str, session_id: str) -> str:
|
|
||||||
"""Run the query"""
|
|
||||||
session = SessionDB(settings, session_id)
|
|
||||||
session.db.sql(query)
|
|
||||||
|
|
||||||
async def _arun(self, query: str, session_id: str) -> str:
|
|
||||||
"""Use the tool asynchronously."""
|
|
||||||
session = SessionDB(settings, session_id)
|
|
||||||
session.db.sql(query)
|
|
||||||
return "Table"
|
|
||||||
|
|
||||||
|
|
||||||
class PlotHistogramInput(BaseModel):
|
|
||||||
column_name: str = Field(description="Name of the column containing the values")
|
|
||||||
table_name: str = Field(description="Name of the table that contains the data")
|
|
||||||
num_bins: int = Field(description="Number of bins of the histogram")
|
|
||||||
|
|
||||||
|
|
||||||
class PlotHistogramTool(BaseTool):
|
|
||||||
name: str = "plot_histogram"
|
|
||||||
description: str = """
|
|
||||||
Generate a histogram plot given a name of an existing table of the database,
|
|
||||||
and a name of a column in the table. The default number of bins is 10, but
|
|
||||||
you can forward the number of bins if you are requested to"""
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def extract_sid(message: str) -> str:
|
||||||
|
"""Extract the session id from the initial message
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): Initial message to the entrypoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: session id
|
||||||
|
"""
|
||||||
|
pattern = r"\*{6}(.*?)\*{6}"
|
||||||
|
matches = re.findall(pattern, message)
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_sid(message: str) -> str:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message (str): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: _description_
|
||||||
|
"""
|
||||||
|
return os.linesep.join(message.splitlines()[:-1])
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbQueryInput(BaseModel):
|
||||||
|
query: str = Field(description="Question to be translated to a SQL statement")
|
||||||
|
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbQueryTool(BaseTool):
|
||||||
|
name: str = "sql_query"
|
||||||
|
description: str = "Run a SQL query in the database containing all the datasets "
|
||||||
|
"and provide a summary of the results, and the name of the table with them if the "
|
||||||
|
"volume of the results is large"
|
||||||
|
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||||
|
|
||||||
|
def _run(self, query: str, session_id: str) -> str:
|
||||||
|
"""Run the query"""
|
||||||
|
session = SessionDB(settings, session_id)
|
||||||
|
session.db.sql(query)
|
||||||
|
|
||||||
|
async def _arun(self, query: str, session_id: str) -> str:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
session = SessionDB(settings, session_id)
|
||||||
|
session.db.sql(query)
|
||||||
|
return "Table"
|
||||||
|
|
||||||
|
|
||||||
|
class PlotHistogramInput(BaseModel):
|
||||||
|
column_name: str = Field(description="Name of the column containing the values")
|
||||||
|
table_name: str = Field(description="Name of the table that contains the data")
|
||||||
|
num_bins: int = Field(description="Number of bins of the histogram")
|
||||||
|
|
||||||
|
|
||||||
|
class PlotHistogramTool(BaseTool):
|
||||||
|
name: str = "plot_histogram"
|
||||||
|
description: str = """
|
||||||
|
Generate a histogram plot given a name of an existing table of the database,
|
||||||
|
and a name of a column in the table. The default number of bins is 10, but
|
||||||
|
you can forward the number of bins if you are requested to"""
|
|
@ -1,12 +1,13 @@
|
||||||
from typing import Type
|
from typing import Type, Literal
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langgraph.prebuilt import ToolNode
|
from langgraph.prebuilt import ToolNode
|
||||||
from langgraph.graph import MessagesState, StateGraph
|
from langgraph.graph import StateGraph
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.state import SidState
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
from hellocomputer.models import AvailableModels
|
from hellocomputer.models import AvailableModels
|
||||||
|
|
||||||
|
@ -44,8 +45,12 @@ class SQLSubgraph:
|
||||||
queries
|
queries
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def call_model(self, state: MessagesState):
|
@property
|
||||||
db = SessionDB(settings=settings)
|
def start_node(self):
|
||||||
|
return "sql_agent"
|
||||||
|
|
||||||
|
async def call_model(self, state: SidState):
|
||||||
|
db = SessionDB(settings=settings).set_session(state.sid)
|
||||||
sql_toolkit = db.sql_toolkit
|
sql_toolkit = db.sql_toolkit
|
||||||
|
|
||||||
agent_llm = ChatOpenAI(
|
agent_llm = ChatOpenAI(
|
||||||
|
@ -66,7 +71,7 @@ class SQLSubgraph:
|
||||||
sql_toolkit = db.sql_toolkit
|
sql_toolkit = db.sql_toolkit
|
||||||
return ToolNode(sql_toolkit.get_tools())
|
return ToolNode(sql_toolkit.get_tools())
|
||||||
|
|
||||||
def add_subgraph(
|
def add_nodes_edges(
|
||||||
self, workflow: StateGraph, origin: str, destination: str
|
self, workflow: StateGraph, origin: str, destination: str
|
||||||
) -> StateGraph:
|
) -> StateGraph:
|
||||||
"""Creates the nodes and edges of the subgraph given a workflow
|
"""Creates the nodes and edges of the subgraph given a workflow
|
||||||
|
@ -80,7 +85,7 @@ class SQLSubgraph:
|
||||||
StateGraph: Resulting workflow
|
StateGraph: Resulting workflow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def should_continue(state: MessagesState):
|
def should_continue(state: SidState):
|
||||||
messages = state["messages"]
|
messages = state["messages"]
|
||||||
last_message = messages[-1]
|
last_message = messages[-1]
|
||||||
if last_message.tool_calls:
|
if last_message.tool_calls:
|
||||||
|
@ -90,7 +95,7 @@ class SQLSubgraph:
|
||||||
workflow.add_node("sql_agent", self.call_model)
|
workflow.add_node("sql_agent", self.call_model)
|
||||||
workflow.add_node("sql_tool_node", self.query_tool_node)
|
workflow.add_node("sql_tool_node", self.query_tool_node)
|
||||||
workflow.add_edge(origin, "sql_agent")
|
workflow.add_edge(origin, "sql_agent")
|
||||||
workflow.add_conditional_edges("agent", should_continue)
|
workflow.add_conditional_edges("sql_agent", should_continue)
|
||||||
workflow.add_edge("sql_agent", "sql_tool_node")
|
workflow.add_edge("sql_agent", "sql_tool_node")
|
||||||
|
|
||||||
return workflow
|
return workflow
|
||||||
|
|
13
test/test_tools.py
Normal file
13
test/test_tools.py
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
from hellocomputer.tools import extract_sid, remove_sid
|
||||||
|
|
||||||
|
message = """This is a message
|
||||||
|
******sid******
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_sid():
|
||||||
|
assert extract_sid(message) == "sid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_sid():
|
||||||
|
assert remove_sid(message) == "This is a message"
|
Loading…
Reference in a new issue