Better tests and docs
This commit is contained in:
parent
d642bb0a39
commit
e8c7600ed7
|
@ -6,9 +6,11 @@ pydantic-settings
|
||||||
s3fs
|
s3fs
|
||||||
aiofiles
|
aiofiles
|
||||||
duckdb
|
duckdb
|
||||||
|
duckdb-engine
|
||||||
polars
|
polars
|
||||||
pyarrow
|
pyarrow
|
||||||
pyjwt[crypto]
|
pyjwt[crypto]
|
||||||
python-multipart
|
python-multipart
|
||||||
authlib
|
authlib
|
||||||
itsdangerous
|
itsdangerous
|
||||||
|
sqlalchemy
|
|
@ -1,8 +1,7 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from sqlalchemy import create_engine, text
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import duckdb
|
|
||||||
|
|
||||||
|
|
||||||
class StorageEngines(StrEnum):
|
class StorageEngines(StrEnum):
|
||||||
local = "Local"
|
local = "Local"
|
||||||
|
@ -19,11 +18,13 @@ class DDB:
|
||||||
bucket: str | None = None,
|
bucket: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.db = duckdb.connect()
|
self.engine = create_engine(
|
||||||
self.db.install_extension("spatial")
|
"duckdb:///:memory:",
|
||||||
self.db.install_extension("httpfs")
|
connect_args={
|
||||||
self.db.load_extension("spatial")
|
"preload_extensions": ["https", "spatial"],
|
||||||
self.db.load_extension("httpfs")
|
"config": {"memory_limit": "300mb"},
|
||||||
|
},
|
||||||
|
)
|
||||||
self.sheets = tuple()
|
self.sheets = tuple()
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
|
|
||||||
|
@ -35,12 +36,18 @@ class DDB:
|
||||||
bucket is not None,
|
bucket is not None,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
self.db.sql(f"""
|
with self.engine.connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
CREATE SECRET (
|
CREATE SECRET (
|
||||||
TYPE GCS,
|
TYPE GCS,
|
||||||
KEY_ID '{gcs_access}',
|
KEY_ID '{gcs_access}',
|
||||||
SECRET '{gcs_secret}')
|
SECRET '{gcs_secret}')
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.path_prefix = f"gcs://{bucket}"
|
self.path_prefix = f"gcs://{bucket}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -55,3 +62,7 @@ class DDB:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"With local storage you need to provide the path keyword argument"
|
"With local storage you need to provide the path keyword argument"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db(self):
|
||||||
|
return self.engine.raw_connection()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
|
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
from hellocomputer.sessions import SessionDB
|
from hellocomputer.sessions import SessionDB
|
||||||
|
@ -13,7 +14,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "", q: str = "") -> str:
|
async def query(sid: str = "", q: str = "") -> str:
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = SessionDB(
|
db = SessionDB(
|
||||||
StorageEngines.gcs,
|
StorageEngines.gcs,
|
||||||
gcs_access=settings.gcs_access,
|
gcs_access=settings.gcs_access,
|
||||||
|
@ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str:
|
||||||
sid=sid,
|
sid=sid,
|
||||||
).load_folder()
|
).load_folder()
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q))
|
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
result = str(db.query(query))
|
result = str(db.query(query))
|
||||||
print(result)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -149,7 +149,7 @@ class SessionDB(DDB):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schema(self):
|
def schema(self) -> str:
|
||||||
return os.linesep.join(
|
return os.linesep.join(
|
||||||
[
|
[
|
||||||
"The schema of the database is the following:",
|
"The schema of the database is the following:",
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
<p class="techie-font">
|
<p class="techie-font">
|
||||||
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
|
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
|
||||||
not be as powerful as Excel, but it has an efficient query backend that can process your data faster
|
not be as powerful as Excel, but it has an efficient query backend that can process your data faster
|
||||||
and more efficiently than Excel.
|
than Excel.
|
||||||
</p>
|
</p>
|
||||||
<a href="/"><button class="btn btn-secondary w-100">Back</button></a>
|
<a href="/"><button class="btn btn-secondary w-100">Back</button></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -25,9 +25,11 @@
|
||||||
</p>
|
</p>
|
||||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i
|
<a href="#" class="list-group-item list-group-item-action bg-light"><i
|
||||||
class="bi bi-question-circle"></i> How to</a>
|
class="bi bi-question-circle"></i> How to</a>
|
||||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i>
|
<a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
|
||||||
|
class="bi bi-file-ruled"></i>
|
||||||
File templates</a>
|
File templates</a>
|
||||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i>
|
<a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
|
||||||
|
class="bi bi-info-circle"></i>
|
||||||
About</a>
|
About</a>
|
||||||
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
|
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
|
||||||
Config</a>
|
Config</a>
|
||||||
|
|
|
@ -71,9 +71,7 @@ function addAIManualMessage(m) {
|
||||||
chatMessages.prepend(newMessage); // Add new message at the top
|
chatMessages.prepend(newMessage); // Add new message at the top
|
||||||
}
|
}
|
||||||
|
|
||||||
function addUserMessage() {
|
function addUserMessageBlock(messageContent) {
|
||||||
const messageContent = textarea.value.trim();
|
|
||||||
if (messageContent) {
|
|
||||||
const newMessage = document.createElement('div');
|
const newMessage = document.createElement('div');
|
||||||
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
|
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
|
||||||
newMessage.textContent = messageContent;
|
newMessage.textContent = messageContent;
|
||||||
|
@ -81,8 +79,20 @@ function addUserMessage() {
|
||||||
textarea.value = ''; // Clear the textarea
|
textarea.value = ''; // Clear the textarea
|
||||||
textarea.style.height = 'auto'; // Reset the textarea height
|
textarea.style.height = 'auto'; // Reset the textarea height
|
||||||
textarea.style.overflowY = 'hidden';
|
textarea.style.overflowY = 'hidden';
|
||||||
|
};
|
||||||
|
|
||||||
|
function addUserMessage() {
|
||||||
|
const messageContent = textarea.value.trim();
|
||||||
|
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
|
||||||
|
textarea.value = '';
|
||||||
|
addAIManualMessage('Please upload a data file or select a session first!');
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (messageContent) {
|
||||||
|
addUserMessageBlock(messageContent);
|
||||||
addAIMessage(messageContent);
|
addAIMessage(messageContent);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
sendButton.addEventListener('click', addUserMessage);
|
sendButton.addEventListener('click', addUserMessage);
|
||||||
|
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
try {
|
try {
|
||||||
const session_response = await fetch('/new_session');
|
const session_response = await fetch('/new_session');
|
||||||
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
|
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
|
||||||
|
sessionStorage.setItem("helloComputerSessionLoaded", false);
|
||||||
|
|
||||||
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
|
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
|
||||||
|
|
||||||
|
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
|
||||||
const data = await response.text();
|
const data = await response.text();
|
||||||
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
|
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
|
||||||
|
sessionStorage.setItem("helloComputerSessionLoaded", true);
|
||||||
|
|
||||||
addAIManualMessage('File uploaded and processed!');
|
addAIManualMessage('File uploaded and processed!');
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|
BIN
src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
Normal file
BIN
src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
Normal file
Binary file not shown.
|
@ -99,10 +99,10 @@ class OwnershipDB(DDB):
|
||||||
FROM
|
FROM
|
||||||
'{self.path_prefix}/*.csv'
|
'{self.path_prefix}/*.csv'
|
||||||
WHERE
|
WHERE
|
||||||
email = '{user_email}
|
email = '{user_email}'
|
||||||
ORDER BY
|
ORDER BY
|
||||||
timestamp ASC
|
timestamp ASC
|
||||||
LIMIT 10'
|
LIMIT 10
|
||||||
""")
|
""")
|
||||||
.pl()
|
.pl()
|
||||||
.to_series()
|
.to_series()
|
||||||
|
|
|
@ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block
|
||||||
from hellocomputer.models import Chat
|
from hellocomputer.models import Chat
|
||||||
from hellocomputer.sessions import SessionDB
|
from hellocomputer.sessions import SessionDB
|
||||||
|
|
||||||
|
TEST_STORAGE = StorageEngines.local
|
||||||
|
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
TEST_XLS_PATH = (
|
TEST_XLS_PATH = (
|
||||||
Path(hellocomputer.__file__).parents[2]
|
Path(hellocomputer.__file__).parents[2]
|
||||||
/ "test"
|
/ "test"
|
||||||
/ "data"
|
/ "data"
|
||||||
/ "TestExcelHelloComputer.xlsx"
|
/ "TestExcelHelloComputer.xlsx"
|
||||||
)
|
)
|
||||||
|
SID = "test"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -35,9 +38,28 @@ async def test_simple_data_query():
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = SessionDB(
|
db = SessionDB(
|
||||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
||||||
).load_xls(TEST_XLS_PATH)
|
).load_xls(TEST_XLS_PATH)
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
assert query.startswith("SELECT")
|
assert query.startswith("SELECT")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||||
|
)
|
||||||
|
async def test_data_query():
|
||||||
|
q = "find the average score of all the sudents"
|
||||||
|
|
||||||
|
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
|
db = SessionDB(
|
||||||
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
|
).load_folder()
|
||||||
|
|
||||||
|
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||||
|
query = extract_code_block(chat.last_response_content())
|
||||||
|
result = db.query(query).pl()
|
||||||
|
|
||||||
|
assert result.shape[0] == 1
|
||||||
|
|
Loading…
Reference in a new issue