2024-05-25 09:16:45 +02:00
|
|
|
import os
|
2024-05-28 22:23:11 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
2024-05-25 09:16:45 +02:00
|
|
|
import hellocomputer
|
2024-05-28 22:23:11 +02:00
|
|
|
import pytest
|
|
|
|
from hellocomputer.analytics import DDB
|
2024-05-23 23:31:00 +02:00
|
|
|
from hellocomputer.config import settings
|
2024-05-25 09:16:45 +02:00
|
|
|
from hellocomputer.extraction import extract_code_block
|
2024-05-28 22:23:11 +02:00
|
|
|
from hellocomputer.models import Chat
|
2024-05-25 09:16:45 +02:00
|
|
|
|
2024-05-28 22:23:11 +02:00
|
|
|
TEST_XLS_PATH = (
|
|
|
|
Path(hellocomputer.__file__).parents[2]
|
|
|
|
/ "test"
|
|
|
|
/ "data"
|
|
|
|
/ "TestExcelHelloComputer.xlsx"
|
|
|
|
)
|
2024-05-23 23:31:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
|
|
|
)
|
|
|
|
async def test_chat_simple():
|
|
|
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
|
|
|
|
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
|
|
|
|
assert chat.last_response_content() == "Hello!"
|
2024-05-25 09:16:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
|
|
|
)
|
|
|
|
async def test_simple_data_query():
|
|
|
|
query = "write a query that finds the average score of all students in the current database"
|
|
|
|
|
|
|
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
2024-05-28 22:23:11 +02:00
|
|
|
db = DDB().load_xls(TEST_XLS_PATH)
|
2024-05-25 09:16:45 +02:00
|
|
|
|
|
|
|
prompt = os.linesep.join(
|
|
|
|
[
|
|
|
|
query,
|
2024-05-28 22:23:11 +02:00
|
|
|
db.schema(),
|
|
|
|
db.load_description(),
|
2024-05-25 09:16:45 +02:00
|
|
|
"Return just the SQL statement",
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
chat = await chat.eval("You're an expert sql developer", prompt)
|
|
|
|
query = extract_code_block(chat.last_response_content())
|
|
|
|
assert query.startswith("SELECT")
|