File upload and session management work
	
		
			
	
		
	
	
		
	
		
			Some checks failed
		
		
	
	
		
			
				
	
				ci/woodpecker/push/woodpecker Pipeline failed
				
			
		
		
	
	
				
					
				
			
		
			Some checks failed
		
		
	
	ci/woodpecker/push/woodpecker Pipeline failed
				
			This commit is contained in:
		
							parent
							
								
									edd64d468b
								
							
						
					
					
						commit
						e1ffeef646
					
				|  | @ -1,7 +1,7 @@ | |||
| import json | ||||
| import os | ||||
| from datetime import datetime | ||||
| from typing import List | ||||
| from typing import List, Dict | ||||
| from uuid import UUID, uuid4 | ||||
| 
 | ||||
| import duckdb | ||||
|  | @ -64,7 +64,13 @@ class OwnershipDB(DDB): | |||
|         elif settings.storage_engine == StorageEngines.local: | ||||
|             self.path_prefix = settings.path / "owners" | ||||
| 
 | ||||
|     def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None): | ||||
|     def set_ownership( | ||||
|         self, | ||||
|         user_email: str, | ||||
|         sid: str, | ||||
|         session_name: str, | ||||
|         record_id: UUID | None = None, | ||||
|     ): | ||||
|         now = datetime.now().isoformat() | ||||
|         record_id = uuid4() if record_id is None else record_id | ||||
|         query = f""" | ||||
|  | @ -73,6 +79,7 @@ class OwnershipDB(DDB): | |||
|             SELECT | ||||
|               '{user_email}' as email, | ||||
|               '{sid}' as sid, | ||||
|               '{session_name}' as session_name, | ||||
|               '{now}' as timestamp | ||||
|           ) | ||||
|         TO '{self.path_prefix}/{record_id}.csv'""" | ||||
|  | @ -85,12 +92,12 @@ class OwnershipDB(DDB): | |||
| 
 | ||||
|         return sid | ||||
| 
 | ||||
|     def sessions(self, user_email: str) -> List[str]: | ||||
|     def sessions(self, user_email: str) -> List[Dict[str, str]]: | ||||
|         try: | ||||
|             return ( | ||||
|                 self.db.sql(f""" | ||||
|             SELECT | ||||
|                 sid | ||||
|                 sid, session_name | ||||
|             FROM | ||||
|                 '{self.path_prefix}/*.csv' | ||||
|             WHERE | ||||
|  | @ -100,8 +107,7 @@ class OwnershipDB(DDB): | |||
|             LIMIT 10 | ||||
|         """) | ||||
|                 .pl() | ||||
|                 .to_series() | ||||
|                 .to_list() | ||||
|                 .to_dicts() | ||||
|             ) | ||||
|         # If the table does not exist | ||||
|         except duckdb.duckdb.IOException: | ||||
|  |  | |||
|  | @ -1,27 +1,13 @@ | |||
| from typing import Literal | ||||
| 
 | ||||
| from langchain_openai import ChatOpenAI | ||||
| from langgraph.graph import END, START, MessagesState, StateGraph | ||||
| 
 | ||||
| from hellocomputer.config import settings | ||||
| from hellocomputer.extraction import initial_intent_parser | ||||
| from hellocomputer.models import AvailableModels | ||||
| from hellocomputer.prompts import Prompts | ||||
| 
 | ||||
| 
 | ||||
| async def intent(state: MessagesState): | ||||
|     messages = state["messages"] | ||||
|     query = messages[-1] | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.intent() | ||||
|     chain = prompt | llm | initial_intent_parser | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({"query", query.content})]} | ||||
| from hellocomputer.nodes import ( | ||||
|     intent, | ||||
|     answer_general, | ||||
|     answer_query, | ||||
|     answer_visualization, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def route_intent(state: MessagesState) -> Literal["general", "query", "visualization"]: | ||||
|  | @ -30,52 +16,17 @@ def route_intent(state: MessagesState) -> Literal["general", "query", "visualiza | |||
|     return last_message.content | ||||
| 
 | ||||
| 
 | ||||
| async def answer_general(state: MessagesState): | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.general() | ||||
|     chain = prompt | llm | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({})]} | ||||
| 
 | ||||
| 
 | ||||
| async def answer_query(state: MessagesState): | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.sql() | ||||
|     chain = prompt | llm | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({})]} | ||||
| 
 | ||||
| 
 | ||||
| async def answer_visualization(state: MessagesState): | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.visualization() | ||||
|     chain = prompt | llm | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({})]} | ||||
| 
 | ||||
| 
 | ||||
| workflow = StateGraph(MessagesState) | ||||
| 
 | ||||
| # Nodes | ||||
| 
 | ||||
