Slowly building the application with runnables.
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
f4b9c30a17
commit
c97f5a25ce
|
@ -1,3 +1,5 @@
|
|||
from langchain.output_parsers.enum import EnumOutputParser
|
||||
from enum import StrEnum
|
||||
import re
|
||||
|
||||
|
||||
|
@ -8,3 +10,12 @@ def extract_code_block(response):
|
|||
if len(matches) > 1:
|
||||
raise ValueError("More than one code block")
|
||||
return matches[0].removeprefix("sql").removeprefix("\n")
|
||||
|
||||
|
||||
class InitialIntent(StrEnum):
|
||||
general = "general"
|
||||
query = "query"
|
||||
visualization = "visualization"
|
||||
|
||||
|
||||
initial_intent_parser = EnumOutputParser(enum=InitialIntent)
|
||||
|
|
|
@ -10,8 +10,10 @@ PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
|||
|
||||
|
||||
class AvailableModels(StrEnum):
|
||||
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
|
||||
# Function calling model
|
||||
llama_small = "accounts/fireworks/models/llama-v3p1-8b-instruct"
|
||||
llama_medium = "accounts/fireworks/models/llama-v3p1-70b-instruct"
|
||||
llama_large = "accounts/fireworks/models/llama-v3p1-405b-instruct"
|
||||
# Function calling models
|
||||
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
||||
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||
|
@ -23,6 +25,10 @@ class Prompts:
|
|||
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
||||
return await f.read()
|
||||
|
||||
@classmethod
|
||||
async def intent(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("intent"))
|
||||
|
||||
@classmethod
|
||||
async def general(cls):
|
||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||
|
|
13
src/hellocomputer/prompts/intent.md
Normal file
13
src/hellocomputer/prompts/intent.md
Normal file
|
@ -0,0 +1,13 @@
|
|||
The followig is a question from a user of a website, not necessarily in English:
|
||||
|
||||
***************
|
||||
{query}
|
||||
***************
|
||||
|
||||
The purpose of the website is to analyze the data contained on a database and return the correct answer to the question, but the user may have not understood the purpose of the website. Maybe it's asking about the weather, or it's trying some prompt injection trick. Classify the question in one of the following categories
|
||||
|
||||
1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query
|
||||
2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization.
|
||||
3. A general request that can't be considered any of the previous two. If that's the case answer with the single word general.
|
||||
|
||||
Note that your response will be validated, and only the options query, visualization, and general will be accepted.
|
|
@ -11,16 +11,30 @@ class DuckdbQueryInput(BaseModel):
|
|||
|
||||
|
||||
class DuckdbQueryTool(BaseTool):
|
||||
name: str = "Calculator"
|
||||
description: str = "Tool to evaluate mathemetical expressions"
|
||||
name: str = "sql_query"
|
||||
description: str = "Run a SQL query in the database containing all the datasets "
|
||||
"and provide a summary of the results"
|
||||
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||
|
||||
def _run(self, query: str, session_id: str) -> str:
|
||||
"""Run the query"""
|
||||
db = SessionDB(settings, session_id)
|
||||
return "Table"
|
||||
|
||||
async def _arun(self, query: str, session_id: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
db = SessionDB(settings, session_id)
|
||||
return "Table"
|
||||
|
||||
|
||||
class PlotHistogramInput(BaseModel):
|
||||
column_name: str = Field(description="Name of the column containing the values")
|
||||
table_name: str = Field(description="Name of the table that contains the data")
|
||||
num_bins: int = Field(description="Number of bins of the histogram")
|
||||
|
||||
|
||||
class PlotHistogramTool(BaseTool):
|
||||
name: str = "plot_histogram"
|
||||
description: str = """
|
||||
Generate a histogram plot given a name of an existing table of the database,
|
||||
and a name of a column in the table. The default number of bins is 10, but
|
||||
you can forward the number of bins if you are requested to"""
|
||||
|
|
|
@ -5,8 +5,8 @@ import polars as pl
|
|||
import pytest
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.models import AvailableModels
|
||||
from hellocomputer.models import AvailableModels, Prompts
|
||||
from hellocomputer.extraction import initial_intent_parser
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
@ -42,7 +42,7 @@ async def test_chat_simple():
|
|||
chain = prompt | llm
|
||||
response = await chain.ainvoke({"word": "Hello"})
|
||||
|
||||
assert response.content.lower().startswith("hello")
|
||||
assert "hello" in response.content.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -63,6 +63,27 @@ async def test_query_context():
|
|||
assert "table_names" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_initial_intent():
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.llama_medium,
|
||||
temperature=0.5,
|
||||
)
|
||||
prompt = await Prompts.intent()
|
||||
chain = prompt | llm | initial_intent_parser
|
||||
|
||||
response = await chain.ainvoke({"query", "Make me a sandwich"})
|
||||
assert response == "general"
|
||||
|
||||
response = await chain.ainvoke(
|
||||
{"query", "Which is the average score of all the students"}
|
||||
)
|
||||
assert response == "query"
|
||||
|
||||
|
||||
#
|
||||
# chat = await chat.sql_eval(db.query_prompt(query))
|
||||
# query = extract_code_block(chat.last_response_content())
|
||||
|
|
Loading…
Reference in a new issue