File upload and session management work
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
edd64d468b
commit
e1ffeef646
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
|
@ -64,7 +64,13 @@ class OwnershipDB(DDB):
|
||||||
elif settings.storage_engine == StorageEngines.local:
|
elif settings.storage_engine == StorageEngines.local:
|
||||||
self.path_prefix = settings.path / "owners"
|
self.path_prefix = settings.path / "owners"
|
||||||
|
|
||||||
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
def set_ownership(
|
||||||
|
self,
|
||||||
|
user_email: str,
|
||||||
|
sid: str,
|
||||||
|
session_name: str,
|
||||||
|
record_id: UUID | None = None,
|
||||||
|
):
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
record_id = uuid4() if record_id is None else record_id
|
record_id = uuid4() if record_id is None else record_id
|
||||||
query = f"""
|
query = f"""
|
||||||
|
@ -73,6 +79,7 @@ class OwnershipDB(DDB):
|
||||||
SELECT
|
SELECT
|
||||||
'{user_email}' as email,
|
'{user_email}' as email,
|
||||||
'{sid}' as sid,
|
'{sid}' as sid,
|
||||||
|
'{session_name}' as session_name,
|
||||||
'{now}' as timestamp
|
'{now}' as timestamp
|
||||||
)
|
)
|
||||||
TO '{self.path_prefix}/{record_id}.csv'"""
|
TO '{self.path_prefix}/{record_id}.csv'"""
|
||||||
|
@ -85,12 +92,12 @@ class OwnershipDB(DDB):
|
||||||
|
|
||||||
return sid
|
return sid
|
||||||
|
|
||||||
def sessions(self, user_email: str) -> List[str]:
|
def sessions(self, user_email: str) -> List[Dict[str, str]]:
|
||||||
try:
|
try:
|
||||||
return (
|
return (
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
SELECT
|
SELECT
|
||||||
sid
|
sid, session_name
|
||||||
FROM
|
FROM
|
||||||
'{self.path_prefix}/*.csv'
|
'{self.path_prefix}/*.csv'
|
||||||
WHERE
|
WHERE
|
||||||
|
@ -100,8 +107,7 @@ class OwnershipDB(DDB):
|
||||||
LIMIT 10
|
LIMIT 10
|
||||||
""")
|
""")
|
||||||
.pl()
|
.pl()
|
||||||
.to_series()
|
.to_dicts()
|
||||||
.to_list()
|
|
||||||
)
|
)
|
||||||
# If the table does not exist
|
# If the table does not exist
|
||||||
except duckdb.duckdb.IOException:
|
except duckdb.duckdb.IOException:
|
||||||
|
|
|
@ -1,27 +1,13 @@
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.nodes import (
|
||||||
from hellocomputer.extraction import initial_intent_parser
|
intent,
|
||||||
from hellocomputer.models import AvailableModels
|
answer_general,
|
||||||
from hellocomputer.prompts import Prompts
|
answer_query,
|
||||||
|
answer_visualization,
|
||||||
|
|
||||||
async def intent(state: MessagesState):
|
|
||||||
messages = state["messages"]
|
|
||||||
query = messages[-1]
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
base_url=settings.llm_base_url,
|
|
||||||
api_key=settings.llm_api_key,
|
|
||||||
model=AvailableModels.llama_small,
|
|
||||||
temperature=0,
|
|
||||||
)
|
)
|
||||||
prompt = await Prompts.intent()
|
|
||||||
chain = prompt | llm | initial_intent_parser
|
|
||||||
|
|
||||||
return {"messages": [await chain.ainvoke({"query", query.content})]}
|
|
||||||
|
|
||||||
|
|
||||||
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]:
|
||||||
|
@ -30,52 +16,17 @@ def route_intent(state: MessagesState) -> Literal["general", "query", "visualiza
|
||||||
return last_message.content
|
return last_message.content
|
||||||
|
|
||||||
|
|
||||||
async def answer_general(state: MessagesState):
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
base_url=settings.llm_base_url,
|
|
||||||
api_key=settings.llm_api_key,
|
|
||||||
model=AvailableModels.llama_small,
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
prompt = await Prompts.general()
|
|
||||||
chain = prompt | llm
|
|
||||||
|
|
||||||
return {"messages": [await chain.ainvoke({})]}
|
|
||||||
|
|
||||||
|
|
||||||
async def answer_query(state: MessagesState):
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
base_url=settings.llm_base_url,
|
|
||||||
api_key=settings.llm_api_key,
|
|
||||||
model=AvailableModels.llama_small,
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
prompt = await Prompts.sql()
|
|
||||||
chain = prompt | llm
|
|
||||||
|
|
||||||
return {"messages": [await chain.ainvoke({})]}
|
|
||||||
|
|
||||||
|
|
||||||
async def answer_visualization(state: MessagesState):
|
|
||||||
llm = ChatOpenAI(
|
|
||||||
base_url=settings.llm_base_url,
|
|
||||||
api_key=settings.llm_api_key,
|
|
||||||
model=AvailableModels.llama_small,
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
prompt = await Prompts.visualization()
|
|
||||||
chain = prompt | llm
|
|
||||||
|
|
||||||
return {"messages": [await chain.ainvoke({})]}
|
|
||||||
|
|
||||||
|
|
||||||
workflow = StateGraph(MessagesState)
|
workflow = StateGraph(MessagesState)
|
||||||
|
|
||||||
|
# Nodes
|
||||||
|
|
||||||
workflow.add_node("intent", intent)
|
workflow.add_node("intent", intent)
|
||||||
workflow.add_node("answer_general", answer_general)
|
workflow.add_node("answer_general", answer_general)
|
||||||
workflow.add_node("answer_query", answer_query)
|
workflow.add_node("answer_query", answer_query)
|
||||||
workflow.add_node("answer_visualization", answer_visualization)
|
workflow.add_node("answer_visualization", answer_visualization)
|
||||||
|
|
||||||
|
# Edges
|
||||||
|
|
||||||
workflow.add_edge(START, "intent")
|
workflow.add_edge(START, "intent")
|
||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"intent",
|
"intent",
|
||||||
|
|
61
src/hellocomputer/nodes.py
Normal file
61
src/hellocomputer/nodes.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langgraph.graph import MessagesState
|
||||||
|
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.extraction import initial_intent_parser
|
||||||
|
from hellocomputer.models import AvailableModels
|
||||||
|
from hellocomputer.prompts import Prompts
|
||||||
|
|
||||||
|
|
||||||
|
async def intent(state: MessagesState):
|
||||||
|
messages = state["messages"]
|
||||||
|
query = messages[-1]
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.intent()
|
||||||
|
chain = prompt | llm | initial_intent_parser
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({"query", query.content})]}
|
||||||
|
|
||||||
|
|
||||||
|
async def answer_general(state: MessagesState):
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.general()
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
|
async def answer_query(state: MessagesState):
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.sql()
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({})]}
|
||||||
|
|
||||||
|
|
||||||
|
async def answer_visualization(state: MessagesState):
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.llama_small,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
prompt = await Prompts.visualization()
|
||||||
|
chain = prompt | llm
|
||||||
|
|
||||||
|
return {"messages": [await chain.ainvoke({})]}
|
|
@ -5,9 +5,10 @@ from fastapi import APIRouter, File, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from ..config import StorageEngines, settings
|
from ..config import settings
|
||||||
from ..db.sessions import SessionDB
|
from ..db.sessions import SessionDB
|
||||||
from ..db.users import OwnershipDB
|
from ..db.users import OwnershipDB
|
||||||
|
from ..auth import get_user_email
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -22,7 +23,12 @@ router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload", tags=["files"])
|
@router.post("/upload", tags=["files"])
|
||||||
async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ""):
|
async def upload_file(
|
||||||
|
request: Request,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
sid: str = "",
|
||||||
|
session_name: str = "",
|
||||||
|
):
|
||||||
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
|
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
|
||||||
content = await file.read()
|
content = await file.read()
|
||||||
await f.write(content)
|
await f.write(content)
|
||||||
|
@ -30,22 +36,14 @@ async def upload_file(request: Request, file: UploadFile = File(...), sid: str =
|
||||||
|
|
||||||
(
|
(
|
||||||
SessionDB(
|
SessionDB(
|
||||||
StorageEngines.gcs,
|
settings,
|
||||||
gcs_access=settings.gcs_access,
|
|
||||||
gcs_secret=settings.gcs_secret,
|
|
||||||
bucket=settings.gcs_bucketname,
|
|
||||||
sid=sid,
|
sid=sid,
|
||||||
)
|
)
|
||||||
.load_xls(f.name)
|
.load_xls(f.name)
|
||||||
.dump()
|
.dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
OwnershipDB(
|
OwnershipDB(settings).set_ownership(get_user_email(request), sid, session_name)
|
||||||
StorageEngines.gcs,
|
|
||||||
gcs_access=settings.gcs_access,
|
|
||||||
gcs_secret=settings.gcs_secret,
|
|
||||||
bucket=settings.gcs_bucketname,
|
|
||||||
).set_ownersip(request.session.get("user").get("email"), sid)
|
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"message": "File uploaded successfully"}, status_code=200
|
content={"message": "File uploaded successfully"}, status_code=200
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
from typing import List
|
from typing import List, Dict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from hellocomputer.config import StorageEngines
|
|
||||||
from hellocomputer.db.users import OwnershipDB
|
from hellocomputer.db.users import OwnershipDB
|
||||||
|
|
||||||
from ..auth import get_user_email
|
from ..auth import get_user_email
|
||||||
|
@ -30,12 +29,7 @@ async def get_greeting() -> str:
|
||||||
|
|
||||||
|
|
||||||
@router.get("/sessions")
|
@router.get("/sessions")
|
||||||
async def get_sessions(request: Request) -> List[str]:
|
async def get_sessions(request: Request) -> List[Dict[str, str]]:
|
||||||
user_email = get_user_email(request)
|
user_email = get_user_email(request)
|
||||||
ownership = OwnershipDB(
|
ownership = OwnershipDB(settings)
|
||||||
StorageEngines.gcs,
|
|
||||||
gcs_access=settings.gcs_access,
|
|
||||||
gcs_secret=settings.gcs_secret,
|
|
||||||
bucket=settings.gcs_bucketname,
|
|
||||||
)
|
|
||||||
return ownership.sessions(user_email)
|
return ownership.sessions(user_email)
|
||||||
|
|
|
@ -106,6 +106,8 @@
|
||||||
<ul id="userSessions">
|
<ul id="userSessions">
|
||||||
</ul>
|
</ul>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="modal-body" id="loadResultDiv">
|
||||||
|
</div>
|
||||||
<div class="modal-footer">
|
<div class="modal-footer">
|
||||||
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal"
|
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal"
|
||||||
id="sessionCloseButton">Close</button>
|
id="sessionCloseButton">Close</button>
|
||||||
|
|
|
@ -155,7 +155,9 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
formData.append('file', file);
|
formData.append('file', file);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
|
const sid = sessionStorage.getItem('helloComputerSession');
|
||||||
|
const session_name = document.getElementById('datasetLabel').value;
|
||||||
|
const response = await fetch(`/upload?sid=${sid}&session_name=${session_name}`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: formData
|
body: formData
|
||||||
});
|
});
|
||||||
|
@ -179,6 +181,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
document.addEventListener("DOMContentLoaded", function () {
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
const sessionsButton = document.getElementById('loadSessionsButton');
|
const sessionsButton = document.getElementById('loadSessionsButton');
|
||||||
const sessions = document.getElementById('userSessions');
|
const sessions = document.getElementById('userSessions');
|
||||||
|
const loadResultDiv = document.getElementById('loadResultDiv');
|
||||||
|
|
||||||
sessionsButton.addEventListener('click', async function fetchSessions() {
|
sessionsButton.addEventListener('click', async function fetchSessions() {
|
||||||
try {
|
try {
|
||||||
|
@ -191,8 +194,12 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
const listItem = document.createElement('li');
|
const listItem = document.createElement('li');
|
||||||
const button = document.createElement('button');
|
const button = document.createElement('button');
|
||||||
button.textContent = item;
|
button.textContent = item.session_name;
|
||||||
button.addEventListener("click", function () { alert(`You clicked on ${item}`); });
|
button.addEventListener("click", function () {
|
||||||
|
sessionStorage.setItem("helloComputerSession", item.sid);
|
||||||
|
sessionStorage.setItem("helloComputerSessionLoaded", true);
|
||||||
|
loadResultDiv.textContent = 'Session loaded';
|
||||||
|
});
|
||||||
listItem.appendChild(button);
|
listItem.appendChild(button);
|
||||||
sessions.appendChild(listItem);
|
sessions.appendChild(listItem);
|
||||||
});
|
});
|
||||||
|
|
|
@ -6,4 +6,6 @@ from langchain.prompts import PromptTemplate
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_general_prompt():
|
async def test_get_general_prompt():
|
||||||
general: PromptTemplate = await Prompts.general()
|
general: PromptTemplate = await Prompts.general()
|
||||||
assert general.format(query="whatever").startswith("You're a helpful assistant")
|
assert general.format(query="whatever").startswith(
|
||||||
|
"You've been asked to do a task you can't do"
|
||||||
|
)
|
||||||
|
|
|
@ -29,14 +29,14 @@ def test_user_exists():
|
||||||
|
|
||||||
def test_assign_owner():
|
def test_assign_owner():
|
||||||
assert (
|
assert (
|
||||||
OwnershipDB(settings).set_ownersip(
|
OwnershipDB(settings).set_ownership(
|
||||||
"something.something@something", "testsession", "test"
|
"test@test.com", "sid", "session_name", "record_id"
|
||||||
)
|
)
|
||||||
== "testsession"
|
== "sid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_sessions():
|
def test_get_sessions():
|
||||||
assert OwnershipDB(settings).sessions("something.something@something") == [
|
assert OwnershipDB(settings).sessions("test@test.com") == [
|
||||||
"testsession"
|
{"sid": "sid", "session_name": "session_name"}
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in a new issue