File upload and session management work
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-27 22:05:29 +02:00
parent edd64d468b
commit e1ffeef646
9 changed files with 116 additions and 95 deletions

View file

@ -1,7 +1,7 @@
import json import json
import os import os
from datetime import datetime from datetime import datetime
from typing import List from typing import List, Dict
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import duckdb import duckdb
@ -64,7 +64,13 @@ class OwnershipDB(DDB):
elif settings.storage_engine == StorageEngines.local: elif settings.storage_engine == StorageEngines.local:
self.path_prefix = settings.path / "owners" self.path_prefix = settings.path / "owners"
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None): def set_ownership(
self,
user_email: str,
sid: str,
session_name: str,
record_id: UUID | None = None,
):
now = datetime.now().isoformat() now = datetime.now().isoformat()
record_id = uuid4() if record_id is None else record_id record_id = uuid4() if record_id is None else record_id
query = f""" query = f"""
@ -73,6 +79,7 @@ class OwnershipDB(DDB):
SELECT SELECT
'{user_email}' as email, '{user_email}' as email,
'{sid}' as sid, '{sid}' as sid,
'{session_name}' as session_name,
'{now}' as timestamp '{now}' as timestamp
) )
TO '{self.path_prefix}/{record_id}.csv'""" TO '{self.path_prefix}/{record_id}.csv'"""
@ -85,12 +92,12 @@ class OwnershipDB(DDB):
return sid return sid
def sessions(self, user_email: str) -> List[str]: def sessions(self, user_email: str) -> List[Dict[str, str]]:
try: try:
return ( return (
self.db.sql(f""" self.db.sql(f"""
SELECT SELECT
sid sid, session_name
FROM FROM
'{self.path_prefix}/*.csv' '{self.path_prefix}/*.csv'
WHERE WHERE
@ -100,8 +107,7 @@ class OwnershipDB(DDB):
LIMIT 10 LIMIT 10
""") """)
.pl() .pl()
.to_series() .to_dicts()
.to_list()
) )
# If the table does not exist # If the table does not exist
except duckdb.duckdb.IOException: except duckdb.duckdb.IOException:

View file

