This commit is contained in:
parent
181bc92884
commit
98a713b3c7
|
@ -19,7 +19,8 @@ You'll need the following environment variables in a .env file:
|
||||||
|
|
||||||
* `GCS_ACCESS`
|
* `GCS_ACCESS`
|
||||||
* `GCS_SECRET`
|
* `GCS_SECRET`
|
||||||
* `ANYSCALE_API_KEY`
|
* `LLM_API_KEY`
|
||||||
|
* `LLM_BASE_URL`
|
||||||
* `GCS_BUCKETNAME`
|
* `GCS_BUCKETNAME`
|
||||||
|
|
||||||
And to get the application up and running...
|
And to get the application up and running...
|
||||||
|
|
140
notebooks/newtasks.ipynb
Normal file
140
notebooks/newtasks.ipynb
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import openai\n",
|
||||||
|
"import json"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"client = openai.OpenAI(\n",
|
||||||
|
" base_url = \"https://api.fireworks.ai/inference/v1\",\n",
|
||||||
|
" api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"messages = [\n",
|
||||||
|
" {\"role\": \"system\", \"content\": f\"You are a helpful assistant with access to functions.\" \n",
|
||||||
|
" \t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\"Use them if required.\"},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": \"What are Nike's net income in 2022?\"}\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"tools = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"function\": {\n",
|
||||||
|
" # name of the function \n",
|
||||||
|
" \"name\": \"get_financial_data\",\n",
|
||||||
|
" # a good, detailed description for what the function is supposed to do\n",
|
||||||
|
" \"description\": \"Get financial data for a company given the metric and year.\",\n",
|
||||||
|
" # a well defined json schema: https://json-schema.org/learn/getting-started-step-by-step#define\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" # for OpenAI compatibility, we always declare a top level object for the parameters of the function\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" # the properties for the object would be any arguments you want to provide to the function\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"metric\": {\n",
|
||||||
|
" # JSON Schema supports string, number, integer, object, array, boolean and null\n",
|
||||||
|
" # for more information, please check out https://json-schema.org/understanding-json-schema/reference/type\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" # You can restrict the space of possible values in an JSON Schema\n",
|
||||||
|
" # you can check out https://json-schema.org/understanding-json-schema/reference/enum for more examples on how enum works\n",
|
||||||
|
" \"enum\": [\"net_income\", \"revenue\", \"ebdita\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" \"financial_year\": {\n",
|
||||||
|
" \"type\": \"integer\", \n",
|
||||||
|
" # If the model does not understand how it is supposed to fill the field, a good description goes a long way \n",
|
||||||
|
" \"description\": \"Year for which we want to get financial data.\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"company\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"Name of the company for which we want to get financial data.\"\n",
|
||||||
|
" }\n",
|
||||||
|
" },\n",
|
||||||
|
" # You can specify which of the properties from above are required\n",
|
||||||
|
" # for more info on `required` field, please check https://json-schema.org/understanding-json-schema/reference/object#required\n",
|
||||||
|
" \"required\": [\"metric\", \"financial_year\", \"company\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
"]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\n",
|
||||||
|
" \"content\": null,\n",
|
||||||
|
" \"role\": \"assistant\",\n",
|
||||||
|
" \"function_call\": null,\n",
|
||||||
|
" \"tool_calls\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"id\": \"call_W4IUqQRFF9vYINQ74tfBwmqr\",\n",
|
||||||
|
" \"function\": {\n",
|
||||||
|
" \"arguments\": \"{\\\"metric\\\": \\\"net_income\\\", \\\"financial_year\\\": 2022, \\\"company\\\": \\\"Nike\\\"}\",\n",
|
||||||
|
" \"name\": \"get_financial_data\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"type\": \"function\",\n",
|
||||||
|
" \"index\": 0\n",
|
||||||
|
" }\n",
|
||||||
|
" ]\n",
|
||||||
|
"}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"chat_completion = client.chat.completions.create(\n",
|
||||||
|
" model=\"accounts/fireworks/models/firefunction-v2\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" tools=tools,\n",
|
||||||
|
" temperature=0.1\n",
|
||||||
|
")\n",
|
||||||
|
"print(chat_completion.choices[0].message.model_dump_json(indent=4))\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.4"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
File diff suppressed because one or more lines are too long
|
@ -1,5 +1,6 @@
|
||||||
langchain
|
langchain
|
||||||
langchain-community
|
langchain-community
|
||||||
|
langchain-fireworks
|
||||||
openai
|
openai
|
||||||
fastapi
|
fastapi
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
|
|
|
@ -1,13 +1,15 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
# This file was autogenerated by uv via the following command:
|
||||||
# uv pip compile requirements.in
|
# uv pip compile requirements.in
|
||||||
aiobotocore==2.13.0
|
aiobotocore==2.13.1
|
||||||
# via s3fs
|
# via s3fs
|
||||||
aiofiles==23.2.1
|
aiofiles==24.1.0
|
||||||
|
# via -r requirements.in
|
||||||
aiohttp==3.9.5
|
aiohttp==3.9.5
|
||||||
# via
|
# via
|
||||||
# aiobotocore
|
# aiobotocore
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
|
# langchain-fireworks
|
||||||
# s3fs
|
# s3fs
|
||||||
aioitertools==0.11.0
|
aioitertools==0.11.0
|
||||||
# via aiobotocore
|
# via aiobotocore
|
||||||
|
@ -24,9 +26,10 @@ anyio==4.4.0
|
||||||
attrs==23.2.0
|
attrs==23.2.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
authlib==1.3.1
|
authlib==1.3.1
|
||||||
botocore==1.34.106
|
# via -r requirements.in
|
||||||
|
botocore==1.34.131
|
||||||
# via aiobotocore
|
# via aiobotocore
|
||||||
certifi==2024.6.2
|
certifi==2024.7.4
|
||||||
# via
|
# via
|
||||||
# httpcore
|
# httpcore
|
||||||
# httpx
|
# httpx
|
||||||
|
@ -50,16 +53,24 @@ distro==1.9.0
|
||||||
dnspython==2.6.1
|
dnspython==2.6.1
|
||||||
# via email-validator
|
# via email-validator
|
||||||
duckdb==1.0.0
|
duckdb==1.0.0
|
||||||
email-validator==2.1.1
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# duckdb-engine
|
||||||
|
duckdb-engine==0.13.0
|
||||||
|
# via -r requirements.in
|
||||||
|
email-validator==2.2.0
|
||||||
# via fastapi
|
# via fastapi
|
||||||
fastapi==0.111.0
|
fastapi==0.111.0
|
||||||
|
# via -r requirements.in
|
||||||
fastapi-cli==0.0.4
|
fastapi-cli==0.0.4
|
||||||
# via fastapi
|
# via fastapi
|
||||||
|
fireworks-ai==0.14.0
|
||||||
|
# via langchain-fireworks
|
||||||
frozenlist==1.4.1
|
frozenlist==1.4.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
fsspec==2024.6.0
|
fsspec==2024.6.1
|
||||||
# via s3fs
|
# via s3fs
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
# via
|
# via
|
||||||
|
@ -72,7 +83,10 @@ httptools==0.6.1
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
|
# fireworks-ai
|
||||||
# openai
|
# openai
|
||||||
|
httpx-sse==0.4.0
|
||||||
|
# via fireworks-ai
|
||||||
idna==3.7
|
idna==3.7
|
||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
|
@ -81,6 +95,7 @@ idna==3.7
|
||||||
# requests
|
# requests
|
||||||
# yarl
|
# yarl
|
||||||
itsdangerous==2.2.0
|
itsdangerous==2.2.0
|
||||||
|
# via -r requirements.in
|
||||||
jinja2==3.1.4
|
jinja2==3.1.4
|
||||||
# via fastapi
|
# via fastapi
|
||||||
jmespath==1.0.1
|
jmespath==1.0.1
|
||||||
|
@ -89,17 +104,23 @@ jsonpatch==1.33
|
||||||
# via langchain-core
|
# via langchain-core
|
||||||
jsonpointer==3.0.0
|
jsonpointer==3.0.0
|
||||||
# via jsonpatch
|
# via jsonpatch
|
||||||
langchain==0.2.3
|
langchain==0.2.7
|
||||||
# via langchain-community
|
# via
|
||||||
langchain-community==0.2.4
|
# -r requirements.in
|
||||||
langchain-core==0.2.5
|
# langchain-community
|
||||||
|
langchain-community==0.2.7
|
||||||
|
# via -r requirements.in
|
||||||
|
langchain-core==0.2.17
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
|
# langchain-fireworks
|
||||||
# langchain-text-splitters
|
# langchain-text-splitters
|
||||||
langchain-text-splitters==0.2.1
|
langchain-fireworks==0.1.5
|
||||||
|
# via -r requirements.in
|
||||||
|
langchain-text-splitters==0.2.2
|
||||||
# via langchain
|
# via langchain
|
||||||
langsmith==0.1.76
|
langsmith==0.1.85
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
|
@ -123,33 +144,44 @@ numpy==1.26.4
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# pyarrow
|
# pyarrow
|
||||||
openai==1.33.0
|
openai==1.35.13
|
||||||
orjson==3.10.4
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# langchain-fireworks
|
||||||
|
orjson==3.10.6
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
# langsmith
|
# langsmith
|
||||||
packaging==23.2
|
packaging==24.1
|
||||||
# via
|
# via
|
||||||
|
# duckdb-engine
|
||||||
# langchain-core
|
# langchain-core
|
||||||
# marshmallow
|
# marshmallow
|
||||||
polars==0.20.31
|
pillow==10.4.0
|
||||||
|
# via fireworks-ai
|
||||||
|
polars==1.1.0
|
||||||
|
# via -r requirements.in
|
||||||
pyarrow==16.1.0
|
pyarrow==16.1.0
|
||||||
|
# via -r requirements.in
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
# via cffi
|
# via cffi
|
||||||
pydantic==2.7.3
|
pydantic==2.8.2
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
|
# fireworks-ai
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-core
|
# langchain-core
|
||||||
# langsmith
|
# langsmith
|
||||||
# openai
|
# openai
|
||||||
# pydantic-settings
|
# pydantic-settings
|
||||||
pydantic-core==2.18.4
|
pydantic-core==2.20.1
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pydantic-settings==2.3.2
|
pydantic-settings==2.3.4
|
||||||
|
# via -r requirements.in
|
||||||
pygments==2.18.0
|
pygments==2.18.0
|
||||||
# via rich
|
# via rich
|
||||||
pyjwt==2.8.0
|
pyjwt==2.8.0
|
||||||
|
# via -r requirements.in
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via botocore
|
# via botocore
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
@ -157,7 +189,9 @@ python-dotenv==1.0.1
|
||||||
# pydantic-settings
|
# pydantic-settings
|
||||||
# uvicorn
|
# uvicorn
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
# via fastapi
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# fastapi
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
|
@ -168,10 +202,12 @@ requests==2.32.3
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
|
# langchain-fireworks
|
||||||
# langsmith
|
# langsmith
|
||||||
rich==13.7.1
|
rich==13.7.1
|
||||||
# via typer
|
# via typer
|
||||||
s3fs==2024.6.0
|
s3fs==2024.6.1
|
||||||
|
# via -r requirements.in
|
||||||
shellingham==1.5.4
|
shellingham==1.5.4
|
||||||
# via typer
|
# via typer
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
@ -181,13 +217,15 @@ sniffio==1.3.1
|
||||||
# anyio
|
# anyio
|
||||||
# httpx
|
# httpx
|
||||||
# openai
|
# openai
|
||||||
sqlalchemy==2.0.30
|
sqlalchemy==2.0.31
|
||||||
# via
|
# via
|
||||||
|
# -r requirements.in
|
||||||
|
# duckdb-engine
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
starlette==0.37.2
|
starlette==0.37.2
|
||||||
# via fastapi
|
# via fastapi
|
||||||
tenacity==8.3.0
|
tenacity==8.5.0
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
# langchain-community
|
# langchain-community
|
||||||
|
@ -209,7 +247,7 @@ typing-inspect==0.9.0
|
||||||
# via dataclasses-json
|
# via dataclasses-json
|
||||||
ujson==5.10.0
|
ujson==5.10.0
|
||||||
# via fastapi
|
# via fastapi
|
||||||
urllib3==2.2.1
|
urllib3==2.2.2
|
||||||
# via
|
# via
|
||||||
# botocore
|
# botocore
|
||||||
# requests
|
# requests
|
||||||
|
|
|
@ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
base_url: str = "http://localhost:8000"
|
base_url: str = "http://localhost:8000"
|
||||||
anyscale_api_key: str = "Awesome API"
|
llm_api_key: str = "Awesome API"
|
||||||
|
llm_base_url: str = "Awessome Endpoint"
|
||||||
gcs_access: str = "access"
|
gcs_access: str = "access"
|
||||||
gcs_secret: str = "secret"
|
gcs_secret: str = "secret"
|
||||||
gcs_bucketname: str = "bucket"
|
gcs_bucketname: str = "bucket"
|
||||||
|
|
|
@ -22,7 +22,6 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key)
|
||||||
async def homepage(request: Request):
|
async def homepage(request: Request):
|
||||||
user = request.session.get("user")
|
user = request.session.get("user")
|
||||||
if user:
|
if user:
|
||||||
print(json.dumps(user))
|
|
||||||
return RedirectResponse("/app")
|
return RedirectResponse("/app")
|
||||||
|
|
||||||
with open(static_path / "login.html") as f:
|
with open(static_path / "login.html") as f:
|
||||||
|
|
|
@ -1,14 +1,34 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatAnyscale
|
from langchain_fireworks import Fireworks
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
class AvailableModels(StrEnum):
|
class AvailableModels(StrEnum):
|
||||||
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
|
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
|
||||||
llama3_70b = "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
||||||
# Function calling model
|
# Function calling model
|
||||||
mixtral_8x7b = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct"
|
||||||
|
mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct"
|
||||||
|
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||||
|
|
||||||
|
|
||||||
|
general_prompt = """
|
||||||
|
You're a helpful assistant. Perform the following tasks:
|
||||||
|
|
||||||
|
----
|
||||||
|
{query}
|
||||||
|
----
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql_prompt = """
|
||||||
|
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
||||||
|
|
||||||
|
----
|
||||||
|
{query}
|
||||||
|
----
|
||||||
|
|
||||||
|
Return only the sql statement without any additional text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
|
@ -25,33 +45,36 @@ class Chat:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: AvailableModels = AvailableModels.llama3_8b,
|
model: AvailableModels = AvailableModels.mixtral_8x7b,
|
||||||
api_key: str = "",
|
api_key: str = "",
|
||||||
temperature: float = 0.5,
|
temperature: float = 0.5,
|
||||||
):
|
):
|
||||||
self.model_name = model
|
self.model = model
|
||||||
self.api_key = self.raise_no_key(api_key)
|
self.api_key = self.raise_no_key(api_key)
|
||||||
self.messages = []
|
self.messages = []
|
||||||
self.responses = []
|
self.responses = []
|
||||||
|
|
||||||
self.model: ChatAnyscale = ChatAnyscale(
|
self.model: Fireworks = Fireworks(
|
||||||
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
|
model=model, temperature=temperature, api_key=self.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
async def eval(self, system: str, human: str):
|
async def eval(self, task):
|
||||||
self.messages.append(
|
prompt = PromptTemplate.from_template(general_prompt)
|
||||||
[
|
|
||||||
SystemMessage(content=system),
|
|
||||||
HumanMessage(content=human),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await self.model.ainvoke(self.messages[-1])
|
response = await self.model.ainvoke(prompt.format(query=task))
|
||||||
|
self.responses.append(response)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def sql_eval(self, question):
|
||||||
|
prompt = PromptTemplate.from_template(sql_prompt)
|
||||||
|
|
||||||
|
response = await self.model.ainvoke(prompt.format(query=question))
|
||||||
self.responses.append(response)
|
self.responses.append(response)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def last_response_content(self):
|
def last_response_content(self):
|
||||||
return self.responses[-1].content
|
last_response = self.responses[-1]
|
||||||
|
return last_response
|
||||||
|
|
||||||
def last_response_metadata(self):
|
def last_response_metadata(self):
|
||||||
return self.responses[-1].response_metadata
|
return self.responses[-1].response_metadata
|
||||||
|
|
|
@ -13,7 +13,7 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "", q: str = "") -> str:
|
async def query(sid: str = "", q: str = "") -> str:
|
||||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
llm = Chat(api_key=settings.llm_api_key, temperature=0.5)
|
||||||
db = SessionDB(
|
db = SessionDB(
|
||||||
StorageEngines.gcs,
|
StorageEngines.gcs,
|
||||||
gcs_access=settings.gcs_access,
|
gcs_access=settings.gcs_access,
|
||||||
|
|
|
@ -20,45 +20,45 @@ SID = "test"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
|
||||||
)
|
|
||||||
async def test_chat_simple():
|
async def test_chat_simple():
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
|
chat = Chat(api_key=settings.llm_api_key, temperature=0)
|
||||||
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
|
chat = await chat.eval("Say literlly 'Hello'")
|
||||||
assert chat.last_response_content() == "Hello!"
|
assert "Hello" in chat.last_response_content()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
|
||||||
)
|
|
||||||
async def test_simple_data_query():
|
async def test_simple_data_query():
|
||||||
query = "write a query that finds the average score of all students in the current database"
|
query = "write a query that finds the average score of all students in the current database"
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
db = SessionDB(
|
db = SessionDB(
|
||||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
||||||
).load_xls(TEST_XLS_PATH)
|
).load_xls(TEST_XLS_PATH)
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
chat = await chat.sql_eval(db.query_prompt(query))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
assert query.startswith("SELECT")
|
assert query.startswith("SELECT")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
|
||||||
)
|
|
||||||
async def test_data_query():
|
async def test_data_query():
|
||||||
q = "find the average score of all the sudents"
|
q = "find the average score of all the sudents"
|
||||||
|
|
||||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
llm = Chat(
|
||||||
|
api_key=settings.llm_api_key,
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
db = SessionDB(
|
db = SessionDB(
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
).load_folder()
|
).load_folder()
|
||||||
|
|
||||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
chat = await llm.sql_eval(db.query_prompt(q))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
result = db.query(query).pl()
|
result = db.query(query).pl()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue