Ported to fireworks
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-13 10:52:45 +02:00
parent 181bc92884
commit 98a713b3c7
10 changed files with 347 additions and 88 deletions

View file

@ -19,7 +19,8 @@ You'll need the following environment variables in a .env file:
* `GCS_ACCESS`
* `GCS_SECRET`
* `ANYSCALE_API_KEY`
* `LLM_API_KEY`
* `LLM_BASE_URL`
* `GCS_BUCKETNAME`
And to get the application up and running...

140
notebooks/newtasks.ipynb Normal file
View file

@ -0,0 +1,140 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"\n",
"client = openai.OpenAI(\n",
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
" api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n",
")\n",
"\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": f\"You are a helpful assistant with access to functions.\" \n",
" \t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\"Use them if required.\"},\n",
" {\"role\": \"user\", \"content\": \"What are Nike's net income in 2022?\"}\n",
"]\n",
"\n",
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" # name of the function \n",
" \"name\": \"get_financial_data\",\n",
" # a good, detailed description for what the function is supposed to do\n",
" \"description\": \"Get financial data for a company given the metric and year.\",\n",
" # a well defined json schema: https://json-schema.org/learn/getting-started-step-by-step#define\n",
" \"parameters\": {\n",
" # for OpenAI compatibility, we always declare a top level object for the parameters of the function\n",
" \"type\": \"object\",\n",
" # the properties for the object would be any arguments you want to provide to the function\n",
" \"properties\": {\n",
" \"metric\": {\n",
" # JSON Schema supports string, number, integer, object, array, boolean and null\n",
" # for more information, please check out https://json-schema.org/understanding-json-schema/reference/type\n",
" \"type\": \"string\",\n",
" # You can restrict the space of possible values in an JSON Schema\n",
" # you can check out https://json-schema.org/understanding-json-schema/reference/enum for more examples on how enum works\n",
" \"enum\": [\"net_income\", \"revenue\", \"ebdita\"],\n",
" },\n",
" \"financial_year\": {\n",
" \"type\": \"integer\", \n",
" # If the model does not understand how it is supposed to fill the field, a good description goes a long way \n",
" \"description\": \"Year for which we want to get financial data.\"\n",
" },\n",
" \"company\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Name of the company for which we want to get financial data.\"\n",
" }\n",
" },\n",
" # You can specify which of the properties from above are required\n",
" # for more info on `required` field, please check https://json-schema.org/understanding-json-schema/reference/object#required\n",
" \"required\": [\"metric\", \"financial_year\", \"company\"],\n",
" },\n",
" },\n",
" }\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"content\": null,\n",
" \"role\": \"assistant\",\n",
" \"function_call\": null,\n",
" \"tool_calls\": [\n",
" {\n",
" \"id\": \"call_W4IUqQRFF9vYINQ74tfBwmqr\",\n",
" \"function\": {\n",
" \"arguments\": \"{\\\"metric\\\": \\\"net_income\\\", \\\"financial_year\\\": 2022, \\\"company\\\": \\\"Nike\\\"}\",\n",
" \"name\": \"get_financial_data\"\n",
" },\n",
" \"type\": \"function\",\n",
" \"index\": 0\n",
" }\n",
" ]\n",
"}\n"
]
}
],
"source": [
"\n",
"chat_completion = client.chat.completions.create(\n",
" model=\"accounts/fireworks/models/firefunction-v2\",\n",
" messages=messages,\n",
" tools=tools,\n",
" temperature=0.1\n",
")\n",
"print(chat_completion.choices[0].message.model_dump_json(indent=4))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

View file

@ -1,5 +1,6 @@
langchain
langchain-community
langchain-fireworks
openai
fastapi
pydantic-settings

View file

@ -1,13 +1,15 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in
aiobotocore==2.13.0
aiobotocore==2.13.1
# via s3fs
aiofiles==23.2.1
aiofiles==24.1.0
# via -r requirements.in
aiohttp==3.9.5
# via
# aiobotocore
# langchain
# langchain-community
# langchain-fireworks
# s3fs
aioitertools==0.11.0
# via aiobotocore
@ -24,9 +26,10 @@ anyio==4.4.0
attrs==23.2.0
# via aiohttp
authlib==1.3.1
botocore==1.34.106
# via -r requirements.in
botocore==1.34.131
# via aiobotocore
certifi==2024.6.2
certifi==2024.7.4
# via
# httpcore
# httpx
@ -50,16 +53,24 @@ distro==1.9.0
dnspython==2.6.1
# via email-validator
duckdb==1.0.0
email-validator==2.1.1
# via
# -r requirements.in
# duckdb-engine
duckdb-engine==0.13.0
# via -r requirements.in
email-validator==2.2.0
# via fastapi
fastapi==0.111.0
# via -r requirements.in
fastapi-cli==0.0.4
# via fastapi
fireworks-ai==0.14.0
# via langchain-fireworks
frozenlist==1.4.1
# via
# aiohttp
# aiosignal
fsspec==2024.6.0
fsspec==2024.6.1
# via s3fs
h11==0.14.0
# via
@ -72,7 +83,10 @@ httptools==0.6.1
httpx==0.27.0
# via
# fastapi
# fireworks-ai
# openai
httpx-sse==0.4.0
# via fireworks-ai
idna==3.7
# via
# anyio
@ -81,6 +95,7 @@ idna==3.7
# requests
# yarl
itsdangerous==2.2.0
# via -r requirements.in
jinja2==3.1.4
# via fastapi
jmespath==1.0.1
@ -89,17 +104,23 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.3
# via langchain-community
langchain-community==0.2.4
langchain-core==0.2.5
langchain==0.2.7
# via
# -r requirements.in
# langchain-community
langchain-community==0.2.7
# via -r requirements.in
langchain-core==0.2.17
# via
# langchain
# langchain-community
# langchain-fireworks
# langchain-text-splitters
langchain-text-splitters==0.2.1
langchain-fireworks==0.1.5
# via -r requirements.in
langchain-text-splitters==0.2.2
# via langchain
langsmith==0.1.76
langsmith==0.1.85
# via
# langchain
# langchain-community
@ -123,33 +144,44 @@ numpy==1.26.4
# langchain
# langchain-community
# pyarrow
openai==1.33.0
orjson==3.10.4
openai==1.35.13
# via
# -r requirements.in
# langchain-fireworks
orjson==3.10.6
# via
# fastapi
# langsmith
packaging==23.2
packaging==24.1
# via
# duckdb-engine
# langchain-core
# marshmallow
polars==0.20.31
pillow==10.4.0
# via fireworks-ai
polars==1.1.0
# via -r requirements.in
pyarrow==16.1.0
# via -r requirements.in
pycparser==2.22
# via cffi
pydantic==2.7.3
pydantic==2.8.2
# via
# fastapi
# fireworks-ai
# langchain
# langchain-core
# langsmith
# openai
# pydantic-settings
pydantic-core==2.18.4
pydantic-core==2.20.1
# via pydantic
pydantic-settings==2.3.2
pydantic-settings==2.3.4
# via -r requirements.in
pygments==2.18.0
# via rich
pyjwt==2.8.0
# via -r requirements.in
python-dateutil==2.9.0.post0
# via botocore
python-dotenv==1.0.1
@ -157,7 +189,9 @@ python-dotenv==1.0.1
# pydantic-settings
# uvicorn
python-multipart==0.0.9
# via fastapi
# via
# -r requirements.in
# fastapi
pyyaml==6.0.1
# via
# langchain
@ -168,10 +202,12 @@ requests==2.32.3
# via
# langchain
# langchain-community
# langchain-fireworks
# langsmith
rich==13.7.1
# via typer
s3fs==2024.6.0
s3fs==2024.6.1
# via -r requirements.in
shellingham==1.5.4
# via typer
six==1.16.0
@ -181,13 +217,15 @@ sniffio==1.3.1
# anyio
# httpx
# openai
sqlalchemy==2.0.30
sqlalchemy==2.0.31
# via
# -r requirements.in
# duckdb-engine
# langchain
# langchain-community
starlette==0.37.2
# via fastapi
tenacity==8.3.0
tenacity==8.5.0
# via
# langchain
# langchain-community
@ -209,7 +247,7 @@ typing-inspect==0.9.0
# via dataclasses-json
ujson==5.10.0
# via fastapi
urllib3==2.2.1
urllib3==2.2.2
# via
# botocore
# requests

View file

@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
base_url: str = "http://localhost:8000"
anyscale_api_key: str = "Awesome API"
llm_api_key: str = "Awesome API"
llm_base_url: str = "Awessome Endpoint"
gcs_access: str = "access"
gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"

View file

@ -22,7 +22,6 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key)
async def homepage(request: Request):
user = request.session.get("user")
if user:
print(json.dumps(user))
return RedirectResponse("/app")
with open(static_path / "login.html") as f:

