diff --git a/notebooks/newtasks.ipynb b/notebooks/newtasks.ipynb index 901dc61..5ca9728 100644 --- a/notebooks/newtasks.ipynb +++ b/notebooks/newtasks.ipynb @@ -19,7 +19,7 @@ "\n", "client = openai.OpenAI(\n", " base_url = \"https://api.fireworks.ai/inference/v1\",\n", - " api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n", + " api_key = \"YOUR_API_KEY\"\n", ")\n", "\n", "messages = [\n", diff --git a/test/test_query.py b/test/test_query.py index 347c596..07d42b3 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -2,6 +2,7 @@ from pathlib import Path import hellocomputer import pytest +import polars as pl from hellocomputer.config import settings from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block @@ -48,7 +49,7 @@ async def test_simple_data_query(): @pytest.mark.asyncio @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") async def test_data_query(): - q = "find the average score of all the sudents" + q = "Find the average score of all the sudents" llm = Chat( api_key=settings.llm_api_key, @@ -60,6 +61,7 @@ async def test_data_query(): chat = await llm.sql_eval(db.query_prompt(q)) query = extract_code_block(chat.last_response_content()) - result = db.query(query).pl() + result: pl.DataFrame = db.query(query).pl() assert result.shape[0] == 1 + assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5