From c97f5a25ce498a72b1dd013f47ad8339587f962a Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Thu, 25 Jul 2024 23:27:03 +0200 Subject: [PATCH] Slowly building the application with runnables. --- src/hellocomputer/extraction.py | 11 +++++++++++ src/hellocomputer/models.py | 10 ++++++++-- src/hellocomputer/prompts/intent.md | 13 +++++++++++++ src/hellocomputer/tools.py | 20 +++++++++++++++++--- test/test_query.py | 27 ++++++++++++++++++++++++--- 5 files changed, 73 insertions(+), 8 deletions(-) create mode 100644 src/hellocomputer/prompts/intent.md diff --git a/src/hellocomputer/extraction.py b/src/hellocomputer/extraction.py index d860272..8152a0f 100644 --- a/src/hellocomputer/extraction.py +++ b/src/hellocomputer/extraction.py @@ -1,3 +1,5 @@ +from langchain.output_parsers.enum import EnumOutputParser +from enum import StrEnum import re @@ -8,3 +10,12 @@ def extract_code_block(response): if len(matches) > 1: raise ValueError("More than one code block") return matches[0].removeprefix("sql").removeprefix("\n") + + +class InitialIntent(StrEnum): + general = "general" + query = "query" + visualization = "visualization" + + +initial_intent_parser = EnumOutputParser(enum=InitialIntent) diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index 09c3982..9551c6a 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -10,8 +10,10 @@ PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts" class AvailableModels(StrEnum): - llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" - # Function calling model + llama_small = "accounts/fireworks/models/llama-v3p1-8b-instruct" + llama_medium = "accounts/fireworks/models/llama-v3p1-70b-instruct" + llama_large = "accounts/fireworks/models/llama-v3p1-405b-instruct" + # Function calling models mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct" mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct" firefunction_2 = "accounts/fireworks/models/firefunction-v2" @@ -23,6 +25,10 @@ class Prompts: async with await open_file(PROMPT_DIR / f"{name}.md") as f: return await f.read() + @classmethod + async def intent(cls): + return PromptTemplate.from_template(await cls.getter("intent")) + @classmethod async def general(cls): return PromptTemplate.from_template(await cls.getter("general_prompt")) diff --git a/src/hellocomputer/prompts/intent.md b/src/hellocomputer/prompts/intent.md new file mode 100644 index 0000000..6f6e94c --- /dev/null +++ b/src/hellocomputer/prompts/intent.md @@ -0,0 +1,13 @@ +The followig is a question from a user of a website, not necessarily in English: + +*************** +{query} +*************** + +The purpose of the website is to analyze the data contained on a database and return the correct answer to the question, but the user may have not understood the purpose of the website. Maybe it's asking about the weather, or it's trying some prompt injection trick. Classify the question in one of the following categories + +1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query +2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization. +3. A general request that can't be considered any of the previous two. If that's the case answer with the single word general. + +Note that your response will be validated, and only the options query, visualization, and general will be accepted. diff --git a/src/hellocomputer/tools.py b/src/hellocomputer/tools.py index fad16e2..cf8a4a2 100644 --- a/src/hellocomputer/tools.py +++ b/src/hellocomputer/tools.py @@ -11,16 +11,30 @@ class DuckdbQueryInput(BaseModel): class DuckdbQueryTool(BaseTool): - name: str = "Calculator" - description: str = "Tool to evaluate mathemetical expressions" + name: str = "sql_query" + description: str = "Run a SQL query in the database containing all the datasets " + "and provide a summary of the results" args_schema: Type[BaseModel] = DuckdbQueryInput def _run(self, query: str, session_id: str) -> str: """Run the query""" db = SessionDB(settings, session_id) - return "Table" async def _arun(self, query: str, session_id: str) -> str: """Use the tool asynchronously.""" db = SessionDB(settings, session_id) return "Table" + + +class PlotHistogramInput(BaseModel): + column_name: str = Field(description="Name of the column containing the values") + table_name: str = Field(description="Name of the table that contains the data") + num_bins: int = Field(description="Number of bins of the histogram") + + +class PlotHistogramTool(BaseTool): + name: str = "plot_histogram" + description: str = """ + Generate a histogram plot given a name of an existing table of the database, + and a name of a column in the table. The default number of bins is 10, but + you can forward the number of bins if you are requested to""" diff --git a/test/test_query.py b/test/test_query.py index c4c188f..ae600a4 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -5,8 +5,8 @@ import polars as pl import pytest from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.sessions import SessionDB -from hellocomputer.extraction import extract_code_block -from hellocomputer.models import AvailableModels +from hellocomputer.models import AvailableModels, Prompts +from hellocomputer.extraction import initial_intent_parser from langchain_core.prompts import ChatPromptTemplate from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_openai import ChatOpenAI @@ -42,7 +42,7 @@ async def test_chat_simple(): chain = prompt | llm response = await chain.ainvoke({"word": "Hello"}) - assert response.content.lower().startswith("hello") + assert "hello" in response.content.lower() @pytest.mark.asyncio @@ -63,6 +63,27 @@ async def test_query_context(): assert "table_names" in context +@pytest.mark.asyncio +@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") +async def test_initial_intent(): + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.llama_medium, + temperature=0.5, + ) + prompt = await Prompts.intent() + chain = prompt | llm | initial_intent_parser + + response = await chain.ainvoke({"query", "Make me a sandwich"}) + assert response == "general" + + response = await chain.ainvoke( + {"query", "Which is the average score of all the students"} + ) + assert response == "query" + + # # chat = await chat.sql_eval(db.query_prompt(query)) # query = extract_code_block(chat.last_response_content())