| workflow.add_node("intent", intent) | ||||
| workflow.add_node("answer_general", answer_general) | ||||
| workflow.add_node("answer_query", answer_query) | ||||
| workflow.add_node("answer_visualization", answer_visualization) | ||||
| 
 | ||||
| # Edges | ||||
| 
 | ||||
| workflow.add_edge(START, "intent") | ||||
| workflow.add_conditional_edges( | ||||
|     "intent", | ||||
|  |  | |||
							
								
								
									
										61
									
								
								src/hellocomputer/nodes.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								src/hellocomputer/nodes.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,61 @@ | |||
| from langchain_openai import ChatOpenAI | ||||
| from langgraph.graph import MessagesState | ||||
| 
 | ||||
| from hellocomputer.config import settings | ||||
| from hellocomputer.extraction import initial_intent_parser | ||||
| from hellocomputer.models import AvailableModels | ||||
| from hellocomputer.prompts import Prompts | ||||
| 
 | ||||
| 
 | ||||
| async def intent(state: MessagesState): | ||||
|     messages = state["messages"] | ||||
|     query = messages[-1] | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.intent() | ||||
|     chain = prompt | llm | initial_intent_parser | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({"query", query.content})]} | ||||
| 
 | ||||
| 
 | ||||
| async def answer_general(state: MessagesState): | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.general() | ||||
|     chain = prompt | llm | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({})]} | ||||
| 
 | ||||
| 
 | ||||
| async def answer_query(state: MessagesState): | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.sql() | ||||
|     chain = prompt | llm | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({})]} | ||||
| 
 | ||||
| 
 | ||||
| async def answer_visualization(state: MessagesState): | ||||
|     llm = ChatOpenAI( | ||||
|         base_url=settings.llm_base_url, | ||||
|         api_key=settings.llm_api_key, | ||||
|         model=AvailableModels.llama_small, | ||||
|         temperature=0, | ||||
|     ) | ||||
|     prompt = await Prompts.visualization() | ||||
|     chain = prompt | llm | ||||
| 
 | ||||
|     return {"messages": [await chain.ainvoke({})]} | ||||
|  | @ -5,9 +5,10 @@ from fastapi import APIRouter, File, UploadFile | |||
| from fastapi.responses import JSONResponse | ||||
| from starlette.requests import Request | ||||
| 
 | ||||
| from ..config import StorageEngines, settings | ||||
| from ..config import settings | ||||
| from ..db.sessions import SessionDB | ||||
| from ..db.users import OwnershipDB | ||||
| from ..auth import get_user_email | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
|  | @ -22,7 +23,12 @@ router = APIRouter() | |||
| 
 | ||||
| 
 | ||||
