diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index 9551c6a..0c2923f 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -1,12 +1,4 @@ from enum import StrEnum -from pathlib import Path - -from anyio import open_file -from langchain_core.prompts import PromptTemplate - -import hellocomputer - -PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts" class AvailableModels(StrEnum): @@ -17,22 +9,3 @@ class AvailableModels(StrEnum): mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct" mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct" firefunction_2 = "accounts/fireworks/models/firefunction-v2" - - -class Prompts: - @classmethod - async def getter(cls, name): - async with await open_file(PROMPT_DIR / f"{name}.md") as f: - return await f.read() - - @classmethod - async def intent(cls): - return PromptTemplate.from_template(await cls.getter("intent")) - - @classmethod - async def general(cls): - return PromptTemplate.from_template(await cls.getter("general_prompt")) - - @classmethod - async def sql(cls): - return PromptTemplate.from_template(await cls.getter("sql_prompt")) diff --git a/src/hellocomputer/prompts.py b/src/hellocomputer/prompts.py new file mode 100644 index 0000000..a0fd80c --- /dev/null +++ b/src/hellocomputer/prompts.py @@ -0,0 +1,26 @@ +from anyio import open_file +from langchain_core.prompts import PromptTemplate +from pathlib import Path + +import hellocomputer + +PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts" + + +class Prompts: + @classmethod + async def getter(cls, name): + async with await open_file(PROMPT_DIR / f"{name}.md") as f: + return await f.read() + + @classmethod + async def intent(cls): + return PromptTemplate.from_template(await cls.getter("intent")) + + @classmethod + async def general(cls): + return PromptTemplate.from_template(await cls.getter("general_prompt")) + + @classmethod + async def sql(cls): + return PromptTemplate.from_template(await cls.getter("sql_prompt")) diff --git a/test/test_prompts.py b/test/test_prompts.py index 62d275d..fe74ecb 100644 --- a/test/test_prompts.py +++ b/test/test_prompts.py @@ -1,5 +1,5 @@ import pytest -from hellocomputer.models import Prompts +from hellocomputer.prompts import Prompts from langchain.prompts import PromptTemplate diff --git a/test/test_query.py b/test/test_query.py index ae600a4..0e59151 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -5,7 +5,8 @@ import polars as pl import pytest from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.sessions import SessionDB -from hellocomputer.models import AvailableModels, Prompts +from hellocomputer.models import AvailableModels +from hellocomputer.prompts import Prompts from hellocomputer.extraction import initial_intent_parser from langchain_core.prompts import ChatPromptTemplate from langchain_community.agent_toolkits import SQLDatabaseToolkit @@ -69,8 +70,8 @@ async def test_initial_intent(): llm = ChatOpenAI( base_url=settings.llm_base_url, api_key=settings.llm_api_key, - model=AvailableModels.llama_medium, - temperature=0.5, + model=AvailableModels.llama_small, + temperature=0, ) prompt = await Prompts.intent() chain = prompt | llm | initial_intent_parser