Got to correctly import
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-08-20 08:30:26 +02:00
parent eb72885f5b
commit e040b2e728
8 changed files with 134 additions and 62 deletions

View file

@ -7,7 +7,7 @@ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from typing_extensions import Self from typing_extensions import Self
from hellocomputer.config import Settings, settings, StorageEngines from hellocomputer.config import settings, StorageEngines
from hellocomputer.models import AvailableModels from hellocomputer.models import AvailableModels
from . import DDB from . import DDB
@ -172,7 +172,9 @@ class SessionDB(DDB):
@property @property
def llmsql(self): def llmsql(self):
return SQLDatabase(self.engine, ignore_tables=["metadata"]) return SQLDatabase(
self.engine
) ## Cannot ignore tables because it creates a connection at import time
@property @property
def sql_toolkit(self) -> SQLDatabaseToolkit: def sql_toolkit(self) -> SQLDatabaseToolkit:

View file

@ -1,24 +1,35 @@
from typing import Literal from typing import Literal
from langgraph.graph import END, START, StateGraph
from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.nodes import ( from hellocomputer.nodes import (
intent, intent,
answer_general, answer_general,
answer_visualization, answer_visualization,
) )
from hellocomputer.config import settings
from hellocomputer.tools import extract_sid
from hellocomputer.tools.db import SQLSubgraph from hellocomputer.tools.db import SQLSubgraph
from hellocomputer.db.sessions import SessionDB
from hellocomputer.state import SidState
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: def route_intent(state: SidState) -> Literal["general", "query", "visualization"]:
messages = state["messages"] messages = state["messages"]
last_message = messages[-1] last_message = messages[-1]
return last_message.content return last_message.content
workflow = StateGraph(MessagesState) def get_sid(state: SidState) -> Literal["ok"]:
print(state["messages"])
last_message = state["messages"][-1]
sid = extract_sid(last_message)
state.sid = SessionDB(sid)
sql_subgraph = SQLSubgraph()
workflow = StateGraph(SidState)
# Nodes # Nodes
@ -27,14 +38,13 @@ workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_visualization", answer_visualization) workflow.add_node("answer_visualization", answer_visualization)
# Edges # Edges
workflow.add_edge(START, "intent") workflow.add_edge(START, "intent")
workflow.add_conditional_edges( workflow.add_conditional_edges(
"intent", "intent",
route_intent, route_intent,
{ {
"general": "answer_general", "general": "answer_general",
"query": "answer_query", "query": sql_subgraph.start_node,
"visualization": "answer_visualization", "visualization": "answer_visualization",
}, },
) )
@ -43,7 +53,7 @@ workflow.add_edge("answer_visualization", END)
# SQL Subgraph # SQL Subgraph
workflow = SQLSubgraph().add_subgraph( workflow = sql_subgraph.add_nodes_edges(
workflow=workflow, origin="intent", destination=END workflow=workflow, origin="intent", destination=END
) )

View file

@ -5,11 +5,16 @@ from starlette.requests import Request
from hellocomputer.graph import app from hellocomputer.graph import app
import os
router = APIRouter() router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["chat"]) @router.get("/query", response_class=PlainTextResponse, tags=["chat"])
async def query(request: Request, sid: str = "", q: str = "") -> str: async def query(request: Request, sid: str = "", q: str = "") -> str:
user = request.session.get("user") # noqa user = request.session.get("user") # noqa
response = await app.ainvoke({"messages": [HumanMessage(content=q)]}) content = f"{q}{os.linesep}******{sid}******"
response = await app.ainvoke(
{"messages": [HumanMessage(content=content)]},
)
return response["messages"][-1].content return response["messages"][-1].content

View file

@ -0,0 +1,8 @@
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from hellocomputer.db.sessions import SessionDB
class SidState(TypedDict):
messages: Annotated[list, add_messages]
sid: SessionDB

View file

@ -1,45 +0,0 @@
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
class DuckdbQueryInput(BaseModel):
query: str = Field(description="Question to be translated to a SQL statement")
session_id: str = Field(description="Session ID necessary to fetch the data")
class DuckdbQueryTool(BaseTool):
name: str = "sql_query"
description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results, and the name of the table with them if the "
"volume of the results is large"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
session = SessionDB(settings, session_id)
session.db.sql(query)
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
session = SessionDB(settings, session_id)
session.db.sql(query)
return "Table"
class PlotHistogramInput(BaseModel):
column_name: str = Field(description="Name of the column containing the values")
table_name: str = Field(description="Name of the table that contains the data")
num_bins: int = Field(description="Number of bins of the histogram")
class PlotHistogramTool(BaseTool):
name: str = "plot_histogram"
description: str = """
Generate a histogram plot given a name of an existing table of the database,
and a name of a column in the table. The default number of bins is 10, but
you can forward the number of bins if you are requested to"""

