Now a graph powers the application
This commit is contained in:
parent
9d08a189b1
commit
ee5dbb7167
|
@ -1,8 +1,9 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import model_validator
|
||||
from pathlib import Path
|
||||
from typing import Self, Optional
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from hellocomputer.config import Settings, StorageEngines
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
|
||||
|
||||
class DDB:
|
||||
def __init__(
|
||||
|
|
|
@ -2,8 +2,8 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
from typing_extensions import Self
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from typing_extensions import Self
|
||||
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ from uuid import UUID, uuid4
|
|||
import duckdb
|
||||
import polars as pl
|
||||
|
||||
from hellocomputer.db import DDB
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db import DDB
|
||||
|
||||
|
||||
class UserDB(DDB):
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from langchain.output_parsers.enum import EnumOutputParser
|
||||
from enum import StrEnum
|
||||
import re
|
||||
from enum import StrEnum
|
||||
|
||||
from langchain.output_parsers.enum import EnumOutputParser
|
||||
|
||||
|
||||
def extract_code_block(response):
|
||||
|
|
93
src/hellocomputer/graph.py
Normal file
93
src/hellocomputer/graph.py
Normal file
|
@ -0,0 +1,93 @@
|
|||
from typing import Literal
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.extraction import initial_intent_parser
|
||||
from hellocomputer.models import AvailableModels
|
||||
from hellocomputer.prompts import Prompts
|
||||
|
||||
|
||||
async def intent(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
query = messages[-1]
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.intent()
|
||||
chain = prompt | llm | initial_intent_parser
|
||||
|
||||
return {"messages": [await chain.ainvoke({"query", query})]}
|
||||
|
||||
|
||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
return last_message.content
|
||||
|
||||
|
||||
async def answer_general(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.general()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_query(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.sql()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_visualization(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.visualization()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
workflow.add_node("intent", intent)
|
||||
workflow.add_node("answer_general", answer_general)
|
||||
workflow.add_node("answer_query", answer_query)
|
||||
workflow.add_node("answer_visualization", answer_visualization)
|
||||
|
||||
workflow.add_edge(START, "intent")
|
||||
workflow.add_conditional_edges(
|
||||
"intent",
|
||||
route_intent,
|
||||
{
|
||||
"general": "answer_general",
|
||||
"query": "answer_query",
|
||||
"visualization": "answer_visualization",
|
||||
},
|
||||
)
|
||||
workflow.add_edge("answer_general", END)
|
||||
workflow.add_edge("answer_query", END)
|
||||
workflow.add_edge("answer_visualization", END)
|
||||
|
||||
app = workflow.compile()
|
|
@ -1,6 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
from anyio import open_file
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
|
||||
|
@ -19,8 +20,12 @@ class Prompts:
|
|||
|
||||
@classmethod
|
||||
async def general(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||
return PromptTemplate.from_template(await cls.getter("general"))
|
||||
|
||||
@classmethod
|
||||
async def sql(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
||||
return PromptTemplate.from_template(await cls.getter("sql"))
|
||||
|
||||
@classmethod
|
||||
async def visualization(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("visualization"))
|
||||
|
|
6
src/hellocomputer/prompts/general.md
Normal file
6
src/hellocomputer/prompts/general.md
Normal file
|
@ -0,0 +1,6 @@
|
|||
You've been asked to do a task you can't do. There are two kinds of questions you can answer:
|
||||
|
||||
1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query
|
||||
2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization.
|
||||
|
||||
Tell the user the request is not one of your skills.
|
|
@ -1,3 +0,0 @@
|
|||
You're a helpful assistant. Perform the following tasks:
|
||||
|
||||
* {query}
|
1
src/hellocomputer/prompts/sql.md
Normal file
1
src/hellocomputer/prompts/sql.md
Normal file
|
@ -0,0 +1 @@
|
|||
Apologise because this feature is under construction
|
|
@ -1,5 +0,0 @@
|
|||
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
||||
|
||||
* {query}
|
||||
|
||||
Return only the sql statement without any additional text.
|
1
src/hellocomputer/prompts/visualization.md
Normal file
1
src/hellocomputer/prompts/visualization.md
Normal file
|
@ -0,0 +1 @@
|
|||
Apologise because this feature is under construction
|
|
@ -1,29 +1,15 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from langchain_core.messages import HumanMessage
|
||||
from starlette.requests import Request
|
||||
|
||||
from hellocomputer.config import StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
|
||||
from ..config import settings
|
||||
from ..models import Chat
|
||||
from hellocomputer.graph import app
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||
async def query(sid: str = "", q: str = "") -> str:
|
||||
llm = Chat(api_key=settings.llm_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
bucket=settings.gcs_bucketname,
|
||||
sid=sid,
|
||||
).load_folder()
|
||||
|
||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = str(db.query(query))
|
||||
|
||||
return result
|
||||
async def query(request: Request, sid: str = "", q: str = "") -> str:
|
||||
user = request.session.get("user") # noqa
|
||||
response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
|
||||
return response
|
||||
|
|
|
@ -3,8 +3,7 @@ from fastapi import APIRouter
|
|||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.config import StorageEngines
|
||||
from hellocomputer.config import StorageEngines, settings
|
||||
from hellocomputer.db.users import UserDB
|
||||
|
||||
router = APIRouter()
|
||||
|
|
|
@ -5,8 +5,7 @@ from fastapi import APIRouter, File, UploadFile
|
|||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..config import settings
|
||||
from ..config import StorageEngines
|
||||
from ..config import StorageEngines, settings
|
||||
from ..db.sessions import SessionDB
|
||||
from ..db.users import OwnershipDB
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import Type
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
|
||||
class DuckdbQueryInput(BaseModel):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.config import StorageEngines, Settings
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
settings = Settings(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
import polars as pl
|
||||
import pytest
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.config import StorageEngines, Settings
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db.users import OwnershipDB, UserDB
|
||||
|
||||
settings = Settings(
|
||||
|
|
Loading…
Reference in a new issue