This commit is contained in:
		
							parent
							
								
									eb72885f5b
								
							
						
					
					
						commit
						e040b2e728
					
				| 
						 | 
				
			
			@ -7,7 +7,7 @@ from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
 | 
			
		|||
from langchain_openai import ChatOpenAI
 | 
			
		||||
from typing_extensions import Self
 | 
			
		||||
 | 
			
		||||
from hellocomputer.config import Settings, settings, StorageEngines
 | 
			
		||||
from hellocomputer.config import settings, StorageEngines
 | 
			
		||||
from hellocomputer.models import AvailableModels
 | 
			
		||||
 | 
			
		||||
from . import DDB
 | 
			
		||||
| 
						 | 
				
			
			@ -172,7 +172,9 @@ class SessionDB(DDB):
 | 
			
		|||
 | 
			
		||||
    @property
 | 
			
		||||
    def llmsql(self):
 | 
			
		||||
        return SQLDatabase(self.engine, ignore_tables=["metadata"])
 | 
			
		||||
        return SQLDatabase(
 | 
			
		||||
            self.engine
 | 
			
		||||
        )  ## Cannot ignore tables because it creates a connection at import time
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def sql_toolkit(self) -> SQLDatabaseToolkit:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,24 +1,35 @@
 | 
			
		|||
from typing import Literal
 | 
			
		||||
from langgraph.graph import END, START, StateGraph
 | 
			
		||||
 | 
			
		||||
from langgraph.graph import END, START, MessagesState, StateGraph
 | 
			
		||||
 | 
			
		||||
from hellocomputer.nodes import (
 | 
			
		||||
    intent,
 | 
			
		||||
    answer_general,
 | 
			
		||||
    answer_visualization,
 | 
			
		||||
)
 | 
			
		||||
from hellocomputer.config import settings
 | 
			
		||||
 | 
			
		||||
from hellocomputer.tools import extract_sid
 | 
			
		||||
from hellocomputer.tools.db import SQLSubgraph
 | 
			
		||||
from hellocomputer.db.sessions import SessionDB
 | 
			
		||||
from hellocomputer.state import SidState
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
 | 
			
		||||
def route_intent(state: SidState) -> Literal["general", "query", "visualization"]:
 | 
			
		||||
    messages = state["messages"]
 | 
			
		||||
    last_message = messages[-1]
 | 
			
		||||
    return last_message.content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
workflow = StateGraph(MessagesState)
 | 
			
		||||
def get_sid(state: SidState) -> Literal["ok"]:
 | 
			
		||||
    print(state["messages"])
 | 
			
		||||
    last_message = state["messages"][-1]
 | 
			
		||||
    sid = extract_sid(last_message)
 | 
			
		||||
    state.sid = SessionDB(sid)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
sql_subgraph = SQLSubgraph()
 | 
			
		||||
 | 
			
		||||
workflow = StateGraph(SidState)
 | 
			
		||||
 | 
			
		||||
# Nodes
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -27,14 +38,13 @@ workflow.add_node("answer_general", answer_general)
 | 
			
		|||
workflow.add_node("answer_visualization", answer_visualization)
 | 
			
		||||
 | 
			
		||||
# Edges
 | 
			
		||||
 | 
			
		||||
workflow.add_edge(START, "intent")
 | 
			
		||||
