Better tests and docs

This commit is contained in:
Guillem Borrell 2024-06-16 08:56:45 +02:00
parent d642bb0a39
commit e8c7600ed7
10 changed files with 78 additions and 29 deletions

View file

@ -6,9 +6,11 @@ pydantic-settings
s3fs s3fs
aiofiles aiofiles
duckdb duckdb
duckdb-engine
polars polars
pyarrow pyarrow
pyjwt[crypto] pyjwt[crypto]
python-multipart python-multipart
authlib authlib
itsdangerous itsdangerous
sqlalchemy

View file

@ -1,8 +1,7 @@
from enum import StrEnum from enum import StrEnum
from sqlalchemy import create_engine, text
from pathlib import Path from pathlib import Path
import duckdb
class StorageEngines(StrEnum): class StorageEngines(StrEnum):
local = "Local" local = "Local"
@ -19,11 +18,13 @@ class DDB:
bucket: str | None = None, bucket: str | None = None,
**kwargs, **kwargs,
): ):
self.db = duckdb.connect() self.engine = create_engine(
self.db.install_extension("spatial") "duckdb:///:memory:",
self.db.install_extension("httpfs") connect_args={
self.db.load_extension("spatial") "preload_extensions": ["https", "spatial"],
self.db.load_extension("httpfs") "config": {"memory_limit": "300mb"},
},
)
self.sheets = tuple() self.sheets = tuple()
self.loaded = False self.loaded = False
@ -35,12 +36,18 @@ class DDB:
bucket is not None, bucket is not None,
) )
): ):
self.db.sql(f""" with self.engine.connect() as conn:
conn.execute(
text(
f"""
CREATE SECRET ( CREATE SECRET (
TYPE GCS, TYPE GCS,
KEY_ID '{gcs_access}', KEY_ID '{gcs_access}',
SECRET '{gcs_secret}') SECRET '{gcs_secret}')
""") """
)
)
self.path_prefix = f"gcs://{bucket}" self.path_prefix = f"gcs://{bucket}"
else: else:
raise ValueError( raise ValueError(
@ -55,3 +62,7 @@ class DDB:
raise ValueError( raise ValueError(
"With local storage you need to provide the path keyword argument" "With local storage you need to provide the path keyword argument"
) )
@property
def db(self):
return self.engine.raw_connection()

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from hellocomputer.sessions import SessionDB from hellocomputer.sessions import SessionDB
@ -13,7 +14,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB( db = SessionDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,
@ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str:
sid=sid, sid=sid,
).load_folder() ).load_folder()
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q)) chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
result = str(db.query(query)) result = str(db.query(query))
print(result)
return result return result

View file

@ -149,7 +149,7 @@ class SessionDB(DDB):
) )
@property @property
def schema(self): def schema(self) -> str:
return os.linesep.join( return os.linesep.join(
[ [
"The schema of the database is the following:", "The schema of the database is the following:",

View file

@ -37,7 +37,7 @@
<p class="techie-font"> <p class="techie-font">
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
not be as powerful as Excel, but it has an efficient query backend that can process your data faster not be as powerful as Excel, but it has an efficient query backend that can process your data faster
and more efficiently than Excel. than Excel.
</p> </p>
<a href="/"><button class="btn btn-secondary w-100">Back</button></a> <a href="/"><button class="btn btn-secondary w-100">Back</button></a>
</div> </div>

View file

@ -25,9 +25,11 @@
</p> </p>
<a href="#" class="list-group-item list-group-item-action bg-light"><i <a href="#" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-question-circle"></i> How to</a> class="bi bi-question-circle"></i> How to</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i> <a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-file-ruled"></i>
File templates</a> File templates</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i> <a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-info-circle"></i>
About</a> About</a>
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i> <a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
Config</a> Config</a>

View file

@ -71,9 +71,7 @@ function addAIManualMessage(m) {
chatMessages.prepend(newMessage); // Add new message at the top chatMessages.prepend(newMessage); // Add new message at the top
} }
function addUserMessage() { function addUserMessageBlock(messageContent) {
const messageContent = textarea.value.trim();
if (messageContent) {
const newMessage = document.createElement('div'); const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
newMessage.textContent = messageContent; newMessage.textContent = messageContent;
@ -81,8 +79,20 @@ function addUserMessage() {
textarea.value = ''; // Clear the textarea textarea.value = ''; // Clear the textarea
textarea.style.height = 'auto'; // Reset the textarea height textarea.style.height = 'auto'; // Reset the textarea height
textarea.style.overflowY = 'hidden'; textarea.style.overflowY = 'hidden';
};
function addUserMessage() {
const messageContent = textarea.value.trim();
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
textarea.value = '';
addAIManualMessage('Please upload a data file or select a session first!');
}
else {
if (messageContent) {
addUserMessageBlock(messageContent);
addAIMessage(messageContent); addAIMessage(messageContent);
} }
}
}; };
sendButton.addEventListener('click', addUserMessage); sendButton.addEventListener('click', addUserMessage);
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
try { try {
const session_response = await fetch('/new_session'); const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
sessionStorage.setItem("helloComputerSessionLoaded", false);
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
const data = await response.text(); const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
sessionStorage.setItem("helloComputerSessionLoaded", true);
addAIManualMessage('File uploaded and processed!'); addAIManualMessage('File uploaded and processed!');
} catch (error) { } catch (error) {

View file

@ -99,10 +99,10 @@ class OwnershipDB(DDB):
FROM FROM
'{self.path_prefix}/*.csv' '{self.path_prefix}/*.csv'
WHERE WHERE
email = '{user_email} email = '{user_email}'
ORDER BY ORDER BY
timestamp ASC timestamp ASC
LIMIT 10' LIMIT 10
""") """)
.pl() .pl()
.to_series() .to_series()

View file

@ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat from hellocomputer.models import Chat
from hellocomputer.sessions import SessionDB from hellocomputer.sessions import SessionDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
TEST_XLS_PATH = ( TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2] Path(hellocomputer.__file__).parents[2]
/ "test" / "test"
/ "data" / "data"
/ "TestExcelHelloComputer.xlsx" / "TestExcelHelloComputer.xlsx"
) )
SID = "test"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -35,9 +38,28 @@ async def test_simple_data_query():
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB( db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
).load_xls(TEST_XLS_PATH) ).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT") assert query.startswith("SELECT")
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_data_query():
q = "find the average score of all the sudents"
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = db.query(query).pl()
assert result.shape[0] == 1