Nothing works because I need to find a way to pass configuration to the graph
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
f16bb6b8cf
commit
eb72885f5b
|
@ -3,16 +3,18 @@ from pathlib import Path
|
|||
|
||||
import duckdb
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
|
||||
from langchain_openai import ChatOpenAI
|
||||
from typing_extensions import Self
|
||||
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.config import Settings, settings, StorageEngines
|
||||
from hellocomputer.models import AvailableModels
|
||||
|
||||
from . import DDB
|
||||
|
||||
|
||||
class SessionDB(DDB):
|
||||
def __init__(self, settings: Settings, sid: str):
|
||||
super().__init__(settings=settings)
|
||||
def set_session(self, sid):
|
||||
self.sid = sid
|
||||
# Override storage engine for sessions
|
||||
if settings.storage_engine == StorageEngines.gcs:
|
||||
|
@ -171,3 +173,13 @@ class SessionDB(DDB):
|
|||
@property
|
||||
def llmsql(self):
|
||||
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
||||
|
||||
@property
|
||||
def sql_toolkit(self) -> SQLDatabaseToolkit:
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_medium,
|
||||
temperature=0.3,
|
||||
)
|
||||
return SQLDatabaseToolkit(db=self.llmsql, llm=llm)
|
||||
|
|
|
@ -5,9 +5,11 @@ from langgraph.graph import END, START, MessagesState, StateGraph
|
|||
from hellocomputer.nodes import (
|
||||
intent,
|
||||
answer_general,
|
||||
answer_query,
|
||||
answer_visualization,
|
||||
)
|
||||
from hellocomputer.config import settings
|
||||
|
||||
from hellocomputer.tools.db import SQLSubgraph
|
||||
|
||||
|
||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||
|
@ -22,7 +24,6 @@ workflow = StateGraph(MessagesState)
|
|||
|
||||
workflow.add_node("intent", intent)
|
||||
workflow.add_node("answer_general", answer_general)
|
||||
workflow.add_node("answer_query", answer_query)
|
||||
workflow.add_node("answer_visualization", answer_visualization)
|
||||
|
||||
# Edges
|
||||
|
@ -38,7 +39,12 @@ workflow.add_conditional_edges(
|
|||
},
|
||||
)
|
||||
workflow.add_edge("answer_general", END)
|
||||
workflow.add_edge("answer_query", END)
|
||||
workflow.add_edge("answer_visualization", END)
|
||||
|
||||
# SQL Subgraph
|
||||
|
||||
workflow = SQLSubgraph().add_subgraph(
|
||||
workflow=workflow, origin="intent", destination=END
|
||||
)
|
||||
|
||||
app = workflow.compile()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.extraction import initial_intent_parser
|
||||
from hellocomputer.models import AvailableModels
|
||||
|
@ -35,17 +36,17 @@ async def answer_general(state: MessagesState):
|
|||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_query(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.sql()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
# async def answer_query(state: MessagesState):
|
||||
# llm = ChatOpenAI(
|
||||
# base_url=settings.llm_base_url,
|
||||
# api_key=settings.llm_api_key,
|
||||
# model=AvailableModels.llama_small,
|
||||
# temperature=0,
|
||||
# )
|
||||
# prompt = await Prompts.sql()
|
||||
# chain = prompt | llm
|
||||
#
|
||||
# return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_visualization(state: MessagesState):
|
||||
|
|
0
src/hellocomputer/tools/__init__.py
Normal file
0
src/hellocomputer/tools/__init__.py
Normal file
96
src/hellocomputer/tools/db.py
Normal file
96
src/hellocomputer/tools/db.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
from typing import Type
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langgraph.graph import MessagesState, StateGraph
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.models import AvailableModels
|
||||
|
||||
|
||||
## This in case I need to create more ReAct agents
|
||||
|
||||
|
||||
class DuckdbQueryInput(BaseModel):
|
||||
query: str = Field(description="Question to be translated to a SQL statement")
|
||||
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||
|
||||
|
||||
class DuckdbQueryTool(BaseTool):
|
||||
name: str = "sql_query"
|
||||
description: str = "Run a SQL query in the database containing all the datasets "
|
||||
"and provide a summary of the results, and the name of the table with them if the "
|
||||
"volume of the results is large"
|
||||
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||
|
||||
def _run(self, query: str, session_id: str) -> str:
|
||||
"""Run the query"""
|
||||
session = SessionDB(settings, session_id)
|
||||
session.db.sql(query)
|
||||
|
||||
async def _arun(self, query: str, session_id: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
session = SessionDB(settings, session_id)
|
||||
session.db.sql(query)
|
||||
return "Table"
|
||||
|
||||
|
||||
class SQLSubgraph:
|
||||
"""
|
||||
Creates the question-answering agent that generates and runs SQL
|
||||
queries
|
||||
"""
|
||||
|
||||
async def call_model(self, state: MessagesState):
|
||||
db = SessionDB(settings=settings)
|
||||
sql_toolkit = db.sql_toolkit
|
||||
|
||||
agent_llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.firefunction_2,
|
||||
temperature=0.5,
|
||||
max_tokens=256,
|
||||
).bind_tools(sql_toolkit.get_tools())
|
||||
|
||||
messages = state["messages"]
|
||||
response = agent_llm.ainvoke(messages)
|
||||
return {"messages": [response]}
|
||||
|
||||
@property
|
||||
def query_tool_node(self) -> ToolNode:
|
||||
db = SessionDB(settings=settings)
|
||||
sql_toolkit = db.sql_toolkit
|
||||
return ToolNode(sql_toolkit.get_tools())
|
||||
|
||||
def add_subgraph(
|
||||
self, workflow: StateGraph, origin: str, destination: str
|
||||
) -> StateGraph:
|
||||
"""Creates the nodes and edges of the subgraph given a workflow
|
||||
|
||||
Args:
|
||||
workflow (StateGraph): Workflow that will get nodes and edges added
|
||||
origin (str): Origin node
|
||||
destination (str): Destination node
|
||||
|
||||
Returns:
|
||||
StateGraph: Resulting workflow
|
||||
"""
|
||||
|
||||
def should_continue(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
if last_message.tool_calls:
|
||||
return destination
|
||||
return "__end__"
|
||||
|
||||
workflow.add_node("sql_agent", self.call_model)
|
||||
workflow.add_node("sql_tool_node", self.query_tool_node)
|
||||
workflow.add_edge(origin, "sql_agent")
|
||||
workflow.add_conditional_edges("agent", should_continue)
|
||||
workflow.add_edge("sql_agent", "sql_tool_node")
|
||||
|
||||
return workflow
|
16
src/hellocomputer/tools/viz.py
Normal file
16
src/hellocomputer/tools/viz.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PlotHistogramInput(BaseModel):
|
||||
column_name: str = Field(description="Name of the column containing the values")
|
||||
table_name: str = Field(description="Name of the table that contains the data")
|
||||
num_bins: int = Field(description="Number of bins of the histogram")
|
||||
|
||||
|
||||
class PlotHistogramTool(BaseTool):
|
||||
name: str = "plot_histogram"
|
||||
description: str = """
|
||||
Generate a histogram plot given a name of an existing table of the database,
|
||||
and a name of a column in the table. The default number of bins is 10, but
|
||||
you can forward the number of bins if you are requested to"""
|
Loading…
Reference in a new issue