| @router.post("/upload", tags=["files"]) | ||||
| async def upload_file(request: Request, file: UploadFile = File(...), sid: str = ""): | ||||
| async def upload_file( | ||||
|     request: Request, | ||||
|     file: UploadFile = File(...), | ||||
|     sid: str = "", | ||||
|     session_name: str = "", | ||||
| ): | ||||
|     async with aiofiles.tempfile.NamedTemporaryFile("wb") as f: | ||||
|         content = await file.read() | ||||
|         await f.write(content) | ||||
|  | @ -30,22 +36,14 @@ async def upload_file(request: Request, file: UploadFile = File(...), sid: str = | |||
| 
 | ||||
|         ( | ||||
|             SessionDB( | ||||
|                 StorageEngines.gcs, | ||||
|                 gcs_access=settings.gcs_access, | ||||
|                 gcs_secret=settings.gcs_secret, | ||||
|                 bucket=settings.gcs_bucketname, | ||||
|                 settings, | ||||
|                 sid=sid, | ||||
|             ) | ||||
|             .load_xls(f.name) | ||||
|             .dump() | ||||
|         ) | ||||
| 
 | ||||
|         OwnershipDB( | ||||
|             StorageEngines.gcs, | ||||
|             gcs_access=settings.gcs_access, | ||||
|             gcs_secret=settings.gcs_secret, | ||||
|             bucket=settings.gcs_bucketname, | ||||
|         ).set_ownersip(request.session.get("user").get("email"), sid) | ||||
|         OwnershipDB(settings).set_ownership(get_user_email(request), sid, session_name) | ||||
| 
 | ||||
|         return JSONResponse( | ||||
|             content={"message": "File uploaded successfully"}, status_code=200 | ||||
|  |  | |||
|  | @ -1,11 +1,10 @@ | |||
| from typing import List | ||||
| from typing import List, Dict | ||||
| from uuid import uuid4 | ||||
| 
 | ||||
| from fastapi import APIRouter | ||||
| from fastapi.responses import PlainTextResponse | ||||
| from starlette.requests import Request | ||||
| 
 | ||||
| from hellocomputer.config import StorageEngines | ||||
| from hellocomputer.db.users import OwnershipDB | ||||
| 
 | ||||
| from ..auth import get_user_email | ||||
|  | @ -30,12 +29,7 @@ async def get_greeting() -> str: | |||
| 
 | ||||
| 
 | ||||
| @router.get("/sessions") | ||||
| async def get_sessions(request: Request) -> List[str]: | ||||
| async def get_sessions(request: Request) -> List[Dict[str, str]]: | ||||
|     user_email = get_user_email(request) | ||||
|     ownership = OwnershipDB( | ||||
|         StorageEngines.gcs, | ||||
|         gcs_access=settings.gcs_access, | ||||
|         gcs_secret=settings.gcs_secret, | ||||
|         bucket=settings.gcs_bucketname, | ||||
|     ) | ||||
|     ownership = OwnershipDB(settings) | ||||
|     return ownership.sessions(user_email) | ||||
|  |  | |||
|  | @ -106,6 +106,8 @@ | |||
|                                 <ul id="userSessions"> | ||||
|                                 </ul> | ||||
|                             </div> | ||||
|                             <div class="modal-body" id="loadResultDiv"> | ||||
|                             </div> | ||||
|                             <div class="modal-footer"> | ||||
|                                 <button type="button" class="btn btn-secondary" data-bs-dismiss="modal" | ||||
|                                     id="sessionCloseButton">Close</button> | ||||
|  |  | |||
|  | @ -155,7 +155,9 @@ document.addEventListener("DOMContentLoaded", function () { | |||
|         formData.append('file', file); | ||||
| 
 | ||||
|         try { | ||||
|             const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), { | ||||
|             const sid = sessionStorage.getItem('helloComputerSession'); | ||||
|             const session_name = document.getElementById('datasetLabel').value; | ||||
|             const response = await fetch(`/upload?sid=${sid}&session_name=${session_name}`, { | ||||
|                 method: 'POST', | ||||
|                 body: formData | ||||
|             }); | ||||
|  | @ -179,6 +181,7 @@ document.addEventListener("DOMContentLoaded", function () { | |||
| document.addEventListener("DOMContentLoaded", function () { | ||||
|     const sessionsButton = document.getElementById('loadSessionsButton'); | ||||
|     const sessions = document.getElementById('userSessions'); | ||||
|     const loadResultDiv = document.getElementById('loadResultDiv'); | ||||
| 
 | ||||
|     sessionsButton.addEventListener('click', async function fetchSessions() { | ||||
|         try { | ||||
|  | @ -191,8 +194,12 @@ document.addEventListener("DOMContentLoaded", function () { | |||
|             data.forEach(item => { | ||||
|                 const listItem = document.createElement('li'); | ||||
|                 const button = document.createElement('button'); | ||||
|                 button.textContent = item; | ||||
|                 button.addEventListener("click", function () { alert(`You clicked on ${item}`); }); | ||||
|                 button.textContent = item.session_name; | ||||
|                 button.addEventListener("click", function () { | ||||
|                     sessionStorage.setItem("helloComputerSession", item.sid); | ||||
|                     sessionStorage.setItem("helloComputerSessionLoaded", true); | ||||
|                     loadResultDiv.textContent = 'Session loaded'; | ||||
|                 }); | ||||
|                 listItem.appendChild(button); | ||||
|                 sessions.appendChild(listItem); | ||||
|             }); | ||||
|  |  | |||
|  | @ -6,4 +6,6 @@ from langchain.prompts import PromptTemplate | |||
| @pytest.mark.asyncio | ||||
| async def test_get_general_prompt(): | ||||
|     general: PromptTemplate = await Prompts.general() | ||||
|     assert general.format(query="whatever").startswith("You're a helpful assistant") | ||||
|     assert general.format(query="whatever").startswith( | ||||
|         "You've been asked to do a task you can't do" | ||||
|     ) | ||||
|  |  | |||
|  | @ -29,14 +29,14 @@ def test_user_exists(): | |||
| 
 | ||||
| def test_assign_owner(): | ||||
|     assert ( | ||||
|         OwnershipDB(settings).set_ownersip( | ||||
|             "something.something@something", "testsession", "test" | ||||
|         OwnershipDB(settings).set_ownership( | ||||
|             "test@test.com", "sid", "session_name", "record_id" | ||||
|         ) | ||||
|         == "testsession" | ||||
|         == "sid" | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def test_get_sessions(): | ||||
|     assert OwnershipDB(settings).sessions("something.something@something") == [ | ||||
|         "testsession" | ||||
|     assert OwnershipDB(settings).sessions("test@test.com") == [ | ||||
|         {"sid": "sid", "session_name": "session_name"} | ||||
|     ] | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue