Better tests and docs
This commit is contained in:
		
							parent
							
								
									d642bb0a39
								
							
						
					
					
						commit
						e8c7600ed7
					
				|  | @ -6,9 +6,11 @@ pydantic-settings | |||
| s3fs | ||||
| aiofiles | ||||
| duckdb | ||||
| duckdb-engine | ||||
| polars | ||||
| pyarrow | ||||
| pyjwt[crypto] | ||||
| python-multipart | ||||
| authlib | ||||
| itsdangerous | ||||
| itsdangerous | ||||
| sqlalchemy | ||||
|  | @ -1,8 +1,7 @@ | |||
| from enum import StrEnum | ||||
| from sqlalchemy import create_engine, text | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import duckdb | ||||
| 
 | ||||
| 
 | ||||
| class StorageEngines(StrEnum): | ||||
|     local = "Local" | ||||
|  | @ -19,11 +18,13 @@ class DDB: | |||
|         bucket: str | None = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         self.db = duckdb.connect() | ||||
|         self.db.install_extension("spatial") | ||||
|         self.db.install_extension("httpfs") | ||||
|         self.db.load_extension("spatial") | ||||
|         self.db.load_extension("httpfs") | ||||
|         self.engine = create_engine( | ||||
|             "duckdb:///:memory:", | ||||
|             connect_args={ | ||||
|                 "preload_extensions": ["https", "spatial"], | ||||
|                 "config": {"memory_limit": "300mb"}, | ||||
|             }, | ||||
|         ) | ||||
|         self.sheets = tuple() | ||||
|         self.loaded = False | ||||
| 
 | ||||
|  | @ -35,12 +36,18 @@ class DDB: | |||
|                     bucket is not None, | ||||
|                 ) | ||||
|             ): | ||||
|                 self.db.sql(f""" | ||||
|                 with self.engine.connect() as conn: | ||||
|                     conn.execute( | ||||
|                         text( | ||||
|                             f""" | ||||
|                     CREATE SECRET ( | ||||
|                     TYPE GCS, | ||||
|                     KEY_ID '{gcs_access}', | ||||
|                     SECRET '{gcs_secret}') | ||||
|                     """) | ||||
|                     """ | ||||
|                         ) | ||||
|                     ) | ||||
| 
 | ||||
|                 self.path_prefix = f"gcs://{bucket}" | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|  | @ -55,3 +62,7 @@ class DDB: | |||
|                 raise ValueError( | ||||
|                     "With local storage you need to provide the path keyword argument" | ||||
|                 ) | ||||
| 
 | ||||
|     @property | ||||
|     def db(self): | ||||
|         return self.engine.raw_connection() | ||||
|  |  | |||
|  | @ -1,6 +1,7 @@ | |||
| from fastapi import APIRouter | ||||
| from fastapi.responses import PlainTextResponse | ||||
| 
 | ||||
| 
 | ||||
| from hellocomputer.db import StorageEngines | ||||
| from hellocomputer.extraction import extract_code_block | ||||
| from hellocomputer.sessions import SessionDB | ||||
|  | @ -13,7 +14,7 @@ router = APIRouter() | |||
| 
 | ||||
| @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) | ||||
| async def query(sid: str = "", q: str = "") -> str: | ||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||
|     llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||
|     db = SessionDB( | ||||
|         StorageEngines.gcs, | ||||
|         gcs_access=settings.gcs_access, | ||||
|  | @ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str: | |||
|         sid=sid, | ||||
|     ).load_folder() | ||||
| 
 | ||||
|     chat = await chat.eval("You're an expert sql developer", db.query_prompt(q)) | ||||
|     chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) | ||||
|     query = extract_code_block(chat.last_response_content()) | ||||
|     result = str(db.query(query)) | ||||
|     print(result) | ||||
| 
 | ||||
|     return result | ||||
|  |  | |||
|  | @ -149,7 +149,7 @@ class SessionDB(DDB): | |||
|         ) | ||||
| 
 | ||||
