Got to correctly import
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-08-20 08:30:26 +02:00
parent eb72885f5b
commit e040b2e728
8 changed files with 134 additions and 62 deletions

View file

@ -7,7 +7,7 @@ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from typing_extensions import Self
from hellocomputer.config import Settings, settings, StorageEngines
from hellocomputer.config import settings, StorageEngines
from hellocomputer.models import AvailableModels
from . import DDB
@ -172,7 +172,9 @@ class SessionDB(DDB):
@property
def llmsql(self):
return SQLDatabase(self.engine, ignore_tables=["metadata"])
return SQLDatabase(
self.engine
) ## Cannot ignore tables because it creates a connection at import time
@property
def sql_toolkit(self) -> SQLDatabaseToolkit:

View file

@ -1,24 +1,35 @@
from typing import Literal
from langgraph.graph import END, START, StateGraph
from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.nodes import (
intent,
answer_general,
answer_visualization,
)
from hellocomputer.config import settings
from hellocomputer.tools import extract_sid
from hellocomputer.tools.db import SQLSubgraph
from hellocomputer.db.sessions import SessionDB
from hellocomputer.state import SidState
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
def route_intent(state: SidState) -> Literal["general", "query", "visualization"]:
messages = state["messages"]
last_message = messages[-1]
return last_message.content
workflow = StateGraph(MessagesState)
def get_sid(state: SidState) -> Literal["ok"]:
print(state["messages"])
last_message = state["messages"][-1]
sid = extract_sid(last_message)
state.sid = SessionDB(sid)
sql_subgraph = SQLSubgraph()
workflow = StateGraph(SidState)
# Nodes
@ -27,14 +38,13 @@ workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_visualization", answer_visualization)
# Edges
workflow.add_edge(START, "intent")
workflow.add_conditional_edges(
"intent",
route_intent,
{
"general": "answer_general",
"query": "answer_query",
"query": sql_subgraph.start_node,
"visualization": "answer_visualization",
},
)
@ -43,7 +53,7 @@ workflow.add_edge("answer_visualization", END)
# SQL Subgraph
workflow = SQLSubgraph().add_subgraph(
workflow = sql_subgraph.add_nodes_edges(
workflow=workflow, origin="intent", destination=END
)

View file

@ -5,11 +5,16 @@ from starlette.requests import Request
from hellocomputer.graph import app
import os
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["chat"])
async def query(request: Request, sid: str = "", q: str = "") -> str:
user = request.session.get("user") # noqa
response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
content = f"{q}{os.linesep}******{sid}******"
response = await app.ainvoke(
{"messages": [HumanMessage(content=content)]},
)
return response["messages"][-1].content

View file

@ -0,0 +1,8 @@
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from hellocomputer.db.sessions import SessionDB
class SidState(TypedDict):
messages: Annotated[list, add_messages]
sid: SessionDB

View file

@ -1,45 +0,0 @@
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
class DuckdbQueryInput(BaseModel):
query: str = Field(description="Question to be translated to a SQL statement")
session_id: str = Field(description="Session ID necessary to fetch the data")
class DuckdbQueryTool(BaseTool):
name: str = "sql_query"
description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results, and the name of the table with them if the "
"volume of the results is large"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
session = SessionDB(settings, session_id)
session.db.sql(query)
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
session = SessionDB(settings, session_id)
session.db.sql(query)
return "Table"
class PlotHistogramInput(BaseModel):
column_name: str = Field(description="Name of the column containing the values")
table_name: str = Field(description="Name of the table that contains the data")
num_bins: int = Field(description="Number of bins of the histogram")
class PlotHistogramTool(BaseTool):
name: str = "plot_histogram"
description: str = """
Generate a histogram plot given a name of an existing table of the database,
and a name of a column in the table. The default number of bins is 10, but
you can forward the number of bins if you are requested to"""

View file

@ -0,0 +1,74 @@
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
import re
import os
def extract_sid(message: str) -> str:
"""Extract the session id from the initial message
Args:
message (str): Initial message to the entrypoint
Returns:
str: session id
"""
pattern = r"\*{6}(.*?)\*{6}"
matches = re.findall(pattern, message)
return matches[0]
def remove_sid(message: str) -> str:
"""_summary_
Args:
message (str): _description_
Returns:
str: _description_
"""
return os.linesep.join(message.splitlines()[:-1])
class DuckdbQueryInput(BaseModel):
query: str = Field(description="Question to be translated to a SQL statement")
session_id: str = Field(description="Session ID necessary to fetch the data")
class DuckdbQueryTool(BaseTool):
name: str = "sql_query"
description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results, and the name of the table with them if the "
"volume of the results is large"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
session = SessionDB(settings, session_id)
session.db.sql(query)
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
session = SessionDB(settings, session_id)
session.db.sql(query)
return "Table"
class PlotHistogramInput(BaseModel):
column_name: str = Field(description="Name of the column containing the values")
table_name: str = Field(description="Name of the table that contains the data")
num_bins: int = Field(description="Number of bins of the histogram")
class PlotHistogramTool(BaseTool):
name: str = "plot_histogram"
description: str = """
Generate a histogram plot given a name of an existing table of the database,
and a name of a column in the table. The default number of bins is 10, but
you can forward the number of bins if you are requested to"""

View file

@ -1,12 +1,13 @@
from typing import Type
from typing import Type, Literal
from langchain.tools import BaseTool
from langgraph.prebuilt import ToolNode
from langgraph.graph import MessagesState, StateGraph
from langgraph.graph import StateGraph
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.state import SidState
from hellocomputer.db.sessions import SessionDB
from hellocomputer.models import AvailableModels
@ -44,8 +45,12 @@ class SQLSubgraph:
queries
"""
async def call_model(self, state: MessagesState):
db = SessionDB(settings=settings)
@property
def start_node(self):
return "sql_agent"
async def call_model(self, state: SidState):
db = SessionDB(settings=settings).set_session(state.sid)
sql_toolkit = db.sql_toolkit
agent_llm = ChatOpenAI(
@ -66,7 +71,7 @@ class SQLSubgraph:
sql_toolkit = db.sql_toolkit
return ToolNode(sql_toolkit.get_tools())
def add_subgraph(
def add_nodes_edges(
self, workflow: StateGraph, origin: str, destination: str
) -> StateGraph:
"""Creates the nodes and edges of the subgraph given a workflow
@ -80,7 +85,7 @@ class SQLSubgraph:
StateGraph: Resulting workflow
"""
def should_continue(state: MessagesState):
def should_continue(state: SidState):
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
@ -90,7 +95,7 @@ class SQLSubgraph:
workflow.add_node("sql_agent", self.call_model)
workflow.add_node("sql_tool_node", self.query_tool_node)
workflow.add_edge(origin, "sql_agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_conditional_edges("sql_agent", should_continue)
workflow.add_edge("sql_agent", "sql_tool_node")
return workflow

13
test/test_tools.py Normal file
View file

@ -0,0 +1,13 @@
from hellocomputer.tools import extract_sid, remove_sid
message = """This is a message
******sid******
"""
def test_match_sid():
assert extract_sid(message) == "sid"
def test_remove_sid():
assert remove_sid(message) == "This is a message"