File upload and session management work
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-27 22:05:29 +02:00
parent edd64d468b
commit e1ffeef646
9 changed files with 116 additions and 95 deletions

View file

@ -1,7 +1,7 @@
import json
import os
from datetime import datetime
from typing import List
from typing import List, Dict
from uuid import UUID, uuid4
import duckdb
@ -64,7 +64,13 @@ class OwnershipDB(DDB):
elif settings.storage_engine == StorageEngines.local:
self.path_prefix = settings.path / "owners"
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
def set_ownership(
self,
user_email: str,
sid: str,
session_name: str,
record_id: UUID | None = None,
):
now = datetime.now().isoformat()
record_id = uuid4() if record_id is None else record_id
query = f"""
@ -73,6 +79,7 @@ class OwnershipDB(DDB):
SELECT
'{user_email}' as email,
'{sid}' as sid,
'{session_name}' as session_name,
'{now}' as timestamp
)
TO '{self.path_prefix}/{record_id}.csv'"""
@ -85,12 +92,12 @@ class OwnershipDB(DDB):
return sid
def sessions(self, user_email: str) -> List[str]:
def sessions(self, user_email: str) -> List[Dict[str, str]]:
try:
return (
self.db.sql(f"""
SELECT
sid
sid, session_name
FROM
'{self.path_prefix}/*.csv'
WHERE
@ -100,8 +107,7 @@ class OwnershipDB(DDB):
LIMIT 10
""")
.pl()
.to_series()
.to_list()
.to_dicts()
)
# If the table does not exist
except duckdb.duckdb.IOException:

View file

@ -1,27 +1,13 @@
from typing import Literal
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from hellocomputer.config import settings
from hellocomputer.extraction import initial_intent_parser
from hellocomputer.models import AvailableModels
from hellocomputer.prompts import Prompts
async def intent(state: MessagesState):
messages = state["messages"]
query = messages[-1]
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
return {"messages": [await chain.ainvoke({"query", query.content})]}
from hellocomputer.nodes import (
intent,
answer_general,
answer_query,
answer_visualization,
)
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
@ -30,52 +16,17 @@ def route_intent(state: MessagesState) -> Literal["general", "query", "visualiza
return last_message.content
async def answer_general(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.general()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.visualization()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
workflow = StateGraph(MessagesState)
# Nodes
workflow.add_node("intent", intent)
workflow.add_node("answer_general", answer_general)
workflow.add_node("answer_query", answer_query)
workflow.add_node("answer_visualization", answer_visualization)
# Edges
workflow.add_edge(START, "intent")
workflow.add_conditional_edges(
"intent",

View file

@ -0,0 +1,61 @@
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState
from hellocomputer.config import settings
from hellocomputer.extraction import initial_intent_parser
from hellocomputer.models import AvailableModels
from hellocomputer.prompts import Prompts
async def intent(state: MessagesState):
messages = state["messages"]
query = messages[-1]
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
return {"messages": [await chain.ainvoke({"query", query.content})]}
async def answer_general(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.general()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_query(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.sql()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}
async def answer_visualization(state: MessagesState):
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_small,
temperature=0,
)
prompt = await Prompts.visualization()
chain = prompt | llm
return {"messages": [await chain.ainvoke({})]}

View file

@ -5,9 +5,10 @@ from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from starlette.requests import Request
from ..config import StorageEngines, settings
from ..config import settings
from ..db.sessions import SessionDB
from ..db.users import OwnershipDB
from ..auth import get_user_email
router = APIRouter()
@ -22,7 +23,12 @@ router = APIRouter()
@router.post("/upload", tags=["files"])
async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ""):
async def upload_file(
request: Request,
file: UploadFile = File(...),
sid: str = "",
session_name: str = "",
):
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
content = await file.read()
await f.write(content)
@ -30,22 +36,14 @@ async def upload_file(request: Request, file: UploadFile = File(...), sid: str =
(
SessionDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
settings,
sid=sid,
)
.load_xls(f.name)
.dump()
)
OwnershipDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
).set_ownersip(request.session.get("user").get("email"), sid)
OwnershipDB(settings).set_ownership(get_user_email(request), sid, session_name)
return JSONResponse(
content={"message": "File uploaded successfully"}, status_code=200

View file

@ -1,11 +1,10 @@
from typing import List
from typing import List, Dict
from uuid import uuid4
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from starlette.requests import Request
from hellocomputer.config import StorageEngines
from hellocomputer.db.users import OwnershipDB
from ..auth import get_user_email
@ -30,12 +29,7 @@ async def get_greeting() -> str:
@router.get("/sessions")
async def get_sessions(request: Request) -> List[str]:
async def get_sessions(request: Request) -> List[Dict[str, str]]:
user_email = get_user_email(request)
ownership = OwnershipDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
)
ownership = OwnershipDB(settings)
return ownership.sessions(user_email)

View file

@ -106,6 +106,8 @@
<ul id="userSessions">
</ul>
</div>
<div class="modal-body" id="loadResultDiv">
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal"
id="sessionCloseButton">Close</button>

View file

@ -155,7 +155,9 @@ document.addEventListener("DOMContentLoaded", function () {
formData.append('file', file);
try {
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
const sid = sessionStorage.getItem('helloComputerSession');
const session_name = document.getElementById('datasetLabel').value;
const response = await fetch(`/upload?sid=${sid}&session_name=${session_name}`, {
method: 'POST',
body: formData
});
@ -179,6 +181,7 @@ document.addEventListener("DOMContentLoaded", function () {
document.addEventListener("DOMContentLoaded", function () {
const sessionsButton = document.getElementById('loadSessionsButton');
const sessions = document.getElementById('userSessions');
const loadResultDiv = document.getElementById('loadResultDiv');
sessionsButton.addEventListener('click', async function fetchSessions() {
try {
@ -191,8 +194,12 @@ document.addEventListener("DOMContentLoaded", function () {
data.forEach(item => {
const listItem = document.createElement('li');
const button = document.createElement('button');
button.textContent = item;
button.addEventListener("click", function () { alert(`You clicked on ${item}`); });
button.textContent = item.session_name;
button.addEventListener("click", function () {
sessionStorage.setItem("helloComputerSession", item.sid);
sessionStorage.setItem("helloComputerSessionLoaded", true);
loadResultDiv.textContent = 'Session loaded';
});
listItem.appendChild(button);
sessions.appendChild(listItem);
});

View file

@ -6,4 +6,6 @@ from langchain.prompts import PromptTemplate
@pytest.mark.asyncio
async def test_get_general_prompt():
general: PromptTemplate = await Prompts.general()
assert general.format(query="whatever").startswith("You're a helpful assistant")
assert general.format(query="whatever").startswith(
"You've been asked to do a task you can't do"
)

View file

@ -29,14 +29,14 @@ def test_user_exists():
def test_assign_owner():
assert (
OwnershipDB(settings).set_ownersip(
"something.something@something", "testsession", "test"
OwnershipDB(settings).set_ownership(
"test@test.com", "sid", "session_name", "record_id"
)
== "testsession"
== "sid"
)
def test_get_sessions():
assert OwnershipDB(settings).sessions("something.something@something") == [
"testsession"
assert OwnershipDB(settings).sessions("test@test.com") == [
{"sid": "sid", "session_name": "session_name"}
]