from pathlib import Path import hellocomputer import polars as pl import pytest from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.sessions import SessionDB from hellocomputer.extraction import extract_code_block from hellocomputer.models import AvailableModels from langchain_core.prompts import ChatPromptTemplate from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_openai import ChatOpenAI settings = Settings( storage_engine=StorageEngines.local, path=Path(hellocomputer.__file__).parents[2] / "test" / "output", ) TEST_XLS_PATH = ( Path(hellocomputer.__file__).parents[2] / "test" / "data" / "TestExcelHelloComputer.xlsx" ) SID = "test" @pytest.mark.asyncio @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") async def test_chat_simple(): llm = ChatOpenAI( base_url=settings.llm_base_url, api_key=settings.llm_api_key, model=AvailableModels.mixtral_8x7b, temperature=0.5, ) prompt = ChatPromptTemplate.from_template( """Say literally {word}, a single word. Don't be verbose, I'll be disappointed if you say more than a single word""" ) chain = prompt | llm response = await chain.ainvoke({"word": "Hello"}) assert response.content.lower().startswith("hello") @pytest.mark.asyncio @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") async def test_query_context(): db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql llm = ChatOpenAI( base_url=settings.llm_base_url, api_key=settings.llm_api_key, model=AvailableModels.mixtral_8x7b, temperature=0.5, ) toolkit = SQLDatabaseToolkit(db=db, llm=llm) context = toolkit.get_context() assert "table_info" in context assert "table_names" in context # # chat = await chat.sql_eval(db.query_prompt(query)) # query = extract_code_block(chat.last_response_content()) # assert query.startswith("SELECT") # # # @pytest.mark.asyncio # @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") # async def test_data_query(): # q = "Find the average score of all the sudents" # # llm = Chat( # api_key=settings.llm_api_key, # temperature=0.5, # ) # db = SessionDB( # storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" # ).load_folder() # # chat = await llm.sql_eval(db.query_prompt(q)) # query = extract_code_block(chat.last_response_content()) # result: pl.DataFrame = db.query(query).pl() # # assert result.shape[0] == 1 # assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5 #