View file

@ -0,0 +1,74 @@
from typing import Type
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
import re
import os
def extract_sid(message: str) -> str:
"""Extract the session id from the initial message
Args:
message (str): Initial message to the entrypoint
Returns:
str: session id
"""
pattern = r"\*{6}(.*?)\*{6}"
matches = re.findall(pattern, message)
return matches[0]
def remove_sid(message: str) -> str:
"""_summary_
Args:
message (str): _description_
Returns:
str: _description_
"""
return os.linesep.join(message.splitlines()[:-1])
class DuckdbQueryInput(BaseModel):
query: str = Field(description="Question to be translated to a SQL statement")
session_id: str = Field(description="Session ID necessary to fetch the data")
class DuckdbQueryTool(BaseTool):
name: str = "sql_query"
description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results, and the name of the table with them if the "
"volume of the results is large"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
session = SessionDB(settings, session_id)
session.db.sql(query)
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
session = SessionDB(settings, session_id)
session.db.sql(query)
return "Table"
class PlotHistogramInput(BaseModel):
column_name: str = Field(description="Name of the column containing the values")
table_name: str = Field(description="Name of the table that contains the data")
num_bins: int = Field(description="Number of bins of the histogram")
class PlotHistogramTool(BaseTool):
name: str = "plot_histogram"
description: str = """
Generate a histogram plot given a name of an existing table of the database,
and a name of a column in the table. The default number of bins is 10, but
you can forward the number of bins if you are requested to"""

View file

@ -1,12 +1,13 @@
from typing import Type from typing import Type, Literal
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langgraph.prebuilt import ToolNode from langgraph.prebuilt import ToolNode
from langgraph.graph import MessagesState, StateGraph from langgraph.graph import StateGraph
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.state import SidState
from hellocomputer.db.sessions import SessionDB from hellocomputer.db.sessions import SessionDB
from hellocomputer.models import AvailableModels from hellocomputer.models import AvailableModels
@ -44,8 +45,12 @@ class SQLSubgraph:
queries queries
""" """
async def call_model(self, state: MessagesState): @property
db = SessionDB(settings=settings) def start_node(self):
return "sql_agent"
async def call_model(self, state: SidState):
db = SessionDB(settings=settings).set_session(state.sid)
sql_toolkit = db.sql_toolkit sql_toolkit = db.sql_toolkit
agent_llm = ChatOpenAI( agent_llm = ChatOpenAI(
@ -66,7 +71,7 @@ class SQLSubgraph:
sql_toolkit = db.sql_toolkit sql_toolkit = db.sql_toolkit
return ToolNode(sql_toolkit.get_tools()) return ToolNode(sql_toolkit.get_tools())
def add_subgraph( def add_nodes_edges(
self, workflow: StateGraph, origin: str, destination: str self, workflow: StateGraph, origin: str, destination: str
) -> StateGraph: ) -> StateGraph:
"""Creates the nodes and edges of the subgraph given a workflow """Creates the nodes and edges of the subgraph given a workflow
@ -80,7 +85,7 @@ class SQLSubgraph:
StateGraph: Resulting workflow StateGraph: Resulting workflow
""" """
def should_continue(state: MessagesState): def should_continue(state: SidState):
messages = state["messages"] messages = state["messages"]
last_message = messages[-1] last_message = messages[-1]
if last_message.tool_calls: if last_message.tool_calls:
@ -90,7 +95,7 @@ class SQLSubgraph:
workflow.add_node("sql_agent", self.call_model) workflow.add_node("sql_agent", self.call_model)
workflow.add_node("sql_tool_node", self.query_tool_node) workflow.add_node("sql_tool_node", self.query_tool_node)
workflow.add_edge(origin, "sql_agent") workflow.add_edge(origin, "sql_agent")
workflow.add_conditional_edges("agent", should_continue) workflow.add_conditional_edges("sql_agent", should_continue)
workflow.add_edge("sql_agent", "sql_tool_node") workflow.add_edge("sql_agent", "sql_tool_node")
return workflow return workflow

13
test/test_tools.py Normal file
View file

@ -0,0 +1,13 @@
from hellocomputer.tools import extract_sid, remove_sid
message = """This is a message
******sid******
"""
def test_match_sid():
assert extract_sid(message) == "sid"
def test_remove_sid():
assert remove_sid(message) == "This is a message"