hellocomputer/test/test_query.py
Guillem Borrell f4b9c30a17
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Kind of refactored everything
2024-07-25 00:10:09 +02:00

92 lines
2.7 KiB
Python

from pathlib import Path
import hellocomputer
import polars as pl
import pytest
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import AvailableModels
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
settings = Settings(
storage_engine=StorageEngines.local,
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
)
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
SID = "test"
@pytest.mark.asyncio
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_chat_simple():
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.mixtral_8x7b,
temperature=0.5,
)
prompt = ChatPromptTemplate.from_template(
"""Say literally {word}, a single word. Don't be verbose,
I'll be disappointed if you say more than a single word"""
)
chain = prompt | llm
response = await chain.ainvoke({"word": "Hello"})
assert response.content.lower().startswith("hello")
@pytest.mark.asyncio
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_query_context():
db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.mixtral_8x7b,
temperature=0.5,
)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
context = toolkit.get_context()
assert "table_info" in context
assert "table_names" in context
#
# chat = await chat.sql_eval(db.query_prompt(query))
# query = extract_code_block(chat.last_response_content())
# assert query.startswith("SELECT")
#
#
# @pytest.mark.asyncio
# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
# async def test_data_query():
# q = "Find the average score of all the sudents"
#
# llm = Chat(
# api_key=settings.llm_api_key,
# temperature=0.5,
# )
# db = SessionDB(
# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
# ).load_folder()
#
# chat = await llm.sql_eval(db.query_prompt(q))
# query = extract_code_block(chat.last_response_content())
# result: pl.DataFrame = db.query(query).pl()
#
# assert result.shape[0] == 1
# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
#