2024-05-28 22:23:11 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
2024-05-25 09:16:45 +02:00
|
|
|
import hellocomputer
|
2024-07-13 11:53:37 +02:00
|
|
|
import polars as pl
|
2024-07-13 22:46:34 +02:00
|
|
|
import pytest
|
2024-05-23 23:31:00 +02:00
|
|
|
from hellocomputer.config import settings
|
2024-06-11 17:20:32 +02:00
|
|
|
from hellocomputer.db import StorageEngines
|
2024-07-13 22:46:34 +02:00
|
|
|
from hellocomputer.db.sessions import SessionDB
|
2024-05-25 09:16:45 +02:00
|
|
|
from hellocomputer.extraction import extract_code_block
|
2024-05-28 22:23:11 +02:00
|
|
|
from hellocomputer.models import Chat
|
2024-05-25 09:16:45 +02:00
|
|
|
|
2024-06-16 08:56:45 +02:00
|
|
|
TEST_STORAGE = StorageEngines.local
|
|
|
|
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
2024-05-28 22:23:11 +02:00
|
|
|
TEST_XLS_PATH = (
|
|
|
|
Path(hellocomputer.__file__).parents[2]
|
|
|
|
/ "test"
|
|
|
|
/ "data"
|
|
|
|
/ "TestExcelHelloComputer.xlsx"
|
|
|
|
)
|
2024-06-16 08:56:45 +02:00
|
|
|
SID = "test"
|
2024-05-23 23:31:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-13 10:52:45 +02:00
|
|
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
2024-05-23 23:31:00 +02:00
|
|
|
async def test_chat_simple():
|
2024-07-13 10:52:45 +02:00
|
|
|
chat = Chat(api_key=settings.llm_api_key, temperature=0)
|
|
|
|
chat = await chat.eval("Say literlly 'Hello'")
|
|
|
|
assert "Hello" in chat.last_response_content()
|
2024-05-25 09:16:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-13 10:52:45 +02:00
|
|
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
2024-05-25 09:16:45 +02:00
|
|
|
async def test_simple_data_query():
|
|
|
|
query = "write a query that finds the average score of all students in the current database"
|
|
|
|
|
2024-07-13 10:52:45 +02:00
|
|
|
chat = Chat(
|
|
|
|
api_key=settings.llm_api_key,
|
|
|
|
temperature=0.5,
|
|
|
|
)
|
2024-06-11 18:54:10 +02:00
|
|
|
db = SessionDB(
|
2024-06-16 08:56:45 +02:00
|
|
|
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
2024-06-11 14:19:51 +02:00
|
|
|
).load_xls(TEST_XLS_PATH)
|
2024-05-25 09:16:45 +02:00
|
|
|
|
2024-07-13 10:52:45 +02:00
|
|
|
chat = await chat.sql_eval(db.query_prompt(query))
|
2024-05-25 09:16:45 +02:00
|
|
|
query = extract_code_block(chat.last_response_content())
|
|
|
|
assert query.startswith("SELECT")
|
2024-06-16 08:56:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-13 10:52:45 +02:00
|
|
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
2024-06-16 08:56:45 +02:00
|
|
|
async def test_data_query():
|
2024-07-13 11:53:37 +02:00
|
|
|
q = "Find the average score of all the sudents"
|
2024-06-16 08:56:45 +02:00
|
|
|
|
2024-07-13 10:52:45 +02:00
|
|
|
llm = Chat(
|
|
|
|
api_key=settings.llm_api_key,
|
|
|
|
temperature=0.5,
|
|
|
|
)
|
2024-06-16 08:56:45 +02:00
|
|
|
db = SessionDB(
|
|
|
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
|
|
|
).load_folder()
|
|
|
|
|
2024-07-13 10:52:45 +02:00
|
|
|
chat = await llm.sql_eval(db.query_prompt(q))
|
2024-06-16 08:56:45 +02:00
|
|
|
query = extract_code_block(chat.last_response_content())
|
2024-07-13 11:53:37 +02:00
|
|
|
result: pl.DataFrame = db.query(query).pl()
|
2024-06-16 08:56:45 +02:00
|
|
|
|
|
|
|
assert result.shape[0] == 1
|
2024-07-13 11:53:37 +02:00
|
|
|
assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|