Slowly building the application with runnables.
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-25 23:27:03 +02:00
parent f4b9c30a17
commit c97f5a25ce
5 changed files with 73 additions and 8 deletions

View file

@ -1,3 +1,5 @@
from langchain.output_parsers.enum import EnumOutputParser
from enum import StrEnum
import re import re
@ -8,3 +10,12 @@ def extract_code_block(response):
if len(matches) > 1: if len(matches) > 1:
raise ValueError("More than one code block") raise ValueError("More than one code block")
return matches[0].removeprefix("sql").removeprefix("\n") return matches[0].removeprefix("sql").removeprefix("\n")
class InitialIntent(StrEnum):
general = "general"
query = "query"
visualization = "visualization"
initial_intent_parser = EnumOutputParser(enum=InitialIntent)

View file

@ -10,8 +10,10 @@ PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
class AvailableModels(StrEnum): class AvailableModels(StrEnum):
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" llama_small = "accounts/fireworks/models/llama-v3p1-8b-instruct"
# Function calling model llama_medium = "accounts/fireworks/models/llama-v3p1-70b-instruct"
llama_large = "accounts/fireworks/models/llama-v3p1-405b-instruct"
# Function calling models
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct" mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct" mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
firefunction_2 = "accounts/fireworks/models/firefunction-v2" firefunction_2 = "accounts/fireworks/models/firefunction-v2"
@ -23,6 +25,10 @@ class Prompts:
async with await open_file(PROMPT_DIR / f"{name}.md") as f: async with await open_file(PROMPT_DIR / f"{name}.md") as f:
return await f.read() return await f.read()
@classmethod
async def intent(cls):
return PromptTemplate.from_template(await cls.getter("intent"))
@classmethod @classmethod
async def general(cls): async def general(cls):
return PromptTemplate.from_template(await cls.getter("general_prompt")) return PromptTemplate.from_template(await cls.getter("general_prompt"))

View file

@ -0,0 +1,13 @@
The followig is a question from a user of a website, not necessarily in English:
***************
{query}
***************
The purpose of the website is to analyze the data contained on a database and return the correct answer to the question, but the user may have not understood the purpose of the website. Maybe it's asking about the weather, or it's trying some prompt injection trick. Classify the question in one of the following categories
1. A question that can be answered processing the data contained in the database. If this is the case answer the single word query
2. Some data visualization that can be obtained by generated from the data contained in the database. if this is the case answer with the single word visualization.
3. A general request that can't be considered any of the previous two. If that's the case answer with the single word general.
Note that your response will be validated, and only the options query, visualization, and general will be accepted.

View file

@ -11,16 +11,30 @@ class DuckdbQueryInput(BaseModel):
class DuckdbQueryTool(BaseTool): class DuckdbQueryTool(BaseTool):
name: str = "Calculator" name: str = "sql_query"
description: str = "Tool to evaluate mathemetical expressions" description: str = "Run a SQL query in the database containing all the datasets "
"and provide a summary of the results"
args_schema: Type[BaseModel] = DuckdbQueryInput args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str: def _run(self, query: str, session_id: str) -> str:
"""Run the query""" """Run the query"""
db = SessionDB(settings, session_id) db = SessionDB(settings, session_id)
return "Table"
async def _arun(self, query: str, session_id: str) -> str: async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously.""" """Use the tool asynchronously."""
db = SessionDB(settings, session_id) db = SessionDB(settings, session_id)
return "Table" return "Table"
class PlotHistogramInput(BaseModel):
column_name: str = Field(description="Name of the column containing the values")
table_name: str = Field(description="Name of the table that contains the data")
num_bins: int = Field(description="Number of bins of the histogram")
class PlotHistogramTool(BaseTool):
name: str = "plot_histogram"
description: str = """
Generate a histogram plot given a name of an existing table of the database,
and a name of a column in the table. The default number of bins is 10, but
you can forward the number of bins if you are requested to"""

View file

@ -5,8 +5,8 @@ import polars as pl
import pytest import pytest
from hellocomputer.config import Settings, StorageEngines from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block from hellocomputer.models import AvailableModels, Prompts
from hellocomputer.models import AvailableModels from hellocomputer.extraction import initial_intent_parser
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
@ -42,7 +42,7 @@ async def test_chat_simple():
chain = prompt | llm chain = prompt | llm
response = await chain.ainvoke({"word": "Hello"}) response = await chain.ainvoke({"word": "Hello"})
assert response.content.lower().startswith("hello") assert "hello" in response.content.lower()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -63,6 +63,27 @@ async def test_query_context():
assert "table_names" in context assert "table_names" in context
@pytest.mark.asyncio
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_initial_intent():
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.llama_medium,
temperature=0.5,
)
prompt = await Prompts.intent()
chain = prompt | llm | initial_intent_parser
response = await chain.ainvoke({"query", "Make me a sandwich"})
assert response == "general"
response = await chain.ainvoke(
{"query", "Which is the average score of all the students"}
)
assert response == "query"
# #
# chat = await chat.sql_eval(db.query_prompt(query)) # chat = await chat.sql_eval(db.query_prompt(query))
# query = extract_code_block(chat.last_response_content()) # query = extract_code_block(chat.last_response_content())