@ -1,27 +1,13 @@
from typing import Literal from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.config import settings from hellocomputer.nodes import (
from hellocomputer.extraction import initial_intent_parser intent,
from hellocomputer.models import AvailableModels answer_general,
from hellocomputer.prompts import Prompts answer_query,
answer_visualization,
)
async def intent(state: MessagesState):
messages = state["messages"]
query = messages[-1]
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
return {"messages": [await chain.ainvoke({"query", query.content})]}
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
@ -30,52 +16,17 @@ def route_intent(state: MessagesState) -> Literal["general", "query", "visualiza
return last_message.content return last_message.content
async def answer_general(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.general()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.visualization()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
workflow = StateGraph(MessagesState) workflow = StateGraph(MessagesState)
# Nodes
workflow.add_node("intent", intent) workflow.add_node("intent", intent)
workflow.add_node("answer_general", answer_general) workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_query", answer_query) workflow.add_node("answer_query", answer_query)
workflow.add_node("answer_visualization", answer_visualization) workflow.add_node("answer_visualization", answer_visualization)
# Edges
workflow.add_edge(START, "intent") workflow.add_edge(START, "intent")
workflow.add_conditional_edges( workflow.add_conditional_edges(
"intent", "intent",

View file

@ -0,0 +1,61 @@
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from hellocomputer.config import settings
from hellocomputer.extraction import initial_intent_parser
from hellocomputer.models import AvailableModels
from hellocomputer.prompts import Prompts
async def intent(state: MessagesState):
messages = state["messages"]
query = messages[-1]
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
return {"messages": [await chain.ainvoke({"query", query.content})]}
async def answer_general(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.general()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.visualization()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}

View file

@ -5,9 +5,10 @@ from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.requests import Request from starlette.requests import Request
from ..config import StorageEngines, settings from ..config import settings
from ..db.sessions import SessionDB from ..db.sessions import SessionDB
from ..db.users import OwnershipDB from ..db.users import OwnershipDB
from ..auth import get_user_email
router = APIRouter() router = APIRouter()
@ -22,7 +23,12 @@ router = APIRouter()
@router.post("/upload", tags=["files"]) @router.post("/upload", tags=["files"])
async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ""): async def upload_file(
request: Request,
file: UploadFile = File(...),
sid: str = "",
session_name: str = "",
):
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f: async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
content = await file.read() content = await file.read()
await f.write(content) await f.write(content)
@ -30,22 +36,14 @@ async def upload_file(request: Request, file: UploadFile = File(...), sid: str =
( (
SessionDB( SessionDB(
StorageEngines.gcs, settings,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid, sid=sid,
) )
.load_xls(f.name) .load_xls(f.name)
.dump() .dump()
) )
OwnershipDB( OwnershipDB(settings).set_ownership(get_user_email(request), sid, session_name)
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
).set_ownersip(request.session.get("user").get("email"), sid)
return JSONResponse( return JSONResponse(
content={"message": "File uploaded successfully"}, status_code=200 content={"message": "File uploaded successfully"}, status_code=200

View file

@ -1,11 +1,10 @@
from typing import List from typing import List, Dict
from uuid import uuid4 from uuid import uuid4
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from starlette.requests import Request from starlette.requests import Request
from hellocomputer.config import StorageEngines
from hellocomputer.db.users import OwnershipDB from hellocomputer.db.users import OwnershipDB
from ..auth import get_user_email from ..auth import get_user_email
@ -30,12 +29,7 @@ async def get_greeting() -> str:
@router.get("/sessions") @router.get("/sessions")
async def get_sessions(request: Request) -> List[str]: async def get_sessions(request: Request) -> List[Dict[str, str]]:
user_email = get_user_email(request) user_email = get_user_email(request)
ownership = OwnershipDB( ownership = OwnershipDB(settings)
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
)
return ownership.sessions(user_email) return ownership.sessions(user_email)

View file

@ -106,6 +106,8 @@
<ul id="userSessions"> <ul id="userSessions">
</ul> </ul>
</div> </div>
<div class="modal-body" id="loadResultDiv">
</div>
<div class="modal-footer"> <div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal" <button type="button" class="btn btn-secondary" data-bs-dismiss="modal"
id="sessionCloseButton">Close</button> id="sessionCloseButton">Close</button>

View file

@ -155,7 +155,9 @@ document.addEventListener("DOMContentLoaded", function () {
formData.append('file', file); formData.append('file', file);
try { try {
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), { const sid = sessionStorage.getItem('helloComputerSession');
const session_name = document.getElementById('datasetLabel').value;
const response = await fetch(`/upload?sid=${sid}&session_name=${session_name}`, {
method: 'POST', method: 'POST',
body: formData body: formData
}); });
@ -179,6 +181,7 @@ document.addEventListener("DOMContentLoaded", function () {
document.addEventListener("DOMContentLoaded", function () { document.addEventListener("DOMContentLoaded", function () {
const sessionsButton = document.getElementById('loadSessionsButton'); const sessionsButton = document.getElementById('loadSessionsButton');
const sessions = document.getElementById('userSessions'); const sessions = document.getElementById('userSessions');
const loadResultDiv = document.getElementById('loadResultDiv');
sessionsButton.addEventListener('click', async function fetchSessions() { sessionsButton.addEventListener('click', async function fetchSessions() {
try { try {
@ -191,8 +194,12 @@ document.addEventListener("DOMContentLoaded", function () {
data.forEach(item => { data.forEach(item => {
const listItem = document.createElement('li'); const listItem = document.createElement('li');
const button = document.createElement('button'); const button = document.createElement('button');
button.textContent = item; button.textContent = item.session_name;
button.addEventListener("click", function () { alert(`You clicked on ${item}`); }); button.addEventListener("click", function () {
sessionStorage.setItem("helloComputerSession", item.sid);
sessionStorage.setItem("helloComputerSessionLoaded", true);
loadResultDiv.textContent = 'Session loaded';
});
listItem.appendChild(button); listItem.appendChild(button);
sessions.appendChild(listItem); sessions.appendChild(listItem);
}); });

View file

@ -6,4 +6,6 @@ from langchain.prompts import PromptTemplate
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_general_prompt(): async def test_get_general_prompt():
general: PromptTemplate = await Prompts.general() general: PromptTemplate = await Prompts.general()
assert general.format(query="whatever").startswith("You're a helpful assistant") assert general.format(query="whatever").startswith(
"You've been asked to do a task you can't do"
)

View file

@ -29,14 +29,14 @@ def test_user_exists():
def test_assign_owner(): def test_assign_owner():
assert ( assert (
OwnershipDB(settings).set_ownersip( OwnershipDB(settings).set_ownership(
"something.something@something", "testsession", "test" "test@test.com", "sid", "session_name", "record_id"
) )
== "testsession" == "sid"
) )
def test_get_sessions(): def test_get_sessions():
assert OwnershipDB(settings).sessions("something.something@something") == [ assert OwnershipDB(settings).sessions("test@test.com") == [
"testsession" {"sid": "sid", "session_name": "session_name"}
] ]