This commit is contained in:
parent
98a713b3c7
commit
495c22e0de
|
@ -19,7 +19,7 @@
|
||||||
"\n",
|
"\n",
|
||||||
"client = openai.OpenAI(\n",
|
"client = openai.OpenAI(\n",
|
||||||
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
|
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
|
||||||
" api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n",
|
" api_key = \"YOUR_API_KEY\"\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"messages = [\n",
|
"messages = [\n",
|
||||||
|
|
|
@ -2,6 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import pytest
|
import pytest
|
||||||
|
import polars as pl
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
@ -48,7 +49,7 @@ async def test_simple_data_query():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
async def test_data_query():
|
async def test_data_query():
|
||||||
q = "find the average score of all the sudents"
|
q = "Find the average score of all the sudents"
|
||||||
|
|
||||||
llm = Chat(
|
llm = Chat(
|
||||||
api_key=settings.llm_api_key,
|
api_key=settings.llm_api_key,
|
||||||
|
@ -60,6 +61,7 @@ async def test_data_query():
|
||||||
|
|
||||||
chat = await llm.sql_eval(db.query_prompt(q))
|
chat = await llm.sql_eval(db.query_prompt(q))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
result = db.query(query).pl()
|
result: pl.DataFrame = db.query(query).pl()
|
||||||
|
|
||||||
assert result.shape[0] == 1
|
assert result.shape[0] == 1
|
||||||
|
assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
||||||
|
|
Loading…
Reference in a new issue