2024-05-28 22:23:11 +02:00
|
|
|
from pathlib import Path
|
|
|
|
|
2024-05-25 09:16:45 +02:00
|
|
|
import hellocomputer
|
2024-07-13 11:53:37 +02:00
|
|
|
import polars as pl
|
2024-07-13 22:46:34 +02:00
|
|
|
import pytest
|
2024-07-25 00:10:09 +02:00
|
|
|
from hellocomputer.config import Settings, StorageEngines
|
2024-07-13 22:46:34 +02:00
|
|
|
from hellocomputer.db.sessions import SessionDB
|
2024-05-25 09:16:45 +02:00
|
|
|
from hellocomputer.extraction import extract_code_block
|
2024-07-25 00:10:09 +02:00
|
|
|
from hellocomputer.models import AvailableModels
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
|
|
|
|
settings = Settings(
|
|
|
|
storage_engine=StorageEngines.local,
|
|
|
|
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
|
|
|
)
|
2024-05-25 09:16:45 +02:00
|
|
|
|
2024-05-28 22:23:11 +02:00
|
|
|
TEST_XLS_PATH = (
|
|
|
|
Path(hellocomputer.__file__).parents[2]
|
|
|
|
/ "test"
|
|
|
|
/ "data"
|
|
|
|
/ "TestExcelHelloComputer.xlsx"
|
|
|
|
)
|
2024-07-25 00:10:09 +02:00
|
|
|
|
2024-06-16 08:56:45 +02:00
|
|
|
SID = "test"
|
2024-05-23 23:31:00 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-13 10:52:45 +02:00
|
|
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
2024-05-23 23:31:00 +02:00
|
|
|
async def test_chat_simple():
|
2024-07-25 00:10:09 +02:00
|
|
|
llm = ChatOpenAI(
|
|
|
|
base_url=settings.llm_base_url,
|
2024-07-13 10:52:45 +02:00
|
|
|
api_key=settings.llm_api_key,
|
2024-07-25 00:10:09 +02:00
|
|
|
model=AvailableModels.mixtral_8x7b,
|
2024-07-13 10:52:45 +02:00
|
|
|
temperature=0.5,
|
|
|
|
)
|
2024-07-25 00:10:09 +02:00
|
|
|
prompt = ChatPromptTemplate.from_template(
|
|
|
|
"""Say literally {word}, a single word. Don't be verbose,
|
|
|
|
I'll be disappointed if you say more than a single word"""
|
|
|
|
)
|
|
|
|
chain = prompt | llm
|
|
|
|
response = await chain.ainvoke({"word": "Hello"})
|
2024-05-25 09:16:45 +02:00
|
|
|
|
2024-07-25 00:10:09 +02:00
|
|
|
assert response.content.lower().startswith("hello")
|
2024-06-16 08:56:45 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-13 10:52:45 +02:00
|
|
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
2024-07-25 00:10:09 +02:00
|
|
|
async def test_query_context():
|
|
|
|
db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql
|
2024-06-16 08:56:45 +02:00
|
|
|
|
2024-07-25 00:10:09 +02:00
|
|
|
llm = ChatOpenAI(
|
|
|
|
base_url=settings.llm_base_url,
|
2024-07-13 10:52:45 +02:00
|
|
|
api_key=settings.llm_api_key,
|
2024-07-25 00:10:09 +02:00
|
|
|
model=AvailableModels.mixtral_8x7b,
|
2024-07-13 10:52:45 +02:00
|
|
|
temperature=0.5,
|
|
|
|
)
|
2024-06-16 08:56:45 +02:00
|
|
|
|
2024-07-25 00:10:09 +02:00
|
|
|
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
|
|
|
context = toolkit.get_context()
|
|
|
|
assert "table_info" in context
|
|
|
|
assert "table_names" in context
|
|
|
|
|
2024-06-16 08:56:45 +02:00
|
|
|
|
2024-07-25 00:10:09 +02:00
|
|
|
#
|
|
|
|
# chat = await chat.sql_eval(db.query_prompt(query))
|
|
|
|
# query = extract_code_block(chat.last_response_content())
|
|
|
|
# assert query.startswith("SELECT")
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# @pytest.mark.asyncio
|
|
|
|
# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
|
|
|
# async def test_data_query():
|
|
|
|
# q = "Find the average score of all the sudents"
|
|
|
|
#
|
|
|
|
# llm = Chat(
|
|
|
|
# api_key=settings.llm_api_key,
|
|
|
|
# temperature=0.5,
|
|
|
|
# )
|
|
|
|
# db = SessionDB(
|
|
|
|
# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
|
|
|
# ).load_folder()
|
|
|
|
#
|
|
|
|
# chat = await llm.sql_eval(db.query_prompt(q))
|
|
|
|
# query = extract_code_block(chat.last_response_content())
|
|
|
|
# result: pl.DataFrame = db.query(query).pl()
|
|
|
|
#
|
|
|
|
# assert result.shape[0] == 1
|
|
|
|
# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
|
|
|
#
|