diff --git a/src/hellocomputer/config.py b/src/hellocomputer/config.py index 0fda686..3d07e4f 100644 --- a/src/hellocomputer/config.py +++ b/src/hellocomputer/config.py @@ -1,8 +1,9 @@ -from pydantic_settings import BaseSettings, SettingsConfigDict -from pydantic import model_validator -from pathlib import Path -from typing import Self, Optional from enum import StrEnum +from pathlib import Path +from typing import Optional, Self + +from pydantic import model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class StorageEngines(StrEnum): diff --git a/src/hellocomputer/db/__init__.py b/src/hellocomputer/db/__init__.py index 5042077..04c3f45 100644 --- a/src/hellocomputer/db/__init__.py +++ b/src/hellocomputer/db/__init__.py @@ -1,6 +1,7 @@ -from hellocomputer.config import Settings, StorageEngines from sqlalchemy import create_engine +from hellocomputer.config import Settings, StorageEngines + class DDB: def __init__( diff --git a/src/hellocomputer/db/sessions.py b/src/hellocomputer/db/sessions.py index eb6559b..76cbf8a 100644 --- a/src/hellocomputer/db/sessions.py +++ b/src/hellocomputer/db/sessions.py @@ -2,8 +2,8 @@ import os from pathlib import Path import duckdb -from typing_extensions import Self from langchain_community.utilities.sql_database import SQLDatabase +from typing_extensions import Self from hellocomputer.config import Settings, StorageEngines diff --git a/src/hellocomputer/db/users.py b/src/hellocomputer/db/users.py index b4f84dc..019be70 100644 --- a/src/hellocomputer/db/users.py +++ b/src/hellocomputer/db/users.py @@ -7,8 +7,8 @@ from uuid import UUID, uuid4 import duckdb import polars as pl -from hellocomputer.db import DDB from hellocomputer.config import Settings, StorageEngines +from hellocomputer.db import DDB class UserDB(DDB): diff --git a/src/hellocomputer/extraction.py b/src/hellocomputer/extraction.py index 8152a0f..bc4ddce 100644 --- a/src/hellocomputer/extraction.py +++ b/src/hellocomputer/extraction.py @@ -1,6 +1,7 @@ -from langchain.output_parsers.enum import EnumOutputParser -from enum import StrEnum import re +from enum import StrEnum + +from langchain.output_parsers.enum import EnumOutputParser def extract_code_block(response): diff --git a/src/hellocomputer/graph.py b/src/hellocomputer/graph.py new file mode 100644 index 0000000..1d04575 --- /dev/null +++ b/src/hellocomputer/graph.py @@ -0,0 +1,93 @@ +from typing import Literal + +from langchain_openai import ChatOpenAI +from langgraph.graph import END, START, MessagesState, StateGraph + +from hellocomputer.config import settings +from hellocomputer.extraction import initial_intent_parser +from hellocomputer.models import AvailableModels +from hellocomputer.prompts import Prompts + + +async def intent(state: MessagesState): + messages = state["messages"] + query = messages[-1] + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.intent() + chain = prompt | llm | initial_intent_parser + + return {"messages": [await chain.ainvoke({"query", query})]} + + +def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: + messages = state["messages"] + last_message = messages[-1] + return last_message.content + + +async def answer_general(state: MessagesState): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.general() + chain = prompt | llm + + return {"messages": [await chain.ainvoke({})]} + + +async def answer_query(state: MessagesState): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.sql() + chain = prompt | llm + + return {"messages": [await chain.ainvoke({})]} + + +async def answer_visualization(state: MessagesState): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.visualization() + chain = prompt | llm + + return {"messages": [await chain.ainvoke({})]} + + +workflow = StateGraph(MessagesState) + +workflow.add_node("intent", intent) +workflow.add_node("answer_general", answer_general) +workflow.add_node("answer_query", answer_query) +workflow.add_node("answer_visualization", answer_visualization) + +workflow.add_edge(START, "intent") +workflow.add_conditional_edges( + "intent", + route_intent, + { + "general": "answer_general", + "query": "answer_query", + "visualization": "answer_visualization", + }, +) +workflow.add_edge("answer_general", END) +workflow.add_edge("answer_query", END) +workflow.add_edge("answer_visualization", END) + +app = workflow.compile() diff --git a/src/hellocomputer/prompts.py b/src/hellocomputer/prompts.py index a0fd80c..5cf5d15 100644 --- a/src/hellocomputer/prompts.py +++ b/src/hellocomputer/prompts.py @@ -1,6 +1,7 @@ +from pathlib import Path + from anyio import open_file from langchain_core.prompts import PromptTemplate -from pathlib import Path import hellocomputer @@ -19,8 +20,12 @@ class Prompts: @classmethod async def general(cls): - return PromptTemplate.from_template(await cls.getter("general_prompt")) + return PromptTemplate.from_template(await cls.getter("general")) @classmethod async def sql(cls): - return PromptTemplate.from_template(await cls.getter("sql_prompt")) + return PromptTemplate.from_template(await cls.getter("sql")) + + @classmethod + async def visualization(cls): + return PromptTemplate.from_template(await cls.getter("visualization")) diff --git a/src/hellocomputer/prompts/general.md b/src/hellocomputer/prompts/general.md new file mode 100644 index 0000000..c341c15 --- /dev/null +++ b/src/hellocomputer/prompts/general.md @@ -0,0 +1,6 @@ +You've been asked to do a task you can't do. There are two kinds of questions you can answer: + +1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query +2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization. + +Tell the user the request is not one of your skills. diff --git a/src/hellocomputer/prompts/general_prompt.md b/src/hellocomputer/prompts/general_prompt.md deleted file mode 100644 index 73e6c97..0000000 --- a/src/hellocomputer/prompts/general_prompt.md +++ /dev/null @@ -1,3 +0,0 @@ -You're a helpful assistant. Perform the following tasks: - -* {query} diff --git a/src/hellocomputer/prompts/sql.md b/src/hellocomputer/prompts/sql.md new file mode 100644 index 0000000..ebd5bff --- /dev/null +++ b/src/hellocomputer/prompts/sql.md @@ -0,0 +1 @@ +Apologise because this feature is under construction \ No newline at end of file diff --git a/src/hellocomputer/prompts/sql_prompt.md b/src/hellocomputer/prompts/sql_prompt.md deleted file mode 100644 index 0853c71..0000000 --- a/src/hellocomputer/prompts/sql_prompt.md +++ /dev/null @@ -1,5 +0,0 @@ -You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: - -* {query} - -Return only the sql statement without any additional text. \ No newline at end of file diff --git a/src/hellocomputer/prompts/visualization.md b/src/hellocomputer/prompts/visualization.md new file mode 100644 index 0000000..ebd5bff --- /dev/null +++ b/src/hellocomputer/prompts/visualization.md @@ -0,0 +1 @@ +Apologise because this feature is under construction \ No newline at end of file diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 7ac6708..54b8349 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,29 +1,15 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse +from langchain_core.messages import HumanMessage +from starlette.requests import Request -from hellocomputer.config import StorageEngines -from hellocomputer.db.sessions import SessionDB -from hellocomputer.extraction import extract_code_block - -from ..config import settings -from ..models import Chat +from hellocomputer.graph import app router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) -async def query(sid: str = "", q: str = "") -> str: - llm = Chat(api_key=settings.llm_api_key, temperature=0.5) - db = SessionDB( - StorageEngines.gcs, - gcs_access=settings.gcs_access, - gcs_secret=settings.gcs_secret, - bucket=settings.gcs_bucketname, - sid=sid, - ).load_folder() - - chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) - query = extract_code_block(chat.last_response_content()) - result = str(db.query(query)) - - return result +async def query(request: Request, sid: str = "", q: str = "") -> str: + user = request.session.get("user") # noqa + response = await app.ainvoke({"messages": [HumanMessage(content=q)]}) + return response diff --git a/src/hellocomputer/routers/auth.py b/src/hellocomputer/routers/auth.py index c9c986b..026e23d 100644 --- a/src/hellocomputer/routers/auth.py +++ b/src/hellocomputer/routers/auth.py @@ -3,8 +3,7 @@ from fastapi import APIRouter from fastapi.responses import HTMLResponse, RedirectResponse from starlette.requests import Request -from hellocomputer.config import settings -from hellocomputer.config import StorageEngines +from hellocomputer.config import StorageEngines, settings from hellocomputer.db.users import UserDB router = APIRouter() diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index 6ae2f12..517746f 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -5,8 +5,7 @@ from fastapi import APIRouter, File, UploadFile from fastapi.responses import JSONResponse from starlette.requests import Request -from ..config import settings -from ..config import StorageEngines +from ..config import StorageEngines, settings from ..db.sessions import SessionDB from ..db.users import OwnershipDB diff --git a/src/hellocomputer/tools.py b/src/hellocomputer/tools.py index cf8a4a2..b031d7f 100644 --- a/src/hellocomputer/tools.py +++ b/src/hellocomputer/tools.py @@ -1,8 +1,10 @@ -from pydantic import BaseModel, Field from typing import Type + from langchain.tools import BaseTool -from hellocomputer.db.sessions import SessionDB +from pydantic import BaseModel, Field + from hellocomputer.config import settings +from hellocomputer.db.sessions import SessionDB class DuckdbQueryInput(BaseModel): diff --git a/test/test_data.py b/test/test_data.py index bda3c72..022b00c 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,7 +1,7 @@ from pathlib import Path import hellocomputer -from hellocomputer.config import StorageEngines, Settings +from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.sessions import SessionDB settings = Settings( diff --git a/test/test_query.py b/test/test_query.py index 0e59151..7820fc0 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1,7 +1,6 @@ from pathlib import Path import hellocomputer -import polars as pl import pytest from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.sessions import SessionDB diff --git a/test/test_user.py b/test/test_user.py index 959826d..454cc4a 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -1,7 +1,7 @@ from pathlib import Path import hellocomputer -from hellocomputer.config import StorageEngines, Settings +from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.users import OwnershipDB, UserDB settings = Settings(