|     @property | ||||
|     def schema(self): | ||||
|     def schema(self) -> str: | ||||
|         return os.linesep.join( | ||||
|             [ | ||||
|                 "The schema of the database is the following:", | ||||
|  |  | |||
|  | @ -37,7 +37,7 @@ | |||
|             <p class="techie-font"> | ||||
|                 Hola, computer! is a web assistant that allows you to query excel files using natural language. It may | ||||
|                 not be as powerful as Excel, but it has an efficient query backend that can process your data faster | ||||
|                 and more efficiently than Excel. | ||||
|                 than Excel. | ||||
|             </p> | ||||
|             <a href="/"><button class="btn btn-secondary w-100">Back</button></a> | ||||
|         </div> | ||||
|  |  | |||
|  | @ -25,9 +25,11 @@ | |||
|                 </p> | ||||
|                 <a href="#" class="list-group-item list-group-item-action bg-light"><i | ||||
|                         class="bi bi-question-circle"></i> How to</a> | ||||
|                 <a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i> | ||||
|                 <a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i | ||||
|                         class="bi bi-file-ruled"></i> | ||||
|                     File templates</a> | ||||
|                 <a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i> | ||||
|                 <a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i | ||||
|                         class="bi bi-info-circle"></i> | ||||
|                     About</a> | ||||
|                 <a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i> | ||||
|                     Config</a> | ||||
|  |  | |||
|  | @ -71,17 +71,27 @@ function addAIManualMessage(m) { | |||
|     chatMessages.prepend(newMessage); // Add new message at the top
 | ||||
| } | ||||
| 
 | ||||
| function addUserMessageBlock(messageContent) { | ||||
|     const newMessage = document.createElement('div'); | ||||
|     newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); | ||||
|     newMessage.textContent = messageContent; | ||||
|     chatMessages.prepend(newMessage); // Add new message at the top
 | ||||
|     textarea.value = ''; // Clear the textarea
 | ||||
|     textarea.style.height = 'auto'; // Reset the textarea height
 | ||||
|     textarea.style.overflowY = 'hidden'; | ||||
| }; | ||||
| 
 | ||||
| function addUserMessage() { | ||||
|     const messageContent = textarea.value.trim(); | ||||
|     if (messageContent) { | ||||
|         const newMessage = document.createElement('div'); | ||||
|         newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); | ||||
|         newMessage.textContent = messageContent; | ||||
|         chatMessages.prepend(newMessage); // Add new message at the top
 | ||||
|         textarea.value = ''; // Clear the textarea
 | ||||
|         textarea.style.height = 'auto'; // Reset the textarea height
 | ||||
|         textarea.style.overflowY = 'hidden'; | ||||
|         addAIMessage(messageContent); | ||||
|     if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') { | ||||
|         textarea.value = ''; | ||||
|         addAIManualMessage('Please upload a data file or select a session first!'); | ||||
|     } | ||||
|     else { | ||||
|         if (messageContent) { | ||||
|             addUserMessageBlock(messageContent); | ||||
|             addAIMessage(messageContent); | ||||
|         } | ||||
|     } | ||||
| }; | ||||
| 
 | ||||
|  | @ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () { | |||
|         try { | ||||
|             const session_response = await fetch('/new_session'); | ||||
|             sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); | ||||
|             sessionStorage.setItem("helloComputerSessionLoaded", false); | ||||
| 
 | ||||
|             const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); | ||||
| 
 | ||||
|  | @ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () { | |||
| 
 | ||||
|             const data = await response.text(); | ||||
|             uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; | ||||
|             sessionStorage.setItem("helloComputerSessionLoaded", true); | ||||
| 
 | ||||
|             addAIManualMessage('File uploaded and processed!'); | ||||
|         } catch (error) { | ||||
|  |  | |||
							
								
								
									
										
											BIN
										
									
								
								src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							|  | @ -99,10 +99,10 @@ class OwnershipDB(DDB): | |||
|             FROM | ||||
|                 '{self.path_prefix}/*.csv' | ||||
|             WHERE | ||||
|                 email = '{user_email} | ||||
|                 email = '{user_email}' | ||||
|             ORDER BY | ||||
|                 timestamp ASC | ||||
|             LIMIT 10' | ||||
|             LIMIT 10 | ||||
|         """) | ||||
|                 .pl() | ||||
|                 .to_series() | ||||
|  |  | |||
|  | @ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block | |||
| from hellocomputer.models import Chat | ||||
| from hellocomputer.sessions import SessionDB | ||||
| 
 | ||||
| TEST_STORAGE = StorageEngines.local | ||||
| TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" | ||||
| TEST_XLS_PATH = ( | ||||
|     Path(hellocomputer.__file__).parents[2] | ||||
|     / "test" | ||||
|     / "data" | ||||
|     / "TestExcelHelloComputer.xlsx" | ||||
| ) | ||||
| SID = "test" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
|  | @ -35,9 +38,28 @@ async def test_simple_data_query(): | |||
| 
 | ||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||
|     db = SessionDB( | ||||
|         storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent | ||||
|         storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID | ||||
|     ).load_xls(TEST_XLS_PATH) | ||||
| 
 | ||||
|     chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) | ||||
|     query = extract_code_block(chat.last_response_content()) | ||||
|     assert query.startswith("SELECT") | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.skipif( | ||||
|     settings.anyscale_api_key == "Awesome API", reason="API Key not set" | ||||
| ) | ||||
| async def test_data_query(): | ||||
|     q = "find the average score of all the sudents" | ||||
| 
 | ||||
|     llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||
|     db = SessionDB( | ||||
|         storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" | ||||
|     ).load_folder() | ||||
| 
 | ||||
|     chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) | ||||
|     query = extract_code_block(chat.last_response_content()) | ||||
|     result = db.query(query).pl() | ||||
| 
 | ||||
|     assert result.shape[0] == 1 | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue