Better tests and docs

This commit is contained in:
Guillem Borrell 2024-06-16 08:56:45 +02:00
parent d642bb0a39
commit e8c7600ed7
10 changed files with 78 additions and 29 deletions

View file

@ -6,9 +6,11 @@ pydantic-settings
s3fs
aiofiles
duckdb
duckdb-engine
polars
pyarrow
pyjwt[crypto]
python-multipart
authlib
itsdangerous
sqlalchemy

View file

@ -1,8 +1,7 @@
from enum import StrEnum
from sqlalchemy import create_engine, text
from pathlib import Path
import duckdb
class StorageEngines(StrEnum):
local = "Local"
@ -19,11 +18,13 @@ class DDB:
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.engine = create_engine(
"duckdb:///:memory:",
connect_args={
"preload_extensions": ["https", "spatial"],
"config": {"memory_limit": "300mb"},
},
)
self.sheets = tuple()
self.loaded = False
@ -35,12 +36,18 @@ class DDB:
bucket is not None,
)
):
self.db.sql(f"""
with self.engine.connect() as conn:
conn.execute(
text(
f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
"""
)
)
self.path_prefix = f"gcs://{bucket}"
else:
raise ValueError(
@ -55,3 +62,7 @@ class DDB:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
@property
def db(self):
return self.engine.raw_connection()

View file

@ -1,6 +1,7 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block
from hellocomputer.sessions import SessionDB
@ -13,7 +14,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
@ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str:
sid=sid,
).load_folder()
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q))
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = str(db.query(query))
print(result)
return result

View file

@ -149,7 +149,7 @@ class SessionDB(DDB):
)
@property
def schema(self):
def schema(self) -> str:
return os.linesep.join(
[
"The schema of the database is the following:",

View file

@ -37,7 +37,7 @@
<p class="techie-font">
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
not be as powerful as Excel, but it has an efficient query backend that can process your data faster
and more efficiently than Excel.
than Excel.
</p>
<a href="/"><button class="btn btn-secondary w-100">Back</button></a>
</div>

View file

@ -25,9 +25,11 @@
</p>
<a href="#" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-question-circle"></i> How to</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i>
<a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-file-ruled"></i>
File templates</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i>
<a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-info-circle"></i>
About</a>
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
Config</a>

View file

@ -71,9 +71,7 @@ function addAIManualMessage(m) {
chatMessages.prepend(newMessage); // Add new message at the top
}
function addUserMessage() {
const messageContent = textarea.value.trim();
if (messageContent) {
function addUserMessageBlock(messageContent) {
const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
newMessage.textContent = messageContent;
@ -81,8 +79,20 @@ function addUserMessage() {
textarea.value = ''; // Clear the textarea
textarea.style.height = 'auto'; // Reset the textarea height
textarea.style.overflowY = 'hidden';
};
function addUserMessage() {
const messageContent = textarea.value.trim();
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
textarea.value = '';
addAIManualMessage('Please upload a data file or select a session first!');
}
else {
if (messageContent) {
addUserMessageBlock(messageContent);
addAIMessage(messageContent);
}
}
};
sendButton.addEventListener('click', addUserMessage);
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
try {
const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
sessionStorage.setItem("helloComputerSessionLoaded", false);
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
sessionStorage.setItem("helloComputerSessionLoaded", true);
addAIManualMessage('File uploaded and processed!');
} catch (error) {

View file

@ -99,10 +99,10 @@ class OwnershipDB(DDB):
FROM
'{self.path_prefix}/*.csv'
WHERE
email = '{user_email}
email = '{user_email}'
ORDER BY
timestamp ASC
LIMIT 10'
LIMIT 10
""")
.pl()
.to_series()

View file

@ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
from hellocomputer.sessions import SessionDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
SID = "test"
@pytest.mark.asyncio
@ -35,9 +38,28 @@ async def test_simple_data_query():
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_data_query():
q = "find the average score of all the sudents"
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = db.query(query).pl()
assert result.shape[0] == 1