Split models and prompts
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-25 23:42:19 +02:00
parent c97f5a25ce
commit 9d08a189b1
4 changed files with 31 additions and 31 deletions

View file

@ -1,12 +1,4 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path
from anyio import open_file
from langchain_core.prompts import PromptTemplate
import hellocomputer
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
class AvailableModels(StrEnum): class AvailableModels(StrEnum):
@ -17,22 +9,3 @@ class AvailableModels(StrEnum):
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct" mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct" mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
firefunction_2 = "accounts/fireworks/models/firefunction-v2" firefunction_2 = "accounts/fireworks/models/firefunction-v2"
class Prompts:
@classmethod
async def getter(cls, name):
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
return await f.read()
@classmethod
async def intent(cls):
return PromptTemplate.from_template(await cls.getter("intent"))
@classmethod
async def general(cls):
return PromptTemplate.from_template(await cls.getter("general_prompt"))
@classmethod
async def sql(cls):
return PromptTemplate.from_template(await cls.getter("sql_prompt"))

View file

@ -0,0 +1,26 @@
from anyio import open_file
from langchain_core.prompts import PromptTemplate
from pathlib import Path
import hellocomputer
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
class Prompts:
@classmethod
async def getter(cls, name):
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
return await f.read()
@classmethod
async def intent(cls):
return PromptTemplate.from_template(await cls.getter("intent"))
@classmethod
async def general(cls):
return PromptTemplate.from_template(await cls.getter("general_prompt"))
@classmethod
async def sql(cls):
return PromptTemplate.from_template(await cls.getter("sql_prompt"))

View file

@ -1,5 +1,5 @@
import pytest import pytest
from hellocomputer.models import Prompts from hellocomputer.prompts import Prompts
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate

View file

@ -5,7 +5,8 @@ import polars as pl
import pytest import pytest
from hellocomputer.config import Settings, StorageEngines from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB from hellocomputer.db.sessions import SessionDB
from hellocomputer.models import AvailableModels, Prompts from hellocomputer.models import AvailableModels
from hellocomputer.prompts import Prompts
from hellocomputer.extraction import initial_intent_parser from hellocomputer.extraction import initial_intent_parser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.agent_toolkits import SQLDatabaseToolkit
@ -69,8 +70,8 @@ async def test_initial_intent():
llm = ChatOpenAI( llm = ChatOpenAI(
base_url=settings.llm_base_url, base_url=settings.llm_base_url,
api_key=settings.llm_api_key, api_key=settings.llm_api_key,
model=AvailableModels.llama_medium, model=AvailableModels.llama_small,
temperature=0.5, temperature=0,
) )
prompt = await Prompts.intent() prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser chain = prompt | llm | initial_intent_parser