hellocomputer/test/test_query.py

92 lines
2.7 KiB
Python
Raw Normal View History

2024-05-28 22:23:11 +02:00
from pathlib import Path
2024-05-25 09:16:45 +02:00
import hellocomputer
2024-07-13 11:53:37 +02:00
import polars as pl
import pytest
2024-07-25 00:10:09 +02:00
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB
2024-05-25 09:16:45 +02:00
from hellocomputer.extraction import extract_code_block
2024-07-25 00:10:09 +02:00
from hellocomputer.models import AvailableModels
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
settings = Settings(
storage_engine=StorageEngines.local,
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
)
2024-05-25 09:16:45 +02:00
2024-05-28 22:23:11 +02:00
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
2024-07-25 00:10:09 +02:00
2024-06-16 08:56:45 +02:00
SID = "test"
2024-05-23 23:31:00 +02:00
@pytest.mark.asyncio
2024-07-13 10:52:45 +02:00
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
2024-05-23 23:31:00 +02:00
async def test_chat_simple():
2024-07-25 00:10:09 +02:00
llm = ChatOpenAI(
base_url=settings.llm_base_url,
2024-07-13 10:52:45 +02:00
api_key=settings.llm_api_key,
2024-07-25 00:10:09 +02:00
model=AvailableModels.mixtral_8x7b,
2024-07-13 10:52:45 +02:00
temperature=0.5,
)
2024-07-25 00:10:09 +02:00
prompt = ChatPromptTemplate.from_template(
"""Say literally {word}, a single word. Don't be verbose,
I'll be disappointed if you say more than a single word"""
)
chain = prompt | llm
response = await chain.ainvoke({"word": "Hello"})
2024-05-25 09:16:45 +02:00
2024-07-25 00:10:09 +02:00
assert response.content.lower().startswith("hello")
2024-06-16 08:56:45 +02:00
@pytest.mark.asyncio
2024-07-13 10:52:45 +02:00
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
2024-07-25 00:10:09 +02:00
async def test_query_context():
db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql
2024-06-16 08:56:45 +02:00
2024-07-25 00:10:09 +02:00
llm = ChatOpenAI(
base_url=settings.llm_base_url,
2024-07-13 10:52:45 +02:00
api_key=settings.llm_api_key,
2024-07-25 00:10:09 +02:00
model=AvailableModels.mixtral_8x7b,
2024-07-13 10:52:45 +02:00
temperature=0.5,
)
2024-06-16 08:56:45 +02:00
2024-07-25 00:10:09 +02:00
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
context = toolkit.get_context()
assert "table_info" in context
assert "table_names" in context
2024-06-16 08:56:45 +02:00
2024-07-25 00:10:09 +02:00
#
# chat = await chat.sql_eval(db.query_prompt(query))
# query = extract_code_block(chat.last_response_content())
# assert query.startswith("SELECT")
#
#
# @pytest.mark.asyncio
# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
# async def test_data_query():
# q = "Find the average score of all the sudents"
#
# llm = Chat(
# api_key=settings.llm_api_key,
# temperature=0.5,
# )
# db = SessionDB(
# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
# ).load_folder()
#
# chat = await llm.sql_eval(db.query_prompt(q))
# query = extract_code_block(chat.last_response_content())
# result: pl.DataFrame = db.query(query).pl()
#
# assert result.shape[0] == 1
# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
#