Better tests and docs
This commit is contained in:
parent
d642bb0a39
commit
e8c7600ed7
|
@ -6,9 +6,11 @@ pydantic-settings
|
|||
s3fs
|
||||
aiofiles
|
||||
duckdb
|
||||
duckdb-engine
|
||||
polars
|
||||
pyarrow
|
||||
pyjwt[crypto]
|
||||
python-multipart
|
||||
authlib
|
||||
itsdangerous
|
||||
itsdangerous
|
||||
sqlalchemy
|
|
@ -1,8 +1,7 @@
|
|||
from enum import StrEnum
|
||||
from sqlalchemy import create_engine, text
|
||||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
local = "Local"
|
||||
|
@ -19,11 +18,13 @@ class DDB:
|
|||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.db = duckdb.connect()
|
||||
self.db.install_extension("spatial")
|
||||
self.db.install_extension("httpfs")
|
||||
self.db.load_extension("spatial")
|
||||
self.db.load_extension("httpfs")
|
||||
self.engine = create_engine(
|
||||
"duckdb:///:memory:",
|
||||
connect_args={
|
||||
"preload_extensions": ["https", "spatial"],
|
||||
"config": {"memory_limit": "300mb"},
|
||||
},
|
||||
)
|
||||
self.sheets = tuple()
|
||||
self.loaded = False
|
||||
|
||||
|
@ -35,12 +36,18 @@ class DDB:
|
|||
bucket is not None,
|
||||
)
|
||||
):
|
||||
self.db.sql(f"""
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE SECRET (
|
||||
TYPE GCS,
|
||||
KEY_ID '{gcs_access}',
|
||||
SECRET '{gcs_secret}')
|
||||
""")
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
self.path_prefix = f"gcs://{bucket}"
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -55,3 +62,7 @@ class DDB:
|
|||
raise ValueError(
|
||||
"With local storage you need to provide the path keyword argument"
|
||||
)
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.engine.raw_connection()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.sessions import SessionDB
|
||||
|
@ -13,7 +14,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||
async def query(sid: str = "", q: str = "") -> str:
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
|
@ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str:
|
|||
sid=sid,
|
||||
).load_folder()
|
||||
|
||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q))
|
||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = str(db.query(query))
|
||||
print(result)
|
||||
|
||||
return result
|
||||
|
|
|
@ -149,7 +149,7 @@ class SessionDB(DDB):
|
|||
)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
def schema(self) -> str:
|
||||
return os.linesep.join(
|
||||
[
|
||||
"The schema of the database is the following:",
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
<p class="techie-font">
|
||||
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
|
||||
not be as powerful as Excel, but it has an efficient query backend that can process your data faster
|
||||
and more efficiently than Excel.
|
||||
than Excel.
|
||||
</p>
|
||||
<a href="/"><button class="btn btn-secondary w-100">Back</button></a>
|
||||
</div>
|
||||
|
|
|
@ -25,9 +25,11 @@
|
|||
</p>
|
||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i
|
||||
class="bi bi-question-circle"></i> How to</a>
|
||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i>
|
||||
<a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
|
||||
class="bi bi-file-ruled"></i>
|
||||
File templates</a>
|
||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i>
|
||||
<a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
|
||||
class="bi bi-info-circle"></i>
|
||||
About</a>
|
||||
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
|
||||
Config</a>
|
||||
|
|
|
@ -71,17 +71,27 @@ function addAIManualMessage(m) {
|
|||
chatMessages.prepend(newMessage); // Add new message at the top
|
||||
}
|
||||
|
||||
function addUserMessageBlock(messageContent) {
|
||||
const newMessage = document.createElement('div');
|
||||
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
|
||||
newMessage.textContent = messageContent;
|
||||
chatMessages.prepend(newMessage); // Add new message at the top
|
||||
textarea.value = ''; // Clear the textarea
|
||||
textarea.style.height = 'auto'; // Reset the textarea height
|
||||
textarea.style.overflowY = 'hidden';
|
||||
};
|
||||
|
||||
function addUserMessage() {
|
||||
const messageContent = textarea.value.trim();
|
||||
if (messageContent) {
|
||||
const newMessage = document.createElement('div');
|
||||
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
|
||||
newMessage.textContent = messageContent;
|
||||
chatMessages.prepend(newMessage); // Add new message at the top
|
||||
textarea.value = ''; // Clear the textarea
|
||||
textarea.style.height = 'auto'; // Reset the textarea height
|
||||
textarea.style.overflowY = 'hidden';
|
||||
addAIMessage(messageContent);
|
||||
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
|
||||
textarea.value = '';
|
||||
addAIManualMessage('Please upload a data file or select a session first!');
|
||||
}
|
||||
else {
|
||||
if (messageContent) {
|
||||
addUserMessageBlock(messageContent);
|
||||
addAIMessage(messageContent);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
try {
|
||||
const session_response = await fetch('/new_session');
|
||||
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
|
||||
sessionStorage.setItem("helloComputerSessionLoaded", false);
|
||||
|
||||
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
|
||||
|
||||
|
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
|
||||
const data = await response.text();
|
||||
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
|
||||
sessionStorage.setItem("helloComputerSessionLoaded", true);
|
||||
|
||||
addAIManualMessage('File uploaded and processed!');
|
||||
} catch (error) {
|
||||
|
|
BIN
src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
Normal file
BIN
src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
Normal file
Binary file not shown.
|
@ -99,10 +99,10 @@ class OwnershipDB(DDB):
|
|||
FROM
|
||||
'{self.path_prefix}/*.csv'
|
||||
WHERE
|
||||
email = '{user_email}
|
||||
email = '{user_email}'
|
||||
ORDER BY
|
||||
timestamp ASC
|
||||
LIMIT 10'
|
||||
LIMIT 10
|
||||
""")
|
||||
.pl()
|
||||
.to_series()
|
||||
|
|
|
@ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block
|
|||
from hellocomputer.models import Chat
|
||||
from hellocomputer.sessions import SessionDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
TEST_XLS_PATH = (
|
||||
Path(hellocomputer.__file__).parents[2]
|
||||
/ "test"
|
||||
/ "data"
|
||||
/ "TestExcelHelloComputer.xlsx"
|
||||
)
|
||||
SID = "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -35,9 +38,28 @@ async def test_simple_data_query():
|
|||
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
||||
).load_xls(TEST_XLS_PATH)
|
||||
|
||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
assert query.startswith("SELECT")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||
)
|
||||
async def test_data_query():
|
||||
q = "find the average score of all the sudents"
|
||||
|
||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
|
||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = db.query(query).pl()
|
||||
|
||||
assert result.shape[0] == 1
|
||||
|
|
Loading…
Reference in a new issue