Now a graph powers the application

This commit is contained in:
Guillem Borrell 2024-07-26 11:59:13 +02:00
parent 9d08a189b1
commit ee5dbb7167
19 changed files with 136 additions and 50 deletions

View file

@ -1,8 +1,9 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
from pathlib import Path
from typing import Self, Optional
from enum import StrEnum from enum import StrEnum
from pathlib import Path
from typing import Optional, Self
from pydantic import model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class StorageEngines(StrEnum): class StorageEngines(StrEnum):

View file

@ -1,6 +1,7 @@
from hellocomputer.config import Settings, StorageEngines
from sqlalchemy import create_engine from sqlalchemy import create_engine
from hellocomputer.config import Settings, StorageEngines
class DDB: class DDB:
def __init__( def __init__(

View file

@ -2,8 +2,8 @@ import os
from pathlib import Path from pathlib import Path
import duckdb import duckdb
from typing_extensions import Self
from langchain_community.utilities.sql_database import SQLDatabase from langchain_community.utilities.sql_database import SQLDatabase
from typing_extensions import Self
from hellocomputer.config import Settings, StorageEngines from hellocomputer.config import Settings, StorageEngines

View file

@ -7,8 +7,8 @@ from uuid import UUID, uuid4
import duckdb import duckdb
import polars as pl import polars as pl
from hellocomputer.db import DDB
from hellocomputer.config import Settings, StorageEngines from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db import DDB
class UserDB(DDB): class UserDB(DDB):

View file

@ -1,6 +1,7 @@
from langchain.output_parsers.enum import EnumOutputParser
from enum import StrEnum
import re import re
from enum import StrEnum
from langchain.output_parsers.enum import EnumOutputParser
def extract_code_block(response): def extract_code_block(response):

View file

@ -0,0 +1,93 @@
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.config import settings
from hellocomputer.extraction import initial_intent_parser
from hellocomputer.models import AvailableModels
from hellocomputer.prompts import Prompts
async def intent(state: MessagesState):
messages = state["messages"]
query = messages[-1]
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
return {"messages": [await chain.ainvoke({"query", query})]}
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
messages = state["messages"]
last_message = messages[-1]
return last_message.content
async def answer_general(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.general()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.visualization()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
workflow = StateGraph(MessagesState)
workflow.add_node("intent", intent)
workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_query", answer_query)
workflow.add_node("answer_visualization", answer_visualization)
workflow.add_edge(START, "intent")
workflow.add_conditional_edges(
"intent",
route_intent,
{
"general": "answer_general",
"query": "answer_query",
"visualization": "answer_visualization",
},
)
workflow.add_edge("answer_general", END)
workflow.add_edge("answer_query", END)
workflow.add_edge("answer_visualization", END)
app = workflow.compile()

View file

@ -1,6 +1,7 @@
from pathlib import Path
from anyio import open_file from anyio import open_file
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from pathlib import Path
import hellocomputer import hellocomputer
@ -19,8 +20,12 @@ class Prompts:
@classmethod @classmethod
async def general(cls): async def general(cls):
return PromptTemplate.from_template(await cls.getter("general_prompt")) return PromptTemplate.from_template(await cls.getter("general"))
@classmethod @classmethod
async def sql(cls): async def sql(cls):
return PromptTemplate.from_template(await cls.getter("sql_prompt")) return PromptTemplate.from_template(await cls.getter("sql"))
@classmethod
async def visualization(cls):
return PromptTemplate.from_template(await cls.getter("visualization"))

View file

@ -0,0 +1,6 @@
You've been asked to do a task you can't do. There are two kinds of questions you can answer:
1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query
2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization.
Tell the user the request is not one of your skills.

View file

@ -1,3 +0,0 @@
You're a helpful assistant. Perform the following tasks:
* {query}

View file

@ -0,0 +1 @@
Apologise because this feature is under construction

View file

@ -1,5 +0,0 @@
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
* {query}
Return only the sql statement without any additional text.

View file

@ -0,0 +1 @@
Apologise because this feature is under construction

View file

@ -1,29 +1,15 @@
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from langchain_core.messages import HumanMessage
from starlette.requests import Request
from hellocomputer.config import StorageEngines from hellocomputer.graph import app
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block
from ..config import settings
from ..models import Chat
router = APIRouter() router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(request: Request, sid: str = "", q: str = "") -> str:
llm = Chat(api_key=settings.llm_api_key, temperature=0.5) user = request.session.get("user") # noqa
db = SessionDB( response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
StorageEngines.gcs, return response
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = str(db.query(query))
return result

View file

@ -3,8 +3,7 @@ from fastapi import APIRouter
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from starlette.requests import Request from starlette.requests import Request
from hellocomputer.config import settings from hellocomputer.config import StorageEngines, settings
from hellocomputer.config import StorageEngines
from hellocomputer.db.users import UserDB from hellocomputer.db.users import UserDB
router = APIRouter() router = APIRouter()

View file

@ -5,8 +5,7 @@ from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.requests import Request from starlette.requests import Request
from ..config import settings from ..config import StorageEngines, settings
from ..config import StorageEngines
from ..db.sessions import SessionDB from ..db.sessions import SessionDB
from ..db.users import OwnershipDB from ..db.users import OwnershipDB

View file

@ -1,8 +1,10 @@
from pydantic import BaseModel, Field
from typing import Type from typing import Type
from langchain.tools import BaseTool from langchain.tools import BaseTool
from hellocomputer.db.sessions import SessionDB from pydantic import BaseModel, Field
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.db.sessions import SessionDB
class DuckdbQueryInput(BaseModel): class DuckdbQueryInput(BaseModel):

View file

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.config import StorageEngines, Settings from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB from hellocomputer.db.sessions import SessionDB
settings = Settings( settings = Settings(

View file

@ -1,7 +1,6 @@
from pathlib import Path from pathlib import Path
import hellocomputer import hellocomputer
import polars as pl
import pytest import pytest
from hellocomputer.config import Settings, StorageEngines from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB from hellocomputer.db.sessions import SessionDB

View file

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.config import StorageEngines, Settings from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.users import OwnershipDB, UserDB from hellocomputer.db.users import OwnershipDB, UserDB
settings = Settings( settings = Settings(