Nothing works because I need to find a way to pass configuration to the graph
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-08-07 11:29:29 +02:00
parent f16bb6b8cf
commit eb72885f5b
6 changed files with 148 additions and 17 deletions

View file

@ -3,16 +3,18 @@ from pathlib import Path
import duckdb
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
from typing_extensions import Self
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.config import Settings, settings, StorageEngines
from hellocomputer.models import AvailableModels
from . import DDB
class SessionDB(DDB):
def __init__(self, settings: Settings, sid: str):
super().__init__(settings=settings)
def set_session(self, sid):
self.sid = sid
# Override storage engine for sessions
if settings.storage_engine == StorageEngines.gcs:
@ -171,3 +173,13 @@ class SessionDB(DDB):
@property
def llmsql(self):
return SQLDatabase(self.engine, ignore_tables=["metadata"])
@property
def sql_toolkit(self) -> SQLDatabaseToolkit:
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_medium,
temperature=0.3,
)
return SQLDatabaseToolkit(db=self.llmsql, llm=llm)

View file

@ -5,9 +5,11 @@ from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.nodes import (
intent,
answer_general,
answer_query,
answer_visualization,
)
from hellocomputer.config import settings
from hellocomputer.tools.db import SQLSubgraph
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
@ -22,7 +24,6 @@ workflow = StateGraph(MessagesState)
workflow.add_node("intent", intent)
workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_query", answer_query)
workflow.add_node("answer_visualization", answer_visualization)
# Edges
@ -38,7 +39,12 @@ workflow.add_conditional_edges(
},
)
workflow.add_edge("answer_general", END)
workflow.add_edge("answer_query", END)
workflow.add_edge("answer_visualization", END)
# SQL Subgraph
workflow = SQLSubgraph().add_subgraph(
workflow=workflow, origin="intent", destination=END
)
app = workflow.compile()

View file

@ -1,6 +1,7 @@
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from hellocomputer.config import settings
from hellocomputer.extraction import initial_intent_parser
from hellocomputer.models import AvailableModels
@ -35,17 +36,17 @@ async def answer_general(state: MessagesState):
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
# async def answer_query(state: MessagesState):
# llm = ChatOpenAI(
# base_url=settings.llm_base_url,
# api_key=settings.llm_api_key,
# model=AvailableModels.llama_small,
# temperature=0,
# )
# prompt = await Prompts.sql()
# chain = prompt | llm
#
# return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):

View file

View file

@ -0,0 +1,96 @@
from typing import Type
from langchain.tools import BaseTool
from langgraph.prebuilt import ToolNode
from langgraph.graph import MessagesState, StateGraph
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
from hellocomputer.models import AvailableModels
## This in case I need to create more ReAct agents
class DuckdbQueryInput(BaseModel):
query: str = Field(description="Question to be translated to a SQL statement")
session_id: str = Field(description="Session ID necessary to fetch the data")
class DuckdbQueryTool(BaseTool):
name: str = "sql_query"
description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results, and the name of the table with them if the "
"volume of the results is large"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
session = SessionDB(settings, session_id)
session.db.sql(query)
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
session = SessionDB(settings, session_id)
session.db.sql(query)
return "Table"
class SQLSubgraph:
"""
Creates the question-answering agent that generates and runs SQL
queries
"""
async def call_model(self, state: MessagesState):
db = SessionDB(settings=settings)
sql_toolkit = db.sql_toolkit
agent_llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.firefunction_2,
temperature=0.5,
max_tokens=256,
).bind_tools(sql_toolkit.get_tools())
messages = state["messages"]
response = agent_llm.ainvoke(messages)
return {"messages": [response]}
@property
def query_tool_node(self) -> ToolNode:
db = SessionDB(settings=settings)
sql_toolkit = db.sql_toolkit
return ToolNode(sql_toolkit.get_tools())
def add_subgraph(
self, workflow: StateGraph, origin: str, destination: str
) -> StateGraph:
"""Creates the nodes and edges of the subgraph given a workflow
Args:
workflow (StateGraph): Workflow that will get nodes and edges added
origin (str): Origin node
destination (str): Destination node
Returns:
StateGraph: Resulting workflow
"""
def should_continue(state: MessagesState):
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return destination
return "__end__"
workflow.add_node("sql_agent", self.call_model)
workflow.add_node("sql_tool_node", self.query_tool_node)
workflow.add_edge(origin, "sql_agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("sql_agent", "sql_tool_node")
return workflow

View file

@ -0,0 +1,16 @@
from langchain.tools import BaseTool
from pydantic import BaseModel, Field
class PlotHistogramInput(BaseModel):
column_name: str = Field(description="Name of the column containing the values")
table_name: str = Field(description="Name of the table that contains the data")
num_bins: int = Field(description="Number of bins of the histogram")
class PlotHistogramTool(BaseTool):
name: str = "plot_histogram"
description: str = """
Generate a histogram plot given a name of an existing table of the database,
and a name of a column in the table. The default number of bins is 10, but
you can forward the number of bins if you are requested to"""