diff --git a/requirements-test.in b/requirements-test.in new file mode 100644 index 0000000..0bcf21c --- /dev/null +++ b/requirements-test.in @@ -0,0 +1,2 @@ +pytest +pytest-asyncio \ No newline at end of file diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index fb32d59..b98e41d 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -25,12 +25,12 @@ class Chat: api_key: str = "", temperature: float = 0.5, ): - self.model = model + self.model_name = model self.api_key = self.raise_no_key(api_key) self.messages = [] self.responses = [] - model: ChatAnyscale = ChatAnyscale( + self.model: ChatAnyscale = ChatAnyscale( model_name=model, temperature=temperature, anyscale_api_key=self.api_key ) @@ -42,7 +42,8 @@ class Chat: ] ) - self.responses.append(await self.model.ainvoke(self.messages[-1])) + response = await self.model.ainvoke(self.messages[-1]) + self.responses.append(response) return self def last_response_content(self): diff --git a/test/test_data.py b/test/test_data.py index 6f228d8..b3a351e 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -20,7 +20,7 @@ def test_dump(): def test_load(): db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) - results = db.query("select * from answers").fetchall() - assert db.sheets == ("answers",) + + results = db.query("select * from answers").fetchall() assert len(results) == 2 diff --git a/test/test_query.py b/test/test_query.py new file mode 100644 index 0000000..0107db0 --- /dev/null +++ b/test/test_query.py @@ -0,0 +1,13 @@ +import pytest +from hellocomputer.config import settings +from hellocomputer.models import Chat + + +@pytest.mark.asyncio +@pytest.mark.skipif( + settings.anyscale_api_key == "Awesome API", reason="API Key not set" +) +async def test_chat_simple(): + chat = Chat(api_key=settings.anyscale_api_key, temperature=0) + chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'") + assert chat.last_response_content() == "Hello!"