This commit is contained in:
parent
c97f5a25ce
commit
9d08a189b1
|
@ -1,12 +1,4 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from anyio import open_file
|
|
||||||
from langchain_core.prompts import PromptTemplate
|
|
||||||
|
|
||||||
import hellocomputer
|
|
||||||
|
|
||||||
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
|
||||||
|
|
||||||
|
|
||||||
class AvailableModels(StrEnum):
|
class AvailableModels(StrEnum):
|
||||||
|
@ -17,22 +9,3 @@ class AvailableModels(StrEnum):
|
||||||
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||||
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
||||||
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||||
|
|
||||||
|
|
||||||
class Prompts:
|
|
||||||
@classmethod
|
|
||||||
async def getter(cls, name):
|
|
||||||
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
|
||||||
return await f.read()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def intent(cls):
|
|
||||||
return PromptTemplate.from_template(await cls.getter("intent"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def general(cls):
|
|
||||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def sql(cls):
|
|
||||||
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
|
||||||
|
|
26
src/hellocomputer/prompts.py
Normal file
26
src/hellocomputer/prompts.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from anyio import open_file
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hellocomputer
|
||||||
|
|
||||||
|
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
||||||
|
|
||||||
|
|
||||||
|
class Prompts:
|
||||||
|
@classmethod
|
||||||
|
async def getter(cls, name):
|
||||||
|
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def intent(cls):
|
||||||
|
return PromptTemplate.from_template(await cls.getter("intent"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def general(cls):
|
||||||
|
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def sql(cls):
|
||||||
|
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.models import Prompts
|
from hellocomputer.prompts import Prompts
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,8 @@ import polars as pl
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.config import Settings, StorageEngines
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
from hellocomputer.models import AvailableModels, Prompts
|
from hellocomputer.models import AvailableModels
|
||||||
|
from hellocomputer.prompts import Prompts
|
||||||
from hellocomputer.extraction import initial_intent_parser
|
from hellocomputer.extraction import initial_intent_parser
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||||
|
@ -69,8 +70,8 @@ async def test_initial_intent():
|
||||||
llm = ChatOpenAI(
|
llm = ChatOpenAI(
|
||||||
base_url=settings.llm_base_url,
|
base_url=settings.llm_base_url,
|
||||||
api_key=settings.llm_api_key,
|
api_key=settings.llm_api_key,
|
||||||
model=AvailableModels.llama_medium,
|
model=AvailableModels.llama_small,
|
||||||
temperature=0.5,
|
temperature=0,
|
||||||
)
|
)
|
||||||
prompt = await Prompts.intent()
|
prompt = await Prompts.intent()
|
||||||
chain = prompt | llm | initial_intent_parser
|
chain = prompt | llm | initial_intent_parser
|
||||||
|
|
Loading…
Reference in a new issue