Better tests and docs
This commit is contained in:
		
							parent
							
								
									d642bb0a39
								
							
						
					
					
						commit
						e8c7600ed7
					
				|  | @ -6,9 +6,11 @@ pydantic-settings | ||||||
| s3fs | s3fs | ||||||
| aiofiles | aiofiles | ||||||
| duckdb | duckdb | ||||||
|  | duckdb-engine | ||||||
| polars | polars | ||||||
| pyarrow | pyarrow | ||||||
| pyjwt[crypto] | pyjwt[crypto] | ||||||
| python-multipart | python-multipart | ||||||
| authlib | authlib | ||||||
| itsdangerous | itsdangerous | ||||||
|  | sqlalchemy | ||||||
|  | @ -1,8 +1,7 @@ | ||||||
| from enum import StrEnum | from enum import StrEnum | ||||||
|  | from sqlalchemy import create_engine, text | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| import duckdb |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class StorageEngines(StrEnum): | class StorageEngines(StrEnum): | ||||||
|     local = "Local" |     local = "Local" | ||||||
|  | @ -19,11 +18,13 @@ class DDB: | ||||||
|         bucket: str | None = None, |         bucket: str | None = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         self.db = duckdb.connect() |         self.engine = create_engine( | ||||||
|         self.db.install_extension("spatial") |             "duckdb:///:memory:", | ||||||
|         self.db.install_extension("httpfs") |             connect_args={ | ||||||
|         self.db.load_extension("spatial") |                 "preload_extensions": ["https", "spatial"], | ||||||
|         self.db.load_extension("httpfs") |                 "config": {"memory_limit": "300mb"}, | ||||||
|  |             }, | ||||||
|  |         ) | ||||||
|         self.sheets = tuple() |         self.sheets = tuple() | ||||||
|         self.loaded = False |         self.loaded = False | ||||||
| 
 | 
 | ||||||
|  | @ -35,12 +36,18 @@ class DDB: | ||||||
|                     bucket is not None, |                     bucket is not None, | ||||||
|                 ) |                 ) | ||||||
|             ): |             ): | ||||||
|                 self.db.sql(f""" |                 with self.engine.connect() as conn: | ||||||
|  |                     conn.execute( | ||||||
|  |                         text( | ||||||
|  |                             f""" | ||||||
|                     CREATE SECRET ( |                     CREATE SECRET ( | ||||||
|                     TYPE GCS, |                     TYPE GCS, | ||||||
|                     KEY_ID '{gcs_access}', |                     KEY_ID '{gcs_access}', | ||||||
|                     SECRET '{gcs_secret}') |                     SECRET '{gcs_secret}') | ||||||
|                     """) |                     """ | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|                 self.path_prefix = f"gcs://{bucket}" |                 self.path_prefix = f"gcs://{bucket}" | ||||||
|             else: |             else: | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|  | @ -55,3 +62,7 @@ class DDB: | ||||||
|                 raise ValueError( |                 raise ValueError( | ||||||
|                     "With local storage you need to provide the path keyword argument" |                     "With local storage you need to provide the path keyword argument" | ||||||
|                 ) |                 ) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def db(self): | ||||||
|  |         return self.engine.raw_connection() | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| from fastapi import APIRouter | from fastapi import APIRouter | ||||||
| from fastapi.responses import PlainTextResponse | from fastapi.responses import PlainTextResponse | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| from hellocomputer.db import StorageEngines | from hellocomputer.db import StorageEngines | ||||||
| from hellocomputer.extraction import extract_code_block | from hellocomputer.extraction import extract_code_block | ||||||
| from hellocomputer.sessions import SessionDB | from hellocomputer.sessions import SessionDB | ||||||
|  | @ -13,7 +14,7 @@ router = APIRouter() | ||||||
| 
 | 
 | ||||||
| @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) | @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) | ||||||
| async def query(sid: str = "", q: str = "") -> str: | async def query(sid: str = "", q: str = "") -> str: | ||||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) |     llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||||
|     db = SessionDB( |     db = SessionDB( | ||||||
|         StorageEngines.gcs, |         StorageEngines.gcs, | ||||||
|         gcs_access=settings.gcs_access, |         gcs_access=settings.gcs_access, | ||||||
|  | @ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str: | ||||||
|         sid=sid, |         sid=sid, | ||||||
|     ).load_folder() |     ).load_folder() | ||||||
| 
 | 
 | ||||||