workflow.add_conditional_edges(
 | 
			
		||||
    "intent",
 | 
			
		||||
    route_intent,
 | 
			
		||||
    {
 | 
			
		||||
        "general": "answer_general",
 | 
			
		||||
        "query": "answer_query",
 | 
			
		||||
        "query": sql_subgraph.start_node,
 | 
			
		||||
        "visualization": "answer_visualization",
 | 
			
		||||
    },
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -43,7 +53,7 @@ workflow.add_edge("answer_visualization", END)
 | 
			
		|||
 | 
			
		||||
# SQL Subgraph
 | 
			
		||||
 | 
			
		||||
workflow = SQLSubgraph().add_subgraph(
 | 
			
		||||
workflow = sql_subgraph.add_nodes_edges(
 | 
			
		||||
    workflow=workflow, origin="intent", destination=END
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,11 +5,16 @@ from starlette.requests import Request
 | 
			
		|||
 | 
			
		||||
from hellocomputer.graph import app
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
router = APIRouter()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.get("/query", response_class=PlainTextResponse, tags=["chat"])
 | 
			
		||||
async def query(request: Request, sid: str = "", q: str = "") -> str:
 | 
			
		||||
    user = request.session.get("user")  # noqa
 | 
			
		||||
    response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
 | 
			
		||||
    content = f"{q}{os.linesep}******{sid}******"
 | 
			
		||||
    response = await app.ainvoke(
 | 
			
		||||
        {"messages": [HumanMessage(content=content)]},
 | 
			
		||||
    )
 | 
			
		||||
    return response["messages"][-1].content
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										8
									
								
								src/hellocomputer/state.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								src/hellocomputer/state.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
from typing import TypedDict, Annotated
 | 
			
		||||
from langgraph.graph.message import add_messages
 | 
			
		||||
from hellocomputer.db.sessions import SessionDB
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SidState(TypedDict):
 | 
			
		||||
    messages: Annotated[list, add_messages]
 | 
			
		||||
    sid: SessionDB
 | 
			
		||||
| 
						 | 
				
			
			@ -1,45 +0,0 @@
 | 
			
		|||
from typing import Type
 | 
			
		||||
 | 
			
		||||
from langchain.tools import BaseTool
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from hellocomputer.config import settings
 | 
			
		||||
from hellocomputer.db.sessions import SessionDB
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DuckdbQueryInput(BaseModel):
 | 
			
		||||
    query: str = Field(description="Question to be translated to a SQL statement")
 | 
			
		||||
    session_id: str = Field(description="Session ID necessary to fetch the data")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DuckdbQueryTool(BaseTool):
 | 
			
		||||
    name: str = "sql_query"
 | 
			
		||||
    description: str = "Run a SQL query in the database containing all the datasets "
 | 
			
		||||
    "and provide a summary of the results, and the name of the table with them if the "
 | 
			
		||||
    "volume of the results is large"
 | 
			
		||||
    args_schema: Type[BaseModel] = DuckdbQueryInput
 | 
			
		||||
 | 
			
		||||
    def _run(self, query: str, session_id: str) -> str:
 | 
			
		||||
        """Run the query"""
 | 
			
		||||
        session = SessionDB(settings, session_id)
 | 
			
		||||
        session.db.sql(query)
 | 
			
		||||
 | 
			
		||||
    async def _arun(self, query: str, session_id: str) -> str:
 | 
			
		||||
        """Use the tool asynchronously."""
 | 
			
		||||
        session = SessionDB(settings, session_id)
 | 
			
		||||
        session.db.sql(query)
 | 
			
		||||
        return "Table"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PlotHistogramInput(BaseModel):
 | 
			
		||||
    column_name: str = Field(description="Name of the column containing the values")
 | 
			
		||||
    table_name: str = Field(description="Name of the table that contains the data")
 | 
			
		||||
    num_bins: int = Field(description="Number of bins of the histogram")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PlotHistogramTool(BaseTool):
 | 
			
		||||
    name: str = "plot_histogram"
 | 
			
		||||
    description: str = """
 | 
			
		||||
    Generate a histogram plot given a name of an existing table of the database, 
 | 
			
		||||
    and a name of a column in the table. The default number of bins is 10, but 
 | 
			
		||||
    you can forward the number of bins if you are requested to"""
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,74 @@
 | 
			
		|||
from typing import Type
 | 
			
		||||
 | 
			
		||||
from langchain.tools import BaseTool
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from hellocomputer.config import settings
 | 
			
		||||
from hellocomputer.db.sessions import SessionDB
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def extract_sid(message: str) -> str:
 | 
			
		||||
    """Extract the session id from the initial message
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        message (str): Initial message to the entrypoint
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        str: session id
 | 
			
		||||
    """
 | 
			
		||||
    pattern = r"\*{6}(.*?)\*{6}"
 | 
			
		||||
    matches = re.findall(pattern, message)
 | 
			
		||||
    return matches[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_sid(message: str) -> str:
 | 
			
		||||
    """_summary_
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        message (str): _description_
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        str: _description_
 | 
			
		||||
    """
 | 
			
		||||
    return os.linesep.join(message.splitlines()[:-1])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DuckdbQueryInput(BaseModel):
 | 
			
		||||
    query: str = Field(description="Question to be translated to a SQL statement")
 | 
			
		||||
    session_id: str = Field(description="Session ID necessary to fetch the data")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DuckdbQueryTool(BaseTool):
 | 
			
		||||
    name: str = "sql_query"
 | 
			
		||||
    description: str = "Run a SQL query in the database containing all the datasets "
 | 
			
		||||
    "and provide a summary of the results, and the name of the table with them if the "
 | 
			
		||||
    "volume of the results is large"
 | 
			
		||||
    args_schema: Type[BaseModel] = DuckdbQueryInput
 | 
			
		||||
 | 
			
		||||
    def _run(self, query: str, session_id: str) -> str:
 | 
			
		||||
        """Run the query"""
 | 
			
		||||
        session = SessionDB(settings, session_id)
 | 
			
		||||
        session.db.sql(query)
 | 
			
		||||
 | 
			
		||||
    async def _arun(self, query: str, session_id: str) -> str:
 | 
			
		||||
        """Use the tool asynchronously."""
 | 
			
		||||
        session = SessionDB(settings, session_id)
 | 
			
		||||
        session.db.sql(query)
 | 
			
		||||
        return "Table"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PlotHistogramInput(BaseModel):
 | 
			
		||||
    column_name: str = Field(description="Name of the column containing the values")
 | 
			
		||||
    table_name: str = Field(description="Name of the table that contains the data")
 | 
			
		||||
    num_bins: int = Field(description="Number of bins of the histogram")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PlotHistogramTool(BaseTool):
 | 
			
		||||
    name: str = "plot_histogram"
 | 
			
		||||
    description: str = """
 | 
			
		||||
    Generate a histogram plot given a name of an existing table of the database, 
 | 
			
		||||
    and a name of a column in the table. The default number of bins is 10, but 
 | 
			
		||||
    you can forward the number of bins if you are requested to"""
 | 
			
		||||
| 
						 | 
				
			
			@ -1,12 +1,13 @@
 | 
			
		|||
from typing import Type
 | 
			
		||||
from typing import Type, Literal
 | 
			
		||||
 | 
			
		||||
from langchain.tools import BaseTool
 | 
			
		||||
from langgraph.prebuilt import ToolNode
 | 
			
		||||
from langgraph.graph import MessagesState, StateGraph
 | 
			
		||||
from langgraph.graph import StateGraph
 | 
			
		||||
from langchain_openai import ChatOpenAI
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
from hellocomputer.config import settings
 | 
			
		||||
from hellocomputer.state import SidState
 | 
			
		||||
from hellocomputer.db.sessions import SessionDB
 | 
			
		||||
from hellocomputer.models import AvailableModels
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -44,8 +45,12 @@ class SQLSubgraph:
 | 
			
		|||
    queries
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    async def call_model(self, state: MessagesState):
 | 
			
		||||
        db = SessionDB(settings=settings)
 | 
			
		||||
    @property
 | 
			
		||||
    def start_node(self):
 | 
			
		||||
        return "sql_agent"
 | 
			
		||||
 | 
			
		||||
    async def call_model(self, state: SidState):
 | 
			
		||||
        db = SessionDB(settings=settings).set_session(state.sid)
 | 
			
		||||
        sql_toolkit = db.sql_toolkit
 | 
			
		||||
 | 
			
		||||
        agent_llm = ChatOpenAI(
 | 
			
		||||
| 
						 | 
				
			
			@ -66,7 +71,7 @@ class SQLSubgraph:
 | 
			
		|||
        sql_toolkit = db.sql_toolkit
 | 
			
		||||
        return ToolNode(sql_toolkit.get_tools())
 | 
			
		||||
 | 
			
		||||
    def add_subgraph(
 | 
			
		||||
    def add_nodes_edges(
 | 
			
		||||
        self, workflow: StateGraph, origin: str, destination: str
 | 
			
		||||
    ) -> StateGraph:
 | 
			
		||||
        """Creates the nodes and edges of the subgraph given a workflow
 | 
			
		||||
| 
						 | 
				
			
			@ -80,7 +85,7 @@ class SQLSubgraph:
 | 
			
		|||
            StateGraph: Resulting workflow
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def should_continue(state: MessagesState):
 | 
			
		||||
        def should_continue(state: SidState):
 | 
			
		||||
            messages = state["messages"]
 | 
			
		||||
            last_message = messages[-1]
 | 
			
		||||
            if last_message.tool_calls:
 | 
			
		||||
| 
						 | 
				
			
			@ -90,7 +95,7 @@ class SQLSubgraph:
 | 
			
		|||
        workflow.add_node("sql_agent", self.call_model)
 | 
			
		||||
        workflow.add_node("sql_tool_node", self.query_tool_node)
 | 
			
		||||
        workflow.add_edge(origin, "sql_agent")
 | 
			
		||||
        workflow.add_conditional_edges("agent", should_continue)
 | 
			
		||||
        workflow.add_conditional_edges("sql_agent", should_continue)
 | 
			
		||||
        workflow.add_edge("sql_agent", "sql_tool_node")
 | 
			
		||||
 | 
			
		||||
        return workflow
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										13
									
								
								test/test_tools.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								test/test_tools.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,13 @@
 | 
			
		|||
from hellocomputer.tools import extract_sid, remove_sid
 | 
			
		||||
 | 
			
		||||
message = """This is a message
 | 
			
		||||
******sid******
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_match_sid():
 | 
			
		||||
    assert extract_sid(message) == "sid"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_remove_sid():
 | 
			
		||||
    assert remove_sid(message) == "This is a message"
 | 
			
		||||
		Loading…
	
		Reference in a new issue