Now a graph powers the application
This commit is contained in:
parent
9d08a189b1
commit
ee5dbb7167
|
@ -1,8 +1,9 @@
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
from pydantic import model_validator
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Self, Optional
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
from pydantic import model_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class StorageEngines(StrEnum):
|
class StorageEngines(StrEnum):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from hellocomputer.config import Settings, StorageEngines
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
|
|
||||||
|
|
||||||
class DDB:
|
class DDB:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -2,8 +2,8 @@ import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
from typing_extensions import Self
|
|
||||||
from langchain_community.utilities.sql_database import SQLDatabase
|
from langchain_community.utilities.sql_database import SQLDatabase
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from hellocomputer.config import Settings, StorageEngines
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ from uuid import UUID, uuid4
|
||||||
import duckdb
|
import duckdb
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from hellocomputer.db import DDB
|
|
||||||
from hellocomputer.config import Settings, StorageEngines
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
|
from hellocomputer.db import DDB
|
||||||
|
|
||||||
|
|
||||||
class UserDB(DDB):
|
class UserDB(DDB):
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from langchain.output_parsers.enum import EnumOutputParser
|
|
||||||
from enum import StrEnum
|
|
||||||
import re
|
import re
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
from langchain.output_parsers.enum import EnumOutputParser
|
||||||
|
|
||||||
|
|
||||||
def extract_code_block(response):
|
def extract_code_block(response):
|
||||||
|
|
93
src/hellocomputer/graph.py
Normal file
93
src/hellocomputer/graph.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||||
|
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.extraction import initial_intent_parser
|
||||||
|
from hellocomputer.models import AvailableModels
|
||||||
|
from hellocomputer.prompts import Prompts
|
||||||
|
|
||||||
|
|
||||||
|
async def intent(state: MessagesState):
|
||||||
|
messages = state["messages"]
|
||||||
|
query = messages[-1]
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.intent()
|
||||||
|
chain = prompt | llm | initial_intent_parser
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({"query", query})]}
|
||||||
|
|
||||||
|
|
||||||
|
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||||
|
messages = state["messages"]
|
||||||
|
last_message = messages[-1]
|
||||||
|
return last_message.content
|
||||||
|
|
||||||
|
|
||||||
|
async def answer_general(state: MessagesState):
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.general()
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
|
async def answer_query(state: MessagesState):
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.sql()
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
|
async def answer_visualization(state: MessagesState):
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.visualization()
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
|
workflow = StateGraph(MessagesState)
|
||||||
|
|
||||||
|
workflow.add_node("intent", intent)
|
||||||
|
workflow.add_node("answer_general", answer_general)
|
||||||
|
workflow.add_node("answer_query", answer_query)
|
||||||
|
workflow.add_node("answer_visualization", answer_visualization)
|
||||||
|
|
||||||
|
workflow.add_edge(START, "intent")
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"intent",
|
||||||
|
route_intent,
|
||||||
|
{
|
||||||
|
"general": "answer_general",
|
||||||
|
"query": "answer_query",
|
||||||
|
"visualization": "answer_visualization",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
workflow.add_edge("answer_general", END)
|
||||||
|
workflow.add_edge("answer_query", END)
|
||||||
|
workflow.add_edge("answer_visualization", END)
|
||||||
|
|
||||||
|
app = workflow.compile()
|
|
@ -1,6 +1,7 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
|
|
||||||
|
@ -19,8 +20,12 @@ class Prompts:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def general(cls):
|
async def general(cls):
|
||||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
return PromptTemplate.from_template(await cls.getter("general"))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def sql(cls):
|
async def sql(cls):
|
||||||
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
return PromptTemplate.from_template(await cls.getter("sql"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def visualization(cls):
|
||||||
|
return PromptTemplate.from_template(await cls.getter("visualization"))
|
||||||
|
|
6
src/hellocomputer/prompts/general.md
Normal file
6
src/hellocomputer/prompts/general.md
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
You've been asked to do a task you can't do. There are two kinds of questions you can answer:
|
||||||
|
|
||||||
|
1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query
|
||||||
|
2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization.
|
||||||
|
|
||||||
|
Tell the user the request is not one of your skills.
|
|
@ -1,3 +0,0 @@
|
||||||
You're a helpful assistant. Perform the following tasks:
|
|
||||||
|
|
||||||
* {query}
|
|
1
src/hellocomputer/prompts/sql.md
Normal file
1
src/hellocomputer/prompts/sql.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Apologise because this feature is under construction
|
|
@ -1,5 +0,0 @@
|
||||||
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
|
||||||
|
|
||||||
* {query}
|
|
||||||
|
|
||||||
Return only the sql statement without any additional text.
|
|
1
src/hellocomputer/prompts/visualization.md
Normal file
1
src/hellocomputer/prompts/visualization.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Apologise because this feature is under construction
|
|
@ -1,29 +1,15 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from hellocomputer.config import StorageEngines
|
from hellocomputer.graph import app
|
||||||
from hellocomputer.db.sessions import SessionDB
|
|
||||||
from hellocomputer.extraction import extract_code_block
|
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..models import Chat
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "", q: str = "") -> str:
|
async def query(request: Request, sid: str = "", q: str = "") -> str:
|
||||||
llm = Chat(api_key=settings.llm_api_key, temperature=0.5)
|
user = request.session.get("user") # noqa
|
||||||
db = SessionDB(
|
response = await app.ainvoke({"messages": [HumanMessage(content=q)]})
|
||||||
StorageEngines.gcs,
|
return response
|
||||||
gcs_access=settings.gcs_access,
|
|
||||||
gcs_secret=settings.gcs_secret,
|
|
||||||
bucket=settings.gcs_bucketname,
|
|
||||||
sid=sid,
|
|
||||||
).load_folder()
|
|
||||||
|
|
||||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
|
||||||
query = extract_code_block(chat.last_response_content())
|
|
||||||
result = str(db.query(query))
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
|
@ -3,8 +3,7 @@ from fastapi import APIRouter
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import StorageEngines, settings
|
||||||
from hellocomputer.config import StorageEngines
|
|
||||||
from hellocomputer.db.users import UserDB
|
from hellocomputer.db.users import UserDB
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
|
@ -5,8 +5,7 @@ from fastapi import APIRouter, File, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import StorageEngines, settings
|
||||||
from ..config import StorageEngines
|
|
||||||
from ..db.sessions import SessionDB
|
from ..db.sessions import SessionDB
|
||||||
from ..db.users import OwnershipDB
|
from ..db.users import OwnershipDB
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
|
||||||
|
|
||||||
class DuckdbQueryInput(BaseModel):
|
class DuckdbQueryInput(BaseModel):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.config import StorageEngines, Settings
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import polars as pl
|
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.config import Settings, StorageEngines
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.config import StorageEngines, Settings
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
from hellocomputer.db.users import OwnershipDB, UserDB
|
from hellocomputer.db.users import OwnershipDB, UserDB
|
||||||
|
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
|
|
Loading…
Reference in a new issue