diff --git a/src/hellocomputer/db/users.py b/src/hellocomputer/db/users.py index 019be70..13232f8 100644 --- a/src/hellocomputer/db/users.py +++ b/src/hellocomputer/db/users.py @@ -1,7 +1,7 @@ import json import os from datetime import datetime -from typing import List +from typing import List, Dict from uuid import UUID, uuid4 import duckdb @@ -64,7 +64,13 @@ class OwnershipDB(DDB): elif settings.storage_engine == StorageEngines.local: self.path_prefix = settings.path / "owners" - def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None): + def set_ownership( + self, + user_email: str, + sid: str, + session_name: str, + record_id: UUID | None = None, + ): now = datetime.now().isoformat() record_id = uuid4() if record_id is None else record_id query = f""" @@ -73,6 +79,7 @@ class OwnershipDB(DDB): SELECT '{user_email}' as email, '{sid}' as sid, + '{session_name}' as session_name, '{now}' as timestamp ) TO '{self.path_prefix}/{record_id}.csv'""" @@ -85,12 +92,12 @@ class OwnershipDB(DDB): return sid - def sessions(self, user_email: str) -> List[str]: + def sessions(self, user_email: str) -> List[Dict[str, str]]: try: return ( self.db.sql(f""" SELECT - sid + sid, session_name FROM '{self.path_prefix}/*.csv' WHERE @@ -100,8 +107,7 @@ class OwnershipDB(DDB): LIMIT 10 """) .pl() - .to_series() - .to_list() + .to_dicts() ) # If the table does not exist except duckdb.duckdb.IOException: diff --git a/src/hellocomputer/graph.py b/src/hellocomputer/graph.py index b0f8daf..fb4d497 100644 --- a/src/hellocomputer/graph.py +++ b/src/hellocomputer/graph.py @@ -1,27 +1,13 @@ from typing import Literal -from langchain_openai import ChatOpenAI from langgraph.graph import END, START, MessagesState, StateGraph -from hellocomputer.config import settings -from hellocomputer.extraction import initial_intent_parser -from hellocomputer.models import AvailableModels -from hellocomputer.prompts import Prompts - - -async def intent(state: MessagesState): - messages = state["messages"] - query = messages[-1] - llm = ChatOpenAI( - base_url=settings.llm_base_url, - api_key=settings.llm_api_key, - model=AvailableModels.llama_small, - temperature=0, - ) - prompt = await Prompts.intent() - chain = prompt | llm | initial_intent_parser - - return {"messages": [await chain.ainvoke({"query", query.content})]} +from hellocomputer.nodes import ( + intent, + answer_general, + answer_query, + answer_visualization, +) def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: @@ -30,52 +16,17 @@ def route_intent(state: MessagesState) -> Literal["general", "query", "visualiza return last_message.content -async def answer_general(state: MessagesState): - llm = ChatOpenAI( - base_url=settings.llm_base_url, - api_key=settings.llm_api_key, - model=AvailableModels.llama_small, - temperature=0, - ) - prompt = await Prompts.general() - chain = prompt | llm - - return {"messages": [await chain.ainvoke({})]} - - -async def answer_query(state: MessagesState): - llm = ChatOpenAI( - base_url=settings.llm_base_url, - api_key=settings.llm_api_key, - model=AvailableModels.llama_small, - temperature=0, - ) - prompt = await Prompts.sql() - chain = prompt | llm - - return {"messages": [await chain.ainvoke({})]} - - -async def answer_visualization(state: MessagesState): - llm = ChatOpenAI( - base_url=settings.llm_base_url, - api_key=settings.llm_api_key, - model=AvailableModels.llama_small, - temperature=0, - ) - prompt = await Prompts.visualization() - chain = prompt | llm - - return {"messages": [await chain.ainvoke({})]} - - workflow = StateGraph(MessagesState) +# Nodes + workflow.add_node("intent", intent) workflow.add_node("answer_general", answer_general) workflow.add_node("answer_query", answer_query) workflow.add_node("answer_visualization", answer_visualization) +# Edges + workflow.add_edge(START, "intent") workflow.add_conditional_edges( "intent", diff --git a/src/hellocomputer/nodes.py b/src/hellocomputer/nodes.py new file mode 100644 index 0000000..9232b09 --- /dev/null +++ b/src/hellocomputer/nodes.py @@ -0,0 +1,61 @@ +from langchain_openai import ChatOpenAI +from langgraph.graph import MessagesState + +from hellocomputer.config import settings +from hellocomputer.extraction import initial_intent_parser +from hellocomputer.models import AvailableModels +from hellocomputer.prompts import Prompts + + +async def intent(state: MessagesState): + messages = state["messages"] + query = messages[-1] + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.intent() + chain = prompt | llm | initial_intent_parser + + return {"messages": [await chain.ainvoke({"query", query.content})]} + + +async def answer_general(state: MessagesState): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.general() + chain = prompt | llm + + return {"messages": [await chain.ainvoke({})]} + + +async def answer_query(state: MessagesState): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.sql() + chain = prompt | llm + + return {"messages": [await chain.ainvoke({})]} + + +async def answer_visualization(state: MessagesState): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_small, + temperature=0, + ) + prompt = await Prompts.visualization() + chain = prompt | llm + + return {"messages": [await chain.ainvoke({})]} diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index 517746f..11da88d 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -5,9 +5,10 @@ from fastapi import APIRouter, File, UploadFile from fastapi.responses import JSONResponse from starlette.requests import Request -from ..config import StorageEngines, settings +from ..config import settings from ..db.sessions import SessionDB from ..db.users import OwnershipDB +from ..auth import get_user_email router = APIRouter() @@ -22,7 +23,12 @@ router = APIRouter() @router.post("/upload", tags=["files"]) -async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ""): +async def upload_file( + request: Request, + file: UploadFile = File(...), + sid: str = "", + session_name: str = "", +): async with aiofiles.tempfile.NamedTemporaryFile("wb") as f: content = await file.read() await f.write(content) @@ -30,22 +36,14 @@ async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ( SessionDB( - StorageEngines.gcs, - gcs_access=settings.gcs_access, - gcs_secret=settings.gcs_secret, - bucket=settings.gcs_bucketname, + settings, sid=sid, ) .load_xls(f.name) .dump() ) - OwnershipDB( - StorageEngines.gcs, - gcs_access=settings.gcs_access, - gcs_secret=settings.gcs_secret, - bucket=settings.gcs_bucketname, - ).set_ownersip(request.session.get("user").get("email"), sid) + OwnershipDB(settings).set_ownership(get_user_email(request), sid, session_name) return JSONResponse( content={"message": "File uploaded successfully"}, status_code=200 diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index 5567f11..198bb97 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -1,11 +1,10 @@ -from typing import List +from typing import List, Dict from uuid import uuid4 from fastapi import APIRouter from fastapi.responses import PlainTextResponse from starlette.requests import Request -from hellocomputer.config import StorageEngines from hellocomputer.db.users import OwnershipDB from ..auth import get_user_email @@ -30,12 +29,7 @@ async def get_greeting() -> str: @router.get("/sessions") -async def get_sessions(request: Request) -> List[str]: +async def get_sessions(request: Request) -> List[Dict[str, str]]: user_email = get_user_email(request) - ownership = OwnershipDB( - StorageEngines.gcs, - gcs_access=settings.gcs_access, - gcs_secret=settings.gcs_secret, - bucket=settings.gcs_bucketname, - ) + ownership = OwnershipDB(settings) return ownership.sessions(user_email) diff --git a/src/hellocomputer/static/index.html b/src/hellocomputer/static/index.html index 1e1c522..6dfcec0 100644 --- a/src/hellocomputer/static/index.html +++ b/src/hellocomputer/static/index.html @@ -106,6 +106,8 @@ +