Nothing works because I need to find a way to pass configuration to the graph
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
f16bb6b8cf
commit
eb72885f5b
|
@ -3,16 +3,18 @@ from pathlib import Path
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
from langchain_community.utilities.sql_database import SQLDatabase
|
from langchain_community.utilities.sql_database import SQLDatabase
|
||||||
|
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from hellocomputer.config import Settings, StorageEngines
|
from hellocomputer.config import Settings, settings, StorageEngines
|
||||||
|
from hellocomputer.models import AvailableModels
|
||||||
|
|
||||||
from . import DDB
|
from . import DDB
|
||||||
|
|
||||||
|
|
||||||
class SessionDB(DDB):
|
class SessionDB(DDB):
|
||||||
def __init__(self, settings: Settings, sid: str):
|
def set_session(self, sid):
|
||||||
super().__init__(settings=settings)
|
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
# Override storage engine for sessions
|
# Override storage engine for sessions
|
||||||
if settings.storage_engine == StorageEngines.gcs:
|
if settings.storage_engine == StorageEngines.gcs:
|
||||||
|
@ -171,3 +173,13 @@ class SessionDB(DDB):
|
||||||
@property
|
@property
|
||||||
def llmsql(self):
|
def llmsql(self):
|
||||||
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sql_toolkit(self) -> SQLDatabaseToolkit:
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_medium,
|
||||||
|
temperature=0.3,
|
||||||
|
)
|
||||||
|
return SQLDatabaseToolkit(db=self.llmsql, llm=llm)
|
||||||
|
|
|
@ -5,9 +5,11 @@ from langgraph.graph import END, START, MessagesState, StateGraph
|
||||||
from hellocomputer.nodes import (
|
from hellocomputer.nodes import (
|
||||||
intent,
|
intent,
|
||||||
answer_general,
|
answer_general,
|
||||||
answer_query,
|
|
||||||
answer_visualization,
|
answer_visualization,
|
||||||
)
|
)
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
|
||||||
|
from hellocomputer.tools.db import SQLSubgraph
|
||||||
|
|
||||||
|
|
||||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||||
|
@ -22,7 +24,6 @@ workflow = StateGraph(MessagesState)
|
||||||
|
|
||||||
workflow.add_node("intent", intent)
|
workflow.add_node("intent", intent)
|
||||||
workflow.add_node("answer_general", answer_general)
|
workflow.add_node("answer_general", answer_general)
|
||||||
workflow.add_node("answer_query", answer_query)
|
|
||||||
workflow.add_node("answer_visualization", answer_visualization)
|
workflow.add_node("answer_visualization", answer_visualization)
|
||||||
|
|
||||||
# Edges
|
# Edges
|
||||||
|
@ -38,7 +39,12 @@ workflow.add_conditional_edges(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
workflow.add_edge("answer_general", END)
|
workflow.add_edge("answer_general", END)
|
||||||
workflow.add_edge("answer_query", END)
|
|
||||||
workflow.add_edge("answer_visualization", END)
|
workflow.add_edge("answer_visualization", END)
|
||||||
|
|
||||||
|
# SQL Subgraph
|
||||||
|
|
||||||
|
workflow = SQLSubgraph().add_subgraph(
|
||||||
|
workflow=workflow, origin="intent", destination=END
|
||||||
|
)
|
||||||
|
|
||||||
app = workflow.compile()
|
app = workflow.compile()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from langchain_openai import ChatOpenAI
|
from langchain_openai import ChatOpenAI
|
||||||
from langgraph.graph import MessagesState
|
from langgraph.graph import MessagesState
|
||||||
|
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.extraction import initial_intent_parser
|
from hellocomputer.extraction import initial_intent_parser
|
||||||
from hellocomputer.models import AvailableModels
|
from hellocomputer.models import AvailableModels
|
||||||
|
@ -35,17 +36,17 @@ async def answer_general(state: MessagesState):
|
||||||
return {"messages": [await chain.ainvoke({})]}
|
return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
async def answer_query(state: MessagesState):
|
# async def answer_query(state: MessagesState):
|
||||||
llm = ChatOpenAI(
|
# llm = ChatOpenAI(
|
||||||
base_url=settings.llm_base_url,
|
# base_url=settings.llm_base_url,
|
||||||
api_key=settings.llm_api_key,
|
# api_key=settings.llm_api_key,
|
||||||
model=AvailableModels.llama_small,
|
# model=AvailableModels.llama_small,
|
||||||
temperature=0,
|
# temperature=0,
|
||||||
)
|
# )
|
||||||
prompt = await Prompts.sql()
|
# prompt = await Prompts.sql()
|
||||||
chain = prompt | llm
|
# chain = prompt | llm
|
||||||
|
#
|
||||||
return {"messages": [await chain.ainvoke({})]}
|
# return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
async def answer_visualization(state: MessagesState):
|
async def answer_visualization(state: MessagesState):
|
||||||
|
|
0
src/hellocomputer/tools/__init__.py
Normal file
0
src/hellocomputer/tools/__init__.py
Normal file
96
src/hellocomputer/tools/db.py
Normal file
96
src/hellocomputer/tools/db.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from langgraph.prebuilt import ToolNode
|
||||||
|
from langgraph.graph import MessagesState, StateGraph
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
from hellocomputer.models import AvailableModels
|
||||||
|
|
||||||
|
|
||||||
|
## This in case I need to create more ReAct agents
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbQueryInput(BaseModel):
|
||||||
|
query: str = Field(description="Question to be translated to a SQL statement")
|
||||||
|
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbQueryTool(BaseTool):
|
||||||
|
name: str = "sql_query"
|
||||||
|
description: str = "Run a SQL query in the database containing all the datasets "
|
||||||
|
"and provide a summary of the results, and the name of the table with them if the "
|
||||||
|
"volume of the results is large"
|
||||||
|
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||||
|
|
||||||
|
def _run(self, query: str, session_id: str) -> str:
|
||||||
|
"""Run the query"""
|
||||||
|
session = SessionDB(settings, session_id)
|
||||||
|
session.db.sql(query)
|
||||||
|
|
||||||
|
async def _arun(self, query: str, session_id: str) -> str:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
session = SessionDB(settings, session_id)
|
||||||
|
session.db.sql(query)
|
||||||
|
return "Table"
|
||||||
|
|
||||||
|
|
||||||
|
class SQLSubgraph:
|
||||||
|
"""
|
||||||
|
Creates the question-answering agent that generates and runs SQL
|
||||||
|
queries
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def call_model(self, state: MessagesState):
|
||||||
|
db = SessionDB(settings=settings)
|
||||||
|
sql_toolkit = db.sql_toolkit
|
||||||
|
|
||||||
|
agent_llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.firefunction_2,
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=256,
|
||||||
|
).bind_tools(sql_toolkit.get_tools())
|
||||||
|
|
||||||
|
messages = state["messages"]
|
||||||
|
response = agent_llm.ainvoke(messages)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query_tool_node(self) -> ToolNode:
|
||||||
|
db = SessionDB(settings=settings)
|
||||||
|
sql_toolkit = db.sql_toolkit
|
||||||
|
return ToolNode(sql_toolkit.get_tools())
|
||||||
|
|
||||||
|
def add_subgraph(
|
||||||
|
self, workflow: StateGraph, origin: str, destination: str
|
||||||
|
) -> StateGraph:
|
||||||
|
"""Creates the nodes and edges of the subgraph given a workflow
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow (StateGraph): Workflow that will get nodes and edges added
|
||||||
|
origin (str): Origin node
|
||||||
|
destination (str): Destination node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
StateGraph: Resulting workflow
|
||||||
|
"""
|
||||||
|
|
||||||
|
def should_continue(state: MessagesState):
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1]
|
||||||
|
if last_message.tool_calls:
|
||||||
|
return destination
|
||||||
|
return "__end__"
|
||||||
|
|
||||||
|
workflow.add_node("sql_agent", self.call_model)
|
||||||
|
workflow.add_node("sql_tool_node", self.query_tool_node)
|
||||||
|
workflow.add_edge(origin, "sql_agent")
|
||||||
|
workflow.add_conditional_edges("agent", should_continue)
|
||||||
|
workflow.add_edge("sql_agent", "sql_tool_node")
|
||||||
|
|
||||||
|
return workflow
|
16
src/hellocomputer/tools/viz.py
Normal file
16
src/hellocomputer/tools/viz.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class PlotHistogramInput(BaseModel):
|
||||||
|
column_name: str = Field(description="Name of the column containing the values")
|
||||||
|
table_name: str = Field(description="Name of the table that contains the data")
|
||||||
|
num_bins: int = Field(description="Number of bins of the histogram")
|
||||||
|
|
||||||
|
|
||||||
|
class PlotHistogramTool(BaseTool):
|
||||||
|
name: str = "plot_histogram"
|
||||||
|
description: str = """
|
||||||
|
Generate a histogram plot given a name of an existing table of the database,
|
||||||
|
and a name of a column in the table. The default number of bins is 10, but
|
||||||
|
you can forward the number of bins if you are requested to"""
|
Loading…
Reference in a new issue