|     chat = await chat.eval("You're an expert sql developer", db.query_prompt(q)) |     chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) | ||||||
|     query = extract_code_block(chat.last_response_content()) |     query = extract_code_block(chat.last_response_content()) | ||||||
|     result = str(db.query(query)) |     result = str(db.query(query)) | ||||||
|     print(result) |  | ||||||
| 
 | 
 | ||||||
|     return result |     return result | ||||||
|  |  | ||||||
|  | @ -149,7 +149,7 @@ class SessionDB(DDB): | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def schema(self): |     def schema(self) -> str: | ||||||
|         return os.linesep.join( |         return os.linesep.join( | ||||||
|             [ |             [ | ||||||
|                 "The schema of the database is the following:", |                 "The schema of the database is the following:", | ||||||
|  |  | ||||||
|  | @ -37,7 +37,7 @@ | ||||||
|             <p class="techie-font"> |             <p class="techie-font"> | ||||||
|                 Hola, computer! is a web assistant that allows you to query excel files using natural language. It may |                 Hola, computer! is a web assistant that allows you to query excel files using natural language. It may | ||||||
|                 not be as powerful as Excel, but it has an efficient query backend that can process your data faster |                 not be as powerful as Excel, but it has an efficient query backend that can process your data faster | ||||||
|                 and more efficiently than Excel. |                 than Excel. | ||||||
|             </p> |             </p> | ||||||
|             <a href="/"><button class="btn btn-secondary w-100">Back</button></a> |             <a href="/"><button class="btn btn-secondary w-100">Back</button></a> | ||||||
|         </div> |         </div> | ||||||
|  |  | ||||||
|  | @ -25,9 +25,11 @@ | ||||||
|                 </p> |                 </p> | ||||||
|                 <a href="#" class="list-group-item list-group-item-action bg-light"><i |                 <a href="#" class="list-group-item list-group-item-action bg-light"><i | ||||||
|                         class="bi bi-question-circle"></i> How to</a> |                         class="bi bi-question-circle"></i> How to</a> | ||||||
|                 <a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i> |                 <a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i | ||||||
|  |                         class="bi bi-file-ruled"></i> | ||||||
|                     File templates</a> |                     File templates</a> | ||||||
|                 <a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i> |                 <a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i | ||||||
|  |                         class="bi bi-info-circle"></i> | ||||||
|                     About</a> |                     About</a> | ||||||
|                 <a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i> |                 <a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i> | ||||||
|                     Config</a> |                     Config</a> | ||||||
|  |  | ||||||
|  | @ -71,9 +71,7 @@ function addAIManualMessage(m) { | ||||||
|     chatMessages.prepend(newMessage); // Add new message at the top
 |     chatMessages.prepend(newMessage); // Add new message at the top
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| function addUserMessage() { | function addUserMessageBlock(messageContent) { | ||||||
|     const messageContent = textarea.value.trim(); |  | ||||||
|     if (messageContent) { |  | ||||||
|     const newMessage = document.createElement('div'); |     const newMessage = document.createElement('div'); | ||||||
|     newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); |     newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); | ||||||
|     newMessage.textContent = messageContent; |     newMessage.textContent = messageContent; | ||||||
|  | @ -81,8 +79,20 @@ function addUserMessage() { | ||||||
|     textarea.value = ''; // Clear the textarea
 |     textarea.value = ''; // Clear the textarea
 | ||||||
|     textarea.style.height = 'auto'; // Reset the textarea height
 |     textarea.style.height = 'auto'; // Reset the textarea height
 | ||||||
|     textarea.style.overflowY = 'hidden'; |     textarea.style.overflowY = 'hidden'; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | function addUserMessage() { | ||||||
|  |     const messageContent = textarea.value.trim(); | ||||||
|  |     if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') { | ||||||
|  |         textarea.value = ''; | ||||||
|  |         addAIManualMessage('Please upload a data file or select a session first!'); | ||||||
|  |     } | ||||||
|  |     else { | ||||||
|  |         if (messageContent) { | ||||||
|  |             addUserMessageBlock(messageContent); | ||||||
|             addAIMessage(messageContent); |             addAIMessage(messageContent); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| sendButton.addEventListener('click', addUserMessage); | sendButton.addEventListener('click', addUserMessage); | ||||||
|  | @ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () { | ||||||
|         try { |         try { | ||||||
|             const session_response = await fetch('/new_session'); |             const session_response = await fetch('/new_session'); | ||||||
|             sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); |             sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); | ||||||
|  |             sessionStorage.setItem("helloComputerSessionLoaded", false); | ||||||
| 
 | 
 | ||||||
|             const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); |             const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); | ||||||
| 
 | 
 | ||||||
|  | @ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () { | ||||||
| 
 | 
 | ||||||
|             const data = await response.text(); |             const data = await response.text(); | ||||||
|             uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; |             uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; | ||||||
|  |             sessionStorage.setItem("helloComputerSessionLoaded", true); | ||||||
| 
 | 
 | ||||||
|             addAIManualMessage('File uploaded and processed!'); |             addAIManualMessage('File uploaded and processed!'); | ||||||
|         } catch (error) { |         } catch (error) { | ||||||
|  |  | ||||||
							
								
								
									
										
											BIN
										
									
								
								src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							|  | @ -99,10 +99,10 @@ class OwnershipDB(DDB): | ||||||
|             FROM |             FROM | ||||||
|                 '{self.path_prefix}/*.csv' |                 '{self.path_prefix}/*.csv' | ||||||
|             WHERE |             WHERE | ||||||
|                 email = '{user_email} |                 email = '{user_email}' | ||||||
|             ORDER BY |             ORDER BY | ||||||
|                 timestamp ASC |                 timestamp ASC | ||||||
|             LIMIT 10' |             LIMIT 10 | ||||||
|         """) |         """) | ||||||
|                 .pl() |                 .pl() | ||||||
|                 .to_series() |                 .to_series() | ||||||
|  |  | ||||||
|  | @ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block | ||||||
| from hellocomputer.models import Chat | from hellocomputer.models import Chat | ||||||
| from hellocomputer.sessions import SessionDB | from hellocomputer.sessions import SessionDB | ||||||
| 
 | 
 | ||||||
|  | TEST_STORAGE = StorageEngines.local | ||||||
|  | TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" | ||||||
| TEST_XLS_PATH = ( | TEST_XLS_PATH = ( | ||||||
|     Path(hellocomputer.__file__).parents[2] |     Path(hellocomputer.__file__).parents[2] | ||||||
|     / "test" |     / "test" | ||||||
|     / "data" |     / "data" | ||||||
|     / "TestExcelHelloComputer.xlsx" |     / "TestExcelHelloComputer.xlsx" | ||||||
| ) | ) | ||||||
|  | SID = "test" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
|  | @ -35,9 +38,28 @@ async def test_simple_data_query(): | ||||||
| 
 | 
 | ||||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) |     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||||
|     db = SessionDB( |     db = SessionDB( | ||||||
|         storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent |         storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID | ||||||
|     ).load_xls(TEST_XLS_PATH) |     ).load_xls(TEST_XLS_PATH) | ||||||
| 
 | 
 | ||||||
|     chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) |     chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) | ||||||
|     query = extract_code_block(chat.last_response_content()) |     query = extract_code_block(chat.last_response_content()) | ||||||
|     assert query.startswith("SELECT") |     assert query.startswith("SELECT") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | @pytest.mark.skipif( | ||||||
|  |     settings.anyscale_api_key == "Awesome API", reason="API Key not set" | ||||||
|  | ) | ||||||
|  | async def test_data_query(): | ||||||
|  |     q = "find the average score of all the sudents" | ||||||
|  | 
 | ||||||
|  |     llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||||
|  |     db = SessionDB( | ||||||
|  |         storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" | ||||||
|  |     ).load_folder() | ||||||
|  | 
 | ||||||
|  |     chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) | ||||||
|  |     query = extract_code_block(chat.last_response_content()) | ||||||
|  |     result = db.query(query).pl() | ||||||
|  | 
 | ||||||
|  |     assert result.shape[0] == 1 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue