This commit is contained in:
parent
98a713b3c7
commit
495c22e0de
|
@ -19,7 +19,7 @@
|
|||
"\n",
|
||||
"client = openai.OpenAI(\n",
|
||||
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
|
||||
" api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n",
|
||||
" api_key = \"YOUR_API_KEY\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
|
|
|
@ -2,6 +2,7 @@ from pathlib import Path
|
|||
|
||||
import hellocomputer
|
||||
import pytest
|
||||
import polars as pl
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
|
@ -48,7 +49,7 @@ async def test_simple_data_query():
|
|||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_data_query():
|
||||
q = "find the average score of all the sudents"
|
||||
q = "Find the average score of all the sudents"
|
||||
|
||||
llm = Chat(
|
||||
api_key=settings.llm_api_key,
|
||||
|
@ -60,6 +61,7 @@ async def test_data_query():
|
|||
|
||||
chat = await llm.sql_eval(db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = db.query(query).pl()
|
||||
result: pl.DataFrame = db.query(query).pl()
|
||||
|
||||
assert result.shape[0] == 1
|
||||
assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
||||
|
|
Loading…
Reference in a new issue