hellocomputer/test/test_query.py

66 lines
2.1 KiB
Python
Raw Normal View History

2024-05-28 22:23:11 +02:00
from pathlib import Path
2024-05-25 09:16:45 +02:00
import hellocomputer
2024-05-28 22:23:11 +02:00
import pytest
2024-05-23 23:31:00 +02:00
from hellocomputer.config import settings
2024-06-11 17:20:32 +02:00
from hellocomputer.db import StorageEngines
2024-05-25 09:16:45 +02:00
from hellocomputer.extraction import extract_code_block
2024-05-28 22:23:11 +02:00
from hellocomputer.models import Chat
2024-06-12 11:04:28 +02:00
from hellocomputer.sessions import SessionDB
2024-05-25 09:16:45 +02:00
2024-06-16 08:56:45 +02:00
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
2024-05-28 22:23:11 +02:00
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
2024-06-16 08:56:45 +02:00
SID = "test"
2024-05-23 23:31:00 +02:00
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_chat_simple():
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
assert chat.last_response_content() == "Hello!"
2024-05-25 09:16:45 +02:00
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
2024-06-11 18:54:10 +02:00
db = SessionDB(
2024-06-16 08:56:45 +02:00
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
2024-06-11 14:19:51 +02:00
).load_xls(TEST_XLS_PATH)
2024-05-25 09:16:45 +02:00
2024-05-31 23:59:02 +02:00
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
2024-05-25 09:16:45 +02:00
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")
2024-06-16 08:56:45 +02:00
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_data_query():
q = "find the average score of all the sudents"
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = db.query(query).pl()
assert result.shape[0] == 1