View file

@ -1,14 +1,34 @@
from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_fireworks import Fireworks
from langchain_core.prompts import PromptTemplate
class AvailableModels(StrEnum):
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
llama3_70b = "meta-llama/Meta-Llama-3-70B-Instruct"
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
# Function calling model
mixtral_8x7b = "mistralai/Mixtral-8x7B-Instruct-v0.1"
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
general_prompt = """
You're a helpful assistant. Perform the following tasks:
----
{query}
----
"""
sql_prompt = """
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
----
{query}
----
Return only the sql statement without any additional text.
"""
class Chat:
@ -25,33 +45,36 @@ class Chat:
def __init__(
self,
model: AvailableModels = AvailableModels.llama3_8b,
model: AvailableModels = AvailableModels.mixtral_8x7b,
api_key: str = "",
temperature: float = 0.5,
):
self.model_name = model
self.model = model
self.api_key = self.raise_no_key(api_key)
self.messages = []
self.responses = []
self.model: ChatAnyscale = ChatAnyscale(
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
self.model: Fireworks = Fireworks(
model=model, temperature=temperature, api_key=self.api_key
)
async def eval(self, system: str, human: str):
self.messages.append(
[
SystemMessage(content=system),
HumanMessage(content=human),
]
)
async def eval(self, task):
prompt = PromptTemplate.from_template(general_prompt)
response = await self.model.ainvoke(self.messages[-1])
response = await self.model.ainvoke(prompt.format(query=task))
self.responses.append(response)
return self
async def sql_eval(self, question):
prompt = PromptTemplate.from_template(sql_prompt)
response = await self.model.ainvoke(prompt.format(query=question))
self.responses.append(response)
return self
def last_response_content(self):
return self.responses[-1].content
last_response = self.responses[-1]
return last_response
def last_response_metadata(self):
return self.responses[-1].response_metadata

View file

@ -13,7 +13,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
llm = Chat(api_key=settings.llm_api_key, temperature=0.5)
db = SessionDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,

View file

@ -20,45 +20,45 @@ SID = "test"
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_chat_simple():
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
assert chat.last_response_content() == "Hello!"
chat = Chat(api_key=settings.llm_api_key, temperature=0)
chat = await chat.eval("Say literlly 'Hello'")
assert "Hello" in chat.last_response_content()
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
chat = Chat(
api_key=settings.llm_api_key,
temperature=0.5,
)
db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
chat = await chat.sql_eval(db.query_prompt(query))
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_data_query():
q = "find the average score of all the sudents"
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
llm = Chat(
api_key=settings.llm_api_key,
temperature=0.5,
)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
chat = await llm.sql_eval(db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = db.query(query).pl()