diff --git a/requirements.in b/requirements.in index 640ff7d..9521c80 100644 --- a/requirements.in +++ b/requirements.in @@ -6,9 +6,11 @@ pydantic-settings s3fs aiofiles duckdb +duckdb-engine polars pyarrow pyjwt[crypto] python-multipart authlib -itsdangerous \ No newline at end of file +itsdangerous +sqlalchemy \ No newline at end of file diff --git a/src/hellocomputer/db.py b/src/hellocomputer/db.py index 3630616..528ac00 100644 --- a/src/hellocomputer/db.py +++ b/src/hellocomputer/db.py @@ -1,8 +1,7 @@ from enum import StrEnum +from sqlalchemy import create_engine, text from pathlib import Path -import duckdb - class StorageEngines(StrEnum): local = "Local" @@ -19,11 +18,13 @@ class DDB: bucket: str | None = None, **kwargs, ): - self.db = duckdb.connect() - self.db.install_extension("spatial") - self.db.install_extension("httpfs") - self.db.load_extension("spatial") - self.db.load_extension("httpfs") + self.engine = create_engine( + "duckdb:///:memory:", + connect_args={ + "preload_extensions": ["https", "spatial"], + "config": {"memory_limit": "300mb"}, + }, + ) self.sheets = tuple() self.loaded = False @@ -35,12 +36,18 @@ class DDB: bucket is not None, ) ): - self.db.sql(f""" + with self.engine.connect() as conn: + conn.execute( + text( + f""" CREATE SECRET ( TYPE GCS, KEY_ID '{gcs_access}', SECRET '{gcs_secret}') - """) + """ + ) + ) + self.path_prefix = f"gcs://{bucket}" else: raise ValueError( @@ -55,3 +62,7 @@ class DDB: raise ValueError( "With local storage you need to provide the path keyword argument" ) + + @property + def db(self): + return self.engine.raw_connection() diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 13e9331..ab39b98 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse + from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block from hellocomputer.sessions import SessionDB @@ -13,7 +14,7 @@ router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) async def query(sid: str = "", q: str = "") -> str: - chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) + llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) db = SessionDB( StorageEngines.gcs, gcs_access=settings.gcs_access, @@ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str: sid=sid, ).load_folder() - chat = await chat.eval("You're an expert sql developer", db.query_prompt(q)) + chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) query = extract_code_block(chat.last_response_content()) result = str(db.query(query)) - print(result) return result diff --git a/src/hellocomputer/sessions.py b/src/hellocomputer/sessions.py index bb1c0d9..dd96104 100644 --- a/src/hellocomputer/sessions.py +++ b/src/hellocomputer/sessions.py @@ -149,7 +149,7 @@ class SessionDB(DDB): ) @property - def schema(self): + def schema(self) -> str: return os.linesep.join( [ "The schema of the database is the following:", diff --git a/src/hellocomputer/static/about.html b/src/hellocomputer/static/about.html index 04ef787..ae7ccdf 100644 --- a/src/hellocomputer/static/about.html +++ b/src/hellocomputer/static/about.html @@ -37,7 +37,7 @@

Hola, computer! is a web assistant that allows you to query excel files using natural language. It may not be as powerful as Excel, but it has an efficient query backend that can process your data faster - and more efficiently than Excel. + than Excel.

diff --git a/src/hellocomputer/static/index.html b/src/hellocomputer/static/index.html index 30bbe4f..1e1c522 100644 --- a/src/hellocomputer/static/index.html +++ b/src/hellocomputer/static/index.html @@ -25,9 +25,11 @@

How to - + File templates - + About Config diff --git a/src/hellocomputer/static/script.js b/src/hellocomputer/static/script.js index 81f15d7..3c3bc97 100644 --- a/src/hellocomputer/static/script.js +++ b/src/hellocomputer/static/script.js @@ -71,17 +71,27 @@ function addAIManualMessage(m) { chatMessages.prepend(newMessage); // Add new message at the top } +function addUserMessageBlock(messageContent) { + const newMessage = document.createElement('div'); + newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); + newMessage.textContent = messageContent; + chatMessages.prepend(newMessage); // Add new message at the top + textarea.value = ''; // Clear the textarea + textarea.style.height = 'auto'; // Reset the textarea height + textarea.style.overflowY = 'hidden'; +}; + function addUserMessage() { const messageContent = textarea.value.trim(); - if (messageContent) { - const newMessage = document.createElement('div'); - newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); - newMessage.textContent = messageContent; - chatMessages.prepend(newMessage); // Add new message at the top - textarea.value = ''; // Clear the textarea - textarea.style.height = 'auto'; // Reset the textarea height - textarea.style.overflowY = 'hidden'; - addAIMessage(messageContent); + if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') { + textarea.value = ''; + addAIManualMessage('Please upload a data file or select a session first!'); + } + else { + if (messageContent) { + addUserMessageBlock(messageContent); + addAIMessage(messageContent); + } } }; @@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () { try { const session_response = await fetch('/new_session'); sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); + sessionStorage.setItem("helloComputerSessionLoaded", false); const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); @@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () { const data = await response.text(); uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; + sessionStorage.setItem("helloComputerSessionLoaded", true); addAIManualMessage('File uploaded and processed!'); } catch (error) { diff --git a/src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx b/src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx new file mode 100644 index 0000000..5568616 Binary files /dev/null and b/src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx differ diff --git a/src/hellocomputer/users.py b/src/hellocomputer/users.py index 2f65e94..a5cf93d 100644 --- a/src/hellocomputer/users.py +++ b/src/hellocomputer/users.py @@ -99,10 +99,10 @@ class OwnershipDB(DDB): FROM '{self.path_prefix}/*.csv' WHERE - email = '{user_email} + email = '{user_email}' ORDER BY timestamp ASC - LIMIT 10' + LIMIT 10 """) .pl() .to_series() diff --git a/test/test_query.py b/test/test_query.py index d4d56a2..65b306d 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block from hellocomputer.models import Chat from hellocomputer.sessions import SessionDB +TEST_STORAGE = StorageEngines.local +TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_XLS_PATH = ( Path(hellocomputer.__file__).parents[2] / "test" / "data" / "TestExcelHelloComputer.xlsx" ) +SID = "test" @pytest.mark.asyncio @@ -35,9 +38,28 @@ async def test_simple_data_query(): chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) db = SessionDB( - storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent + storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID ).load_xls(TEST_XLS_PATH) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) query = extract_code_block(chat.last_response_content()) assert query.startswith("SELECT") + + +@pytest.mark.asyncio +@pytest.mark.skipif( + settings.anyscale_api_key == "Awesome API", reason="API Key not set" +) +async def test_data_query(): + q = "find the average score of all the sudents" + + llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) + db = SessionDB( + storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" + ).load_folder() + + chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) + query = extract_code_block(chat.last_response_content()) + result = db.query(query).pl() + + assert result.shape[0] == 1