File upload and session management work
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
edd64d468b
commit
e1ffeef646
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import duckdb
|
||||
|
@ -64,7 +64,13 @@ class OwnershipDB(DDB):
|
|||
elif settings.storage_engine == StorageEngines.local:
|
||||
self.path_prefix = settings.path / "owners"
|
||||
|
||||
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
||||
def set_ownership(
|
||||
self,
|
||||
user_email: str,
|
||||
sid: str,
|
||||
session_name: str,
|
||||
record_id: UUID | None = None,
|
||||
):
|
||||
now = datetime.now().isoformat()
|
||||
record_id = uuid4() if record_id is None else record_id
|
||||
query = f"""
|
||||
|
@ -73,6 +79,7 @@ class OwnershipDB(DDB):
|
|||
SELECT
|
||||
'{user_email}' as email,
|
||||
'{sid}' as sid,
|
||||
'{session_name}' as session_name,
|
||||
'{now}' as timestamp
|
||||
)
|
||||
TO '{self.path_prefix}/{record_id}.csv'"""
|
||||
|
@ -85,12 +92,12 @@ class OwnershipDB(DDB):
|
|||
|
||||
return sid
|
||||
|
||||
def sessions(self, user_email: str) -> List[str]:
|
||||
def sessions(self, user_email: str) -> List[Dict[str, str]]:
|
||||
try:
|
||||
return (
|
||||
self.db.sql(f"""
|
||||
SELECT
|
||||
sid
|
||||
sid, session_name
|
||||
FROM
|
||||
'{self.path_prefix}/*.csv'
|
||||
WHERE
|
||||
|
@ -100,8 +107,7 @@ class OwnershipDB(DDB):
|
|||
LIMIT 10
|
||||
""")
|
||||
.pl()
|
||||
.to_series()
|
||||
.to_list()
|
||||
.to_dicts()
|
||||
)
|
||||
# If the table does not exist
|
||||
except duckdb.duckdb.IOException:
|
||||
|
|
|
@ -1,27 +1,13 @@
|
|||
from typing import Literal
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.extraction import initial_intent_parser
|
||||
from hellocomputer.models import AvailableModels
|
||||
from hellocomputer.prompts import Prompts
|
||||
|
||||
|
||||
async def intent(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
query = messages[-1]
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.intent()
|
||||
chain = prompt | llm | initial_intent_parser
|
||||
|
||||
return {"messages": [await chain.ainvoke({"query", query.content})]}
|
||||
from hellocomputer.nodes import (
|
||||
intent,
|
||||
answer_general,
|
||||
answer_query,
|
||||
answer_visualization,
|
||||
)
|
||||
|
||||
|
||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||
|
@ -30,52 +16,17 @@ def route_intent(state: MessagesState) -> Literal["general", "query", "visualiza
|
|||
return last_message.content
|
||||
|
||||
|
||||
async def answer_general(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.general()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_query(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.sql()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_visualization(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.visualization()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
workflow = StateGraph(MessagesState)
|
||||
|
||||
# Nodes
|
||||
|
||||
workflow.add_node("intent", intent)
|
||||
workflow.add_node("answer_general", answer_general)
|
||||
workflow.add_node("answer_query", answer_query)
|
||||
workflow.add_node("answer_visualization", answer_visualization)
|
||||
|
||||
# Edges
|
||||
|
||||
workflow.add_edge(START, "intent")
|
||||
workflow.add_conditional_edges(
|
||||
"intent",
|
||||
|
|
61
src/hellocomputer/nodes.py
Normal file
61
src/hellocomputer/nodes.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph import MessagesState
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.extraction import initial_intent_parser
|
||||
from hellocomputer.models import AvailableModels
|
||||
from hellocomputer.prompts import Prompts
|
||||
|
||||
|
||||
async def intent(state: MessagesState):
|
||||
messages = state["messages"]
|
||||
query = messages[-1]
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.intent()
|
||||
chain = prompt | llm | initial_intent_parser
|
||||
|
||||
return {"messages": [await chain.ainvoke({"query", query.content})]}
|
||||
|
||||
|
||||
async def answer_general(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.general()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_query(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.sql()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
||||
|
||||
|
||||
async def answer_visualization(state: MessagesState):
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.visualization()
|
||||
chain = prompt | llm
|
||||
|
||||
return {"messages": [await chain.ainvoke({})]}
|
|
@ -5,9 +5,10 @@ from fastapi import APIRouter, File, UploadFile
|
|||
from fastapi.responses import JSONResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..config import StorageEngines, settings
|
||||
from ..config import settings
|
||||
from ..db.sessions import SessionDB
|
||||
from ..db.users import OwnershipDB
|
||||
from ..auth import get_user_email
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -22,7 +23,12 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.post("/upload", tags=["files"])
|
||||
async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ""):
|
||||
async def upload_file(
|
||||
request: Request,
|
||||
file: UploadFile = File(...),
|
||||
sid: str = "",
|
||||
session_name: str = "",
|
||||
):
|
||||
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
|
||||
content = await file.read()
|
||||
await f.write(content)
|
||||
|
@ -30,22 +36,14 @@ async def upload_file(request: Request, file: UploadFile = File(...), sid: str =
|
|||
|
||||
(
|
||||
SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
bucket=settings.gcs_bucketname,
|
||||
settings,
|
||||
sid=sid,
|
||||
)
|
||||
.load_xls(f.name)
|
||||
.dump()
|
||||
)
|
||||
|
||||
OwnershipDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
bucket=settings.gcs_bucketname,
|
||||
).set_ownersip(request.session.get("user").get("email"), sid)
|
||||
OwnershipDB(settings).set_ownership(get_user_email(request), sid, session_name)
|
||||
|
||||
return JSONResponse(
|
||||
content={"message": "File uploaded successfully"}, status_code=200
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from typing import List
|
||||
from typing import List, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from hellocomputer.config import StorageEngines
|
||||
from hellocomputer.db.users import OwnershipDB
|
||||
|
||||
from ..auth import get_user_email
|
||||
|
@ -30,12 +29,7 @@ async def get_greeting() -> str:
|
|||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def get_sessions(request: Request) -> List[str]:
|
||||
async def get_sessions(request: Request) -> List[Dict[str, str]]:
|
||||
user_email = get_user_email(request)
|
||||
ownership = OwnershipDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
bucket=settings.gcs_bucketname,
|
||||
)
|
||||
ownership = OwnershipDB(settings)
|
||||
return ownership.sessions(user_email)
|
||||
|
|
|
@ -106,6 +106,8 @@
|
|||
<ul id="userSessions">
|
||||
</ul>
|
||||
</div>
|
||||
<div class="modal-body" id="loadResultDiv">
|
||||
</div>
|
||||
<div class="modal-footer">
|
||||
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal"
|
||||
id="sessionCloseButton">Close</button>
|
||||
|
|
|
@ -155,7 +155,9 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
formData.append('file', file);
|
||||
|
||||
try {
|
||||
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
|
||||
const sid = sessionStorage.getItem('helloComputerSession');
|
||||
const session_name = document.getElementById('datasetLabel').value;
|
||||
const response = await fetch(`/upload?sid=${sid}&session_name=${session_name}`, {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
});
|
||||
|
@ -179,6 +181,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
document.addEventListener("DOMContentLoaded", function () {
|
||||
const sessionsButton = document.getElementById('loadSessionsButton');
|
||||
const sessions = document.getElementById('userSessions');
|
||||
const loadResultDiv = document.getElementById('loadResultDiv');
|
||||
|
||||
sessionsButton.addEventListener('click', async function fetchSessions() {
|
||||
try {
|
||||
|
@ -191,8 +194,12 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
data.forEach(item => {
|
||||
const listItem = document.createElement('li');
|
||||
const button = document.createElement('button');
|
||||
button.textContent = item;
|
||||
button.addEventListener("click", function () { alert(`You clicked on ${item}`); });
|
||||
button.textContent = item.session_name;
|
||||
button.addEventListener("click", function () {
|
||||
sessionStorage.setItem("helloComputerSession", item.sid);
|
||||
sessionStorage.setItem("helloComputerSessionLoaded", true);
|
||||
loadResultDiv.textContent = 'Session loaded';
|
||||
});
|
||||
listItem.appendChild(button);
|
||||
sessions.appendChild(listItem);
|
||||
});
|
||||
|
|
|
@ -6,4 +6,6 @@ from langchain.prompts import PromptTemplate
|
|||
@pytest.mark.asyncio
|
||||
async def test_get_general_prompt():
|
||||
general: PromptTemplate = await Prompts.general()
|
||||
assert general.format(query="whatever").startswith("You're a helpful assistant")
|
||||
assert general.format(query="whatever").startswith(
|
||||
"You've been asked to do a task you can't do"
|
||||
)
|
||||
|
|
|
@ -29,14 +29,14 @@ def test_user_exists():
|
|||
|
||||
def test_assign_owner():
|
||||
assert (
|
||||
OwnershipDB(settings).set_ownersip(
|
||||
"something.something@something", "testsession", "test"
|
||||
OwnershipDB(settings).set_ownership(
|
||||
"test@test.com", "sid", "session_name", "record_id"
|
||||
)
|
||||
== "testsession"
|
||||
== "sid"
|
||||
)
|
||||
|
||||
|
||||
def test_get_sessions():
|
||||
assert OwnershipDB(settings).sessions("something.something@something") == [
|
||||
"testsession"
|
||||
assert OwnershipDB(settings).sessions("test@test.com") == [
|
||||
{"sid": "sid", "session_name": "session_name"}
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue