Now a graph powers the application

This commit is contained in:
Guillem Borrell 2024-07-26 11:59:13 +02:00
parent 9d08a189b1
commit ee5dbb7167
19 changed files with 136 additions and 50 deletions

View file

@ -1,8 +1,9 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
from pathlib import Path
from typing import Self, Optional
from enum import StrEnum
from pathlib import Path
from typing import Optional, Self
from pydantic import model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class StorageEngines(StrEnum):

View file

@ -1,6 +1,7 @@
from hellocomputer.config import Settings, StorageEngines
from sqlalchemy import create_engine
from hellocomputer.config import Settings, StorageEngines
class DDB:
def __init__(

View file

@ -2,8 +2,8 @@ import os
from pathlib import Path
import duckdb
from typing_extensions import Self
from langchain_community.utilities.sql_database import SQLDatabase
from typing_extensions import Self
from hellocomputer.config import Settings, StorageEngines

View file

@ -7,8 +7,8 @@ from uuid import UUID, uuid4
import duckdb
import polars as pl
from hellocomputer.db import DDB
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db import DDB
class UserDB(DDB):

View file

@ -1,6 +1,7 @@
from langchain.output_parsers.enum import EnumOutputParser
from enum import StrEnum
import re
from enum import StrEnum
from langchain.output_parsers.enum import EnumOutputParser
def extract_code_block(response):

View file

@ -0,0 +1,93 @@
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.config import settings
from hellocomputer.extraction import initial_intent_parser
from hellocomputer.models import AvailableModels
from hellocomputer.prompts import Prompts
async def intent(state: MessagesState):
messages = state["messages"]
query = messages[-1]
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
return {"messages": [await chain.ainvoke({"query", query})]}
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
messages = state["messages"]
last_message = messages[-1]
return last_message.content
async def answer_general(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.general()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.visualization()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
workflow = StateGraph(MessagesState)
workflow.add_node("intent", intent)
workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_query", answer_query)
workflow.add_node("answer_visualization", answer_visualization)
workflow.add_edge(START, "intent")
workflow.add_conditional_edges(
"intent",
route_intent,
{
"general": "answer_general",
"query": "answer_query",
"visualization": "answer_visualization",
},
)
workflow.add_edge("answer_general", END)
workflow.add_edge("answer_query", END)
workflow.add_edge("answer_visualization", END)
app = workflow.compile()

View file

@ -1,6 +1,7 @@
from pathlib import Path
from anyio import open_file
from langchain_core.prompts import PromptTemplate
from pathlib import Path
import hellocomputer
@ -19,8 +20,12 @@ class Prompts:
@classmethod
async def general(cls):
return PromptTemplate.from_template(await cls.getter("general_prompt"))
return PromptTemplate.from_template(await cls.getter("general"))
@classmethod
async def sql(cls):
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
return PromptTemplate.from_template(await cls.getter("sql"))
@classmethod
async def visualization(cls):
return PromptTemplate.from_template(await cls.getter("visualization"))

View file

@ -0,0 +1,6 @@
You've been asked to do a task you can't do. There are two kinds of questions you can answer:
1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query
2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization.
Tell the user the request is not one of your skills.

View file

@ -1,3 +0,0 @@
You're a helpful assistant. Perform the following tasks:
* {query}

View file

@ -0,0 +1 @@
Apologise because this feature is under construction

View file

@ -1,5 +0,0 @@
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
* {query}
Return only the sql statement without any additional text.

View file

@ -0,0 +1 @@
Apologise because this feature is under construction

View file

@ -1,29 +1,15 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from langchain_core.messages import HumanMessage
from starlette.requests import Request
from hellocomputer.config import StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block
from ..config import settings
from ..models import Chat
from hellocomputer.graph import app
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
llm = Chat(api_key=settings.llm_api_key, temperature=0.5)
db = SessionDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = str(db.query(query))
return result
async def query(request: Request, sid: str = "", q: str = "") -> str:
user = request.session.get("user") # noqa
response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
return response

View file

@ -3,8 +3,7 @@ from fastapi import APIRouter
from fastapi.responses import HTMLResponse, RedirectResponse
from starlette.requests import Request
from hellocomputer.config import settings
from hellocomputer.config import StorageEngines
from hellocomputer.config import StorageEngines, settings
from hellocomputer.db.users import UserDB
router = APIRouter()

View file

@ -5,8 +5,7 @@ from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from starlette.requests import Request
from ..config import settings
from ..config import StorageEngines
from ..config import StorageEngines, settings
from ..db.sessions import SessionDB
from ..db.users import OwnershipDB

View file

@ -1,8 +1,10 @@
from pydantic import BaseModel, Field
from typing import Type
from langchain.tools import BaseTool
from hellocomputer.db.sessions import SessionDB
from pydantic import BaseModel, Field
from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
class DuckdbQueryInput(BaseModel):

View file

@ -1,7 +1,7 @@
from pathlib import Path
import hellocomputer
from hellocomputer.config import StorageEngines, Settings
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB
settings = Settings(

View file

@ -1,7 +1,6 @@
from pathlib import Path
import hellocomputer
import polars as pl
import pytest
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB

View file

@ -1,7 +1,7 @@
from pathlib import Path
import hellocomputer
from hellocomputer.config import StorageEngines, Settings
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.users import OwnershipDB, UserDB
settings = Settings(