This commit is contained in:
parent
181bc92884
commit
98a713b3c7
|
@ -19,7 +19,8 @@ You'll need the following environment variables in a .env file:
|
|||
|
||||
* `GCS_ACCESS`
|
||||
* `GCS_SECRET`
|
||||
* `ANYSCALE_API_KEY`
|
||||
* `LLM_API_KEY`
|
||||
* `LLM_BASE_URL`
|
||||
* `GCS_BUCKETNAME`
|
||||
|
||||
And to get the application up and running...
|
||||
|
|
140
notebooks/newtasks.ipynb
Normal file
140
notebooks/newtasks.ipynb
Normal file
|
@ -0,0 +1,140 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"client = openai.OpenAI(\n",
|
||||
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
|
||||
" api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": f\"You are a helpful assistant with access to functions.\" \n",
|
||||
" \t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\"Use them if required.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"What are Nike's net income in 2022?\"}\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"function\": {\n",
|
||||
" # name of the function \n",
|
||||
" \"name\": \"get_financial_data\",\n",
|
||||
" # a good, detailed description for what the function is supposed to do\n",
|
||||
" \"description\": \"Get financial data for a company given the metric and year.\",\n",
|
||||
" # a well defined json schema: https://json-schema.org/learn/getting-started-step-by-step#define\n",
|
||||
" \"parameters\": {\n",
|
||||
" # for OpenAI compatibility, we always declare a top level object for the parameters of the function\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" # the properties for the object would be any arguments you want to provide to the function\n",
|
||||
" \"properties\": {\n",
|
||||
" \"metric\": {\n",
|
||||
" # JSON Schema supports string, number, integer, object, array, boolean and null\n",
|
||||
" # for more information, please check out https://json-schema.org/understanding-json-schema/reference/type\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" # You can restrict the space of possible values in an JSON Schema\n",
|
||||
" # you can check out https://json-schema.org/understanding-json-schema/reference/enum for more examples on how enum works\n",
|
||||
" \"enum\": [\"net_income\", \"revenue\", \"ebdita\"],\n",
|
||||
" },\n",
|
||||
" \"financial_year\": {\n",
|
||||
" \"type\": \"integer\", \n",
|
||||
" # If the model does not understand how it is supposed to fill the field, a good description goes a long way \n",
|
||||
" \"description\": \"Year for which we want to get financial data.\"\n",
|
||||
" },\n",
|
||||
" \"company\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"Name of the company for which we want to get financial data.\"\n",
|
||||
" }\n",
|
||||
" },\n",
|
||||
" # You can specify which of the properties from above are required\n",
|
||||
" # for more info on `required` field, please check https://json-schema.org/understanding-json-schema/reference/object#required\n",
|
||||
" \"required\": [\"metric\", \"financial_year\", \"company\"],\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\n",
|
||||
" \"content\": null,\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"function_call\": null,\n",
|
||||
" \"tool_calls\": [\n",
|
||||
" {\n",
|
||||
" \"id\": \"call_W4IUqQRFF9vYINQ74tfBwmqr\",\n",
|
||||
" \"function\": {\n",
|
||||
" \"arguments\": \"{\\\"metric\\\": \\\"net_income\\\", \\\"financial_year\\\": 2022, \\\"company\\\": \\\"Nike\\\"}\",\n",
|
||||
" \"name\": \"get_financial_data\"\n",
|
||||
" },\n",
|
||||
" \"type\": \"function\",\n",
|
||||
" \"index\": 0\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
"}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"chat_completion = client.chat.completions.create(\n",
|
||||
" model=\"accounts/fireworks/models/firefunction-v2\",\n",
|
||||
" messages=messages,\n",
|
||||
" tools=tools,\n",
|
||||
" temperature=0.1\n",
|
||||
")\n",
|
||||
"print(chat_completion.choices[0].message.model_dump_json(indent=4))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
File diff suppressed because one or more lines are too long
|
@ -1,5 +1,6 @@
|
|||
langchain
|
||||
langchain-community
|
||||
langchain-fireworks
|
||||
openai
|
||||
fastapi
|
||||
pydantic-settings
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in
|
||||
aiobotocore==2.13.0
|
||||
aiobotocore==2.13.1
|
||||
# via s3fs
|
||||
aiofiles==23.2.1
|
||||
aiofiles==24.1.0
|
||||
# via -r requirements.in
|
||||
aiohttp==3.9.5
|
||||
# via
|
||||
# aiobotocore
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-fireworks
|
||||
# s3fs
|
||||
aioitertools==0.11.0
|
||||
# via aiobotocore
|
||||
|
@ -24,9 +26,10 @@ anyio==4.4.0
|
|||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
authlib==1.3.1
|
||||
botocore==1.34.106
|
||||
# via -r requirements.in
|
||||
botocore==1.34.131
|
||||
# via aiobotocore
|
||||
certifi==2024.6.2
|
||||
certifi==2024.7.4
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
|
@ -50,16 +53,24 @@ distro==1.9.0
|
|||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
duckdb==1.0.0
|
||||
email-validator==2.1.1
|
||||
# via
|
||||
# -r requirements.in
|
||||
# duckdb-engine
|
||||
duckdb-engine==0.13.0
|
||||
# via -r requirements.in
|
||||
email-validator==2.2.0
|
||||
# via fastapi
|
||||
fastapi==0.111.0
|
||||
# via -r requirements.in
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
fireworks-ai==0.14.0
|
||||
# via langchain-fireworks
|
||||
frozenlist==1.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.6.0
|
||||
fsspec==2024.6.1
|
||||
# via s3fs
|
||||
h11==0.14.0
|
||||
# via
|
||||
|
@ -72,7 +83,10 @@ httptools==0.6.1
|
|||
httpx==0.27.0
|
||||
# via
|
||||
# fastapi
|
||||
# fireworks-ai
|
||||
# openai
|
||||
httpx-sse==0.4.0
|
||||
# via fireworks-ai
|
||||
idna==3.7
|
||||
# via
|
||||
# anyio
|
||||
|
@ -81,6 +95,7 @@ idna==3.7
|
|||
# requests
|
||||
# yarl
|
||||
itsdangerous==2.2.0
|
||||
# via -r requirements.in
|
||||
jinja2==3.1.4
|
||||
# via fastapi
|
||||
jmespath==1.0.1
|
||||
|
@ -89,17 +104,23 @@ jsonpatch==1.33
|
|||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==0.2.3
|
||||
# via langchain-community
|
||||
langchain-community==0.2.4
|
||||
langchain-core==0.2.5
|
||||
langchain==0.2.7
|
||||
# via
|
||||
# -r requirements.in
|
||||
# langchain-community
|
||||
langchain-community==0.2.7
|
||||
# via -r requirements.in
|
||||
langchain-core==0.2.17
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-fireworks
|
||||
# langchain-text-splitters
|
||||
langchain-text-splitters==0.2.1
|
||||
langchain-fireworks==0.1.5
|
||||
# via -r requirements.in
|
||||
langchain-text-splitters==0.2.2
|
||||
# via langchain
|
||||
langsmith==0.1.76
|
||||
langsmith==0.1.85
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
|
@ -123,33 +144,44 @@ numpy==1.26.4
|
|||
# langchain
|
||||
# langchain-community
|
||||
# pyarrow
|
||||
openai==1.33.0
|
||||
orjson==3.10.4
|
||||
openai==1.35.13
|
||||
# via
|
||||
# -r requirements.in
|
||||
# langchain-fireworks
|
||||
orjson==3.10.6
|
||||
# via
|
||||
# fastapi
|
||||
# langsmith
|
||||
packaging==23.2
|
||||
packaging==24.1
|
||||
# via
|
||||
# duckdb-engine
|
||||
# langchain-core
|
||||
# marshmallow
|
||||
polars==0.20.31
|
||||
pillow==10.4.0
|
||||
# via fireworks-ai
|
||||
polars==1.1.0
|
||||
# via -r requirements.in
|
||||
pyarrow==16.1.0
|
||||
# via -r requirements.in
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.7.3
|
||||
pydantic==2.8.2
|
||||
# via
|
||||
# fastapi
|
||||
# fireworks-ai
|
||||
# langchain
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# openai
|
||||
# pydantic-settings
|
||||
pydantic-core==2.18.4
|
||||
pydantic-core==2.20.1
|
||||
# via pydantic
|
||||
pydantic-settings==2.3.2
|
||||
pydantic-settings==2.3.4
|
||||
# via -r requirements.in
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyjwt==2.8.0
|
||||
# via -r requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via botocore
|
||||
python-dotenv==1.0.1
|
||||
|
@ -157,7 +189,9 @@ python-dotenv==1.0.1
|
|||
# pydantic-settings
|
||||
# uvicorn
|
||||
python-multipart==0.0.9
|
||||
# via fastapi
|
||||
# via
|
||||
# -r requirements.in
|
||||
# fastapi
|
||||
pyyaml==6.0.1
|
||||
# via
|
||||
# langchain
|
||||
|
@ -168,10 +202,12 @@ requests==2.32.3
|
|||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-fireworks
|
||||
# langsmith
|
||||
rich==13.7.1
|
||||
# via typer
|
||||
s3fs==2024.6.0
|
||||
s3fs==2024.6.1
|
||||
# via -r requirements.in
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
|
@ -181,13 +217,15 @@ sniffio==1.3.1
|
|||
# anyio
|
||||
# httpx
|
||||
# openai
|
||||
sqlalchemy==2.0.30
|
||||
sqlalchemy==2.0.31
|
||||
# via
|
||||
# -r requirements.in
|
||||
# duckdb-engine
|
||||
# langchain
|
||||
# langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
tenacity==8.3.0
|
||||
tenacity==8.5.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
|
@ -209,7 +247,7 @@ typing-inspect==0.9.0
|
|||
# via dataclasses-json
|
||||
ujson==5.10.0
|
||||
# via fastapi
|
||||
urllib3==2.2.1
|
||||
urllib3==2.2.2
|
||||
# via
|
||||
# botocore
|
||||
# requests
|
||||
|
|
|
@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||
|
||||
class Settings(BaseSettings):
|
||||
base_url: str = "http://localhost:8000"
|
||||
anyscale_api_key: str = "Awesome API"
|
||||
llm_api_key: str = "Awesome API"
|
||||
llm_base_url: str = "Awessome Endpoint"
|
||||
gcs_access: str = "access"
|
||||
gcs_secret: str = "secret"
|
||||
gcs_bucketname: str = "bucket"
|
||||
|
|
|
@ -22,7 +22,6 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key)
|
|||
async def homepage(request: Request):
|
||||
user = request.session.get("user")
|
||||
if user:
|
||||
print(json.dumps(user))
|
||||
return RedirectResponse("/app")
|
||||
|
||||
with open(static_path / "login.html") as f:
|
||||
|
|
|
@ -1,14 +1,34 @@
|
|||
from enum import StrEnum
|
||||
|
||||
from langchain_community.chat_models import ChatAnyscale
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_fireworks import Fireworks
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
|
||||
class AvailableModels(StrEnum):
|
||||
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
llama3_70b = "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
|
||||
# Function calling model
|
||||
mixtral_8x7b = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
||||
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||
|
||||
|
||||
general_prompt = """
|
||||
You're a helpful assistant. Perform the following tasks:
|
||||
|
||||
----
|
||||
{query}
|
||||
----
|
||||
"""
|
||||
|
||||
sql_prompt = """
|
||||
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
||||
|
||||
----
|
||||
{query}
|
||||
----
|
||||
|
||||
Return only the sql statement without any additional text.
|
||||
"""
|
||||
|
||||
|
||||
class Chat:
|
||||
|
@ -25,33 +45,36 @@ class Chat:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: AvailableModels = AvailableModels.llama3_8b,
|
||||
model: AvailableModels = AvailableModels.mixtral_8x7b,
|
||||
api_key: str = "",
|
||||
temperature: float = 0.5,
|
||||
):
|
||||
self.model_name = model
|
||||
self.model = model
|
||||
self.api_key = self.raise_no_key(api_key)
|
||||
self.messages = []
|
||||
self.responses = []
|
||||
|
||||
self.model: ChatAnyscale = ChatAnyscale(
|
||||
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
|
||||
self.model: Fireworks = Fireworks(
|
||||
model=model, temperature=temperature, api_key=self.api_key
|
||||
)
|
||||
|
||||
async def eval(self, system: str, human: str):
|
||||
self.messages.append(
|
||||
[
|
||||
SystemMessage(content=system),
|
||||
HumanMessage(content=human),
|
||||
]
|
||||
)
|
||||
async def eval(self, task):
|
||||
prompt = PromptTemplate.from_template(general_prompt)
|
||||
|
||||
response = await self.model.ainvoke(self.messages[-1])
|
||||
response = await self.model.ainvoke(prompt.format(query=task))
|
||||
self.responses.append(response)
|
||||
return self
|
||||
|
||||
async def sql_eval(self, question):
|
||||
prompt = PromptTemplate.from_template(sql_prompt)
|
||||
|
||||
response = await self.model.ainvoke(prompt.format(query=question))
|
||||
self.responses.append(response)
|
||||
return self
|
||||
|
||||
def last_response_content(self):
|
||||
return self.responses[-1].content
|
||||
last_response = self.responses[-1]
|
||||
return last_response
|
||||
|
||||
def last_response_metadata(self):
|
||||
return self.responses[-1].response_metadata
|
||||
|
|
|
@ -13,7 +13,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||
async def query(sid: str = "", q: str = "") -> str:
|
||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
llm = Chat(api_key=settings.llm_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
|
|
|
@ -20,45 +20,45 @@ SID = "test"
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||
)
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_chat_simple():
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
|
||||
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
|
||||
assert chat.last_response_content() == "Hello!"
|
||||
chat = Chat(api_key=settings.llm_api_key, temperature=0)
|
||||
chat = await chat.eval("Say literlly 'Hello'")
|
||||
assert "Hello" in chat.last_response_content()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||
)
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_simple_data_query():
|
||||
query = "write a query that finds the average score of all students in the current database"
|
||||
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
chat = Chat(
|
||||
api_key=settings.llm_api_key,
|
||||
temperature=0.5,
|
||||
)
|
||||
db = SessionDB(
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
||||
).load_xls(TEST_XLS_PATH)
|
||||
|
||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||
chat = await chat.sql_eval(db.query_prompt(query))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
assert query.startswith("SELECT")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||
)
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_data_query():
|
||||
q = "find the average score of all the sudents"
|
||||
|
||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
llm = Chat(
|
||||
api_key=settings.llm_api_key,
|
||||
temperature=0.5,
|
||||
)
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
|
||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||
chat = await llm.sql_eval(db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = db.query(query).pl()
|
||||
|
||||
|
|
Loading…
Reference in a new issue