This commit is contained in:
parent
c97f5a25ce
commit
9d08a189b1
|
@ -1,12 +1,4 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from anyio import open_file
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
import hellocomputer
|
||||
|
||||
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
||||
|
||||
|
||||
class AvailableModels(StrEnum):
|
||||
|
@ -17,22 +9,3 @@ class AvailableModels(StrEnum):
|
|||
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
||||
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||
|
||||
|
||||
class Prompts:
|
||||
@classmethod
|
||||
async def getter(cls, name):
|
||||
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
||||
return await f.read()
|
||||
|
||||
@classmethod
|
||||
async def intent(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("intent"))
|
||||
|
||||
@classmethod
|
||||
async def general(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||
|
||||
@classmethod
|
||||
async def sql(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
||||
|
|
26
src/hellocomputer/prompts.py
Normal file
26
src/hellocomputer/prompts.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from anyio import open_file
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
|
||||
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
||||
|
||||
|
||||
class Prompts:
|
||||
@classmethod
|
||||
async def getter(cls, name):
|
||||
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
||||
return await f.read()
|
||||
|
||||
@classmethod
|
||||
async def intent(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("intent"))
|
||||
|
||||
@classmethod
|
||||
async def general(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||
|
||||
@classmethod
|
||||
async def sql(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from hellocomputer.models import Prompts
|
||||
from hellocomputer.prompts import Prompts
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,8 @@ import polars as pl
|
|||
import pytest
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.models import AvailableModels, Prompts
|
||||
from hellocomputer.models import AvailableModels
|
||||
from hellocomputer.prompts import Prompts
|
||||
from hellocomputer.extraction import initial_intent_parser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
|
@ -69,8 +70,8 @@ async def test_initial_intent():
|
|||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_medium,
|
||||
temperature=0.5,
|
||||
model=AvailableModels.llama_small,
|
||||
temperature=0,
|
||||
)
|
||||
prompt = await Prompts.intent()
|
||||
chain = prompt | llm | initial_intent_parser
|
||||
|
|
Loading…
Reference in a new issue