This commit is contained in:
		
							parent
							
								
									181bc92884
								
							
						
					
					
						commit
						98a713b3c7
					
				|  | @ -19,7 +19,8 @@ You'll need the following environment variables in a .env file: | ||||||
| 
 | 
 | ||||||
| * `GCS_ACCESS` | * `GCS_ACCESS` | ||||||
| * `GCS_SECRET` | * `GCS_SECRET` | ||||||
| * `ANYSCALE_API_KEY` | * `LLM_API_KEY` | ||||||
|  | * `LLM_BASE_URL` | ||||||
| * `GCS_BUCKETNAME` | * `GCS_BUCKETNAME` | ||||||
| 
 | 
 | ||||||
| And to get the application up and running... | And to get the application up and running... | ||||||
|  |  | ||||||
							
								
								
									
										140
									
								
								notebooks/newtasks.ipynb
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								notebooks/newtasks.ipynb
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,140 @@ | ||||||
|  | { | ||||||
|  |  "cells": [ | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 1, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "import openai\n", | ||||||
|  |     "import json" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 2, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [ | ||||||
|  |     "\n", | ||||||
|  |     "client = openai.OpenAI(\n", | ||||||
|  |     "    base_url = \"https://api.fireworks.ai/inference/v1\",\n", | ||||||
|  |     "    api_key = \"vQdRZPGX7Mvd9XEAIP8VAe5w1comMroY765vMfHW9rqbS48I\"\n", | ||||||
|  |     ")\n", | ||||||
|  |     "\n", | ||||||
|  |     "messages = [\n", | ||||||
|  |     "    {\"role\": \"system\", \"content\": f\"You are a helpful assistant with access to functions.\" \n", | ||||||
|  |     "     \t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\"Use them if required.\"},\n", | ||||||
|  |     "    {\"role\": \"user\", \"content\": \"What are Nike's net income in 2022?\"}\n", | ||||||
|  |     "]\n", | ||||||
|  |     "\n", | ||||||
|  |     "tools = [\n", | ||||||
|  |     "    {\n", | ||||||
|  |     "        \"type\": \"function\",\n", | ||||||
|  |     "        \"function\": {\n", | ||||||
|  |     "            # name of the function \n", | ||||||
|  |     "            \"name\": \"get_financial_data\",\n", | ||||||
|  |     "            # a good, detailed description for what the function is supposed to do\n", | ||||||
|  |     "            \"description\": \"Get financial data for a company given the metric and year.\",\n", | ||||||
|  |     "            # a well defined json schema: https://json-schema.org/learn/getting-started-step-by-step#define\n", | ||||||
|  |     "            \"parameters\": {\n", | ||||||
|  |     "                # for OpenAI compatibility, we always declare a top level object for the parameters of the function\n", | ||||||
|  |     "                \"type\": \"object\",\n", | ||||||
|  |     "                # the properties for the object would be any arguments you want to provide to the function\n", | ||||||
|  |     "                \"properties\": {\n", | ||||||
|  |     "                    \"metric\": {\n", | ||||||
|  |     "                        # JSON Schema supports string, number, integer, object, array, boolean and null\n", | ||||||
|  |     "                        # for more information, please check out https://json-schema.org/understanding-json-schema/reference/type\n", | ||||||
|  |     "                        \"type\": \"string\",\n", | ||||||
|  |     "                        # You can restrict the space of possible values in an JSON Schema\n", | ||||||
|  |     "                        # you can check out https://json-schema.org/understanding-json-schema/reference/enum for more examples on how enum works\n", | ||||||
|  |     "                        \"enum\": [\"net_income\", \"revenue\", \"ebdita\"],\n", | ||||||
|  |     "                    },\n", | ||||||
|  |     "                    \"financial_year\": {\n", | ||||||
|  |     "                        \"type\": \"integer\", \n", | ||||||
|  |     "                        # If the model does not understand how it is supposed to fill the field, a good description goes a long way \n", | ||||||
|  |     "                        \"description\": \"Year for which we want to get financial data.\"\n", | ||||||
|  |     "                    },\n", | ||||||
|  |     "                    \"company\": {\n", | ||||||
|  |     "                        \"type\": \"string\",\n", | ||||||
|  |     "                        \"description\": \"Name of the company for which we want to get financial data.\"\n", | ||||||
|  |     "                    }\n", | ||||||
|  |     "                },\n", | ||||||
|  |     "                # You can specify which of the properties from above are required\n", | ||||||
|  |     "                # for more info on `required` field, please check https://json-schema.org/understanding-json-schema/reference/object#required\n", | ||||||
|  |     "                \"required\": [\"metric\", \"financial_year\", \"company\"],\n", | ||||||
|  |     "            },\n", | ||||||
|  |     "        },\n", | ||||||
|  |     "    }\n", | ||||||
|  |     "]\n" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": 3, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [ | ||||||
|  |     { | ||||||
|  |      "name": "stdout", | ||||||
|  |      "output_type": "stream", | ||||||
|  |      "text": [ | ||||||
|  |       "{\n", | ||||||
|  |       "    \"content\": null,\n", | ||||||
|  |       "    \"role\": \"assistant\",\n", | ||||||
|  |       "    \"function_call\": null,\n", | ||||||
|  |       "    \"tool_calls\": [\n", | ||||||
|  |       "        {\n", | ||||||
|  |       "            \"id\": \"call_W4IUqQRFF9vYINQ74tfBwmqr\",\n", | ||||||
|  |       "            \"function\": {\n", | ||||||
|  |       "                \"arguments\": \"{\\\"metric\\\": \\\"net_income\\\", \\\"financial_year\\\": 2022, \\\"company\\\": \\\"Nike\\\"}\",\n", | ||||||
|  |       "                \"name\": \"get_financial_data\"\n", | ||||||
|  |       "            },\n", | ||||||
|  |       "            \"type\": \"function\",\n", | ||||||
|  |       "            \"index\": 0\n", | ||||||
|  |       "        }\n", | ||||||
|  |       "    ]\n", | ||||||
|  |       "}\n" | ||||||
|  |      ] | ||||||
|  |     } | ||||||
|  |    ], | ||||||
|  |    "source": [ | ||||||
|  |     "\n", | ||||||
|  |     "chat_completion = client.chat.completions.create(\n", | ||||||
|  |     "    model=\"accounts/fireworks/models/firefunction-v2\",\n", | ||||||
|  |     "    messages=messages,\n", | ||||||
|  |     "    tools=tools,\n", | ||||||
|  |     "    temperature=0.1\n", | ||||||
|  |     ")\n", | ||||||
|  |     "print(chat_completion.choices[0].message.model_dump_json(indent=4))\n" | ||||||
|  |    ] | ||||||
|  |   }, | ||||||
|  |   { | ||||||
|  |    "cell_type": "code", | ||||||
|  |    "execution_count": null, | ||||||
|  |    "metadata": {}, | ||||||
|  |    "outputs": [], | ||||||
|  |    "source": [] | ||||||
|  |   } | ||||||
|  |  ], | ||||||
|  |  "metadata": { | ||||||
|  |   "kernelspec": { | ||||||
|  |    "display_name": ".venv", | ||||||
|  |    "language": "python", | ||||||
|  |    "name": "python3" | ||||||
|  |   }, | ||||||
|  |   "language_info": { | ||||||
|  |    "codemirror_mode": { | ||||||
|  |     "name": "ipython", | ||||||
|  |     "version": 3 | ||||||
|  |    }, | ||||||
|  |    "file_extension": ".py", | ||||||
|  |    "mimetype": "text/x-python", | ||||||
|  |    "name": "python", | ||||||
|  |    "nbconvert_exporter": "python", | ||||||
|  |    "pygments_lexer": "ipython3", | ||||||
|  |    "version": "3.12.4" | ||||||
|  |   } | ||||||
|  |  }, | ||||||
|  |  "nbformat": 4, | ||||||
|  |  "nbformat_minor": 2 | ||||||
|  | } | ||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							|  | @ -1,5 +1,6 @@ | ||||||
| langchain | langchain | ||||||
| langchain-community | langchain-community | ||||||
|  | langchain-fireworks | ||||||
| openai | openai | ||||||
| fastapi | fastapi | ||||||
| pydantic-settings | pydantic-settings | ||||||
|  |  | ||||||
|  | @ -1,13 +1,15 @@ | ||||||
| # This file was autogenerated by uv via the following command: | # This file was autogenerated by uv via the following command: | ||||||
| #    uv pip compile requirements.in | #    uv pip compile requirements.in | ||||||
| aiobotocore==2.13.0 | aiobotocore==2.13.1 | ||||||
|     # via s3fs |     # via s3fs | ||||||
| aiofiles==23.2.1 | aiofiles==24.1.0 | ||||||
|  |     # via -r requirements.in | ||||||
| aiohttp==3.9.5 | aiohttp==3.9.5 | ||||||
|     # via |     # via | ||||||
|     #   aiobotocore |     #   aiobotocore | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
|  |     #   langchain-fireworks | ||||||
|     #   s3fs |     #   s3fs | ||||||
| aioitertools==0.11.0 | aioitertools==0.11.0 | ||||||
|     # via aiobotocore |     # via aiobotocore | ||||||
|  | @ -24,9 +26,10 @@ anyio==4.4.0 | ||||||
| attrs==23.2.0 | attrs==23.2.0 | ||||||
|     # via aiohttp |     # via aiohttp | ||||||
| authlib==1.3.1 | authlib==1.3.1 | ||||||
| botocore==1.34.106 |     # via -r requirements.in | ||||||
|  | botocore==1.34.131 | ||||||
|     # via aiobotocore |     # via aiobotocore | ||||||
| certifi==2024.6.2 | certifi==2024.7.4 | ||||||
|     # via |     # via | ||||||
|     #   httpcore |     #   httpcore | ||||||
|     #   httpx |     #   httpx | ||||||
|  | @ -50,16 +53,24 @@ distro==1.9.0 | ||||||
| dnspython==2.6.1 | dnspython==2.6.1 | ||||||
|     # via email-validator |     # via email-validator | ||||||
| duckdb==1.0.0 | duckdb==1.0.0 | ||||||
| email-validator==2.1.1 |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   duckdb-engine | ||||||
|  | duckdb-engine==0.13.0 | ||||||
|  |     # via -r requirements.in | ||||||
|  | email-validator==2.2.0 | ||||||
|     # via fastapi |     # via fastapi | ||||||
| fastapi==0.111.0 | fastapi==0.111.0 | ||||||
|  |     # via -r requirements.in | ||||||
| fastapi-cli==0.0.4 | fastapi-cli==0.0.4 | ||||||
|     # via fastapi |     # via fastapi | ||||||
|  | fireworks-ai==0.14.0 | ||||||
|  |     # via langchain-fireworks | ||||||
| frozenlist==1.4.1 | frozenlist==1.4.1 | ||||||
|     # via |     # via | ||||||
|     #   aiohttp |     #   aiohttp | ||||||
|     #   aiosignal |     #   aiosignal | ||||||
| fsspec==2024.6.0 | fsspec==2024.6.1 | ||||||
|     # via s3fs |     # via s3fs | ||||||
| h11==0.14.0 | h11==0.14.0 | ||||||
|     # via |     # via | ||||||
|  | @ -72,7 +83,10 @@ httptools==0.6.1 | ||||||
| httpx==0.27.0 | httpx==0.27.0 | ||||||
|     # via |     # via | ||||||
|     #   fastapi |     #   fastapi | ||||||
|  |     #   fireworks-ai | ||||||
|     #   openai |     #   openai | ||||||
|  | httpx-sse==0.4.0 | ||||||
|  |     # via fireworks-ai | ||||||
| idna==3.7 | idna==3.7 | ||||||
|     # via |     # via | ||||||
|     #   anyio |     #   anyio | ||||||
|  | @ -81,6 +95,7 @@ idna==3.7 | ||||||
|     #   requests |     #   requests | ||||||
|     #   yarl |     #   yarl | ||||||
| itsdangerous==2.2.0 | itsdangerous==2.2.0 | ||||||
|  |     # via -r requirements.in | ||||||
| jinja2==3.1.4 | jinja2==3.1.4 | ||||||
|     # via fastapi |     # via fastapi | ||||||
| jmespath==1.0.1 | jmespath==1.0.1 | ||||||
|  | @ -89,17 +104,23 @@ jsonpatch==1.33 | ||||||
|     # via langchain-core |     # via langchain-core | ||||||
| jsonpointer==3.0.0 | jsonpointer==3.0.0 | ||||||
|     # via jsonpatch |     # via jsonpatch | ||||||
| langchain==0.2.3 | langchain==0.2.7 | ||||||
|     # via langchain-community |     # via | ||||||
| langchain-community==0.2.4 |     #   -r requirements.in | ||||||
| langchain-core==0.2.5 |     #   langchain-community | ||||||
|  | langchain-community==0.2.7 | ||||||
|  |     # via -r requirements.in | ||||||
|  | langchain-core==0.2.17 | ||||||
|     # via |     # via | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
|  |     #   langchain-fireworks | ||||||
|     #   langchain-text-splitters |     #   langchain-text-splitters | ||||||
| langchain-text-splitters==0.2.1 | langchain-fireworks==0.1.5 | ||||||
|  |     # via -r requirements.in | ||||||
|  | langchain-text-splitters==0.2.2 | ||||||
|     # via langchain |     # via langchain | ||||||
| langsmith==0.1.76 | langsmith==0.1.85 | ||||||
|     # via |     # via | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
|  | @ -123,33 +144,44 @@ numpy==1.26.4 | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
|     #   pyarrow |     #   pyarrow | ||||||
| openai==1.33.0 | openai==1.35.13 | ||||||
| orjson==3.10.4 |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   langchain-fireworks | ||||||
|  | orjson==3.10.6 | ||||||
|     # via |     # via | ||||||
|     #   fastapi |     #   fastapi | ||||||
|     #   langsmith |     #   langsmith | ||||||
| packaging==23.2 | packaging==24.1 | ||||||
|     # via |     # via | ||||||
|  |     #   duckdb-engine | ||||||
|     #   langchain-core |     #   langchain-core | ||||||
|     #   marshmallow |     #   marshmallow | ||||||
| polars==0.20.31 | pillow==10.4.0 | ||||||
|  |     # via fireworks-ai | ||||||
|  | polars==1.1.0 | ||||||
|  |     # via -r requirements.in | ||||||
| pyarrow==16.1.0 | pyarrow==16.1.0 | ||||||
|  |     # via -r requirements.in | ||||||
| pycparser==2.22 | pycparser==2.22 | ||||||
|     # via cffi |     # via cffi | ||||||
| pydantic==2.7.3 | pydantic==2.8.2 | ||||||
|     # via |     # via | ||||||
|     #   fastapi |     #   fastapi | ||||||
|  |     #   fireworks-ai | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-core |     #   langchain-core | ||||||
|     #   langsmith |     #   langsmith | ||||||
|     #   openai |     #   openai | ||||||
|     #   pydantic-settings |     #   pydantic-settings | ||||||
| pydantic-core==2.18.4 | pydantic-core==2.20.1 | ||||||
|     # via pydantic |     # via pydantic | ||||||
| pydantic-settings==2.3.2 | pydantic-settings==2.3.4 | ||||||
|  |     # via -r requirements.in | ||||||
| pygments==2.18.0 | pygments==2.18.0 | ||||||
|     # via rich |     # via rich | ||||||
| pyjwt==2.8.0 | pyjwt==2.8.0 | ||||||
|  |     # via -r requirements.in | ||||||
| python-dateutil==2.9.0.post0 | python-dateutil==2.9.0.post0 | ||||||
|     # via botocore |     # via botocore | ||||||
| python-dotenv==1.0.1 | python-dotenv==1.0.1 | ||||||
|  | @ -157,7 +189,9 @@ python-dotenv==1.0.1 | ||||||
|     #   pydantic-settings |     #   pydantic-settings | ||||||
|     #   uvicorn |     #   uvicorn | ||||||
| python-multipart==0.0.9 | python-multipart==0.0.9 | ||||||
|     # via fastapi |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   fastapi | ||||||
| pyyaml==6.0.1 | pyyaml==6.0.1 | ||||||
|     # via |     # via | ||||||
|     #   langchain |     #   langchain | ||||||
|  | @ -168,10 +202,12 @@ requests==2.32.3 | ||||||
|     # via |     # via | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
|  |     #   langchain-fireworks | ||||||
|     #   langsmith |     #   langsmith | ||||||
| rich==13.7.1 | rich==13.7.1 | ||||||
|     # via typer |     # via typer | ||||||
| s3fs==2024.6.0 | s3fs==2024.6.1 | ||||||
|  |     # via -r requirements.in | ||||||
| shellingham==1.5.4 | shellingham==1.5.4 | ||||||
|     # via typer |     # via typer | ||||||
| six==1.16.0 | six==1.16.0 | ||||||
|  | @ -181,13 +217,15 @@ sniffio==1.3.1 | ||||||
|     #   anyio |     #   anyio | ||||||
|     #   httpx |     #   httpx | ||||||
|     #   openai |     #   openai | ||||||
| sqlalchemy==2.0.30 | sqlalchemy==2.0.31 | ||||||
|     # via |     # via | ||||||
|  |     #   -r requirements.in | ||||||
|  |     #   duckdb-engine | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
| starlette==0.37.2 | starlette==0.37.2 | ||||||
|     # via fastapi |     # via fastapi | ||||||
| tenacity==8.3.0 | tenacity==8.5.0 | ||||||
|     # via |     # via | ||||||
|     #   langchain |     #   langchain | ||||||
|     #   langchain-community |     #   langchain-community | ||||||
|  | @ -209,7 +247,7 @@ typing-inspect==0.9.0 | ||||||
|     # via dataclasses-json |     # via dataclasses-json | ||||||
| ujson==5.10.0 | ujson==5.10.0 | ||||||
|     # via fastapi |     # via fastapi | ||||||
| urllib3==2.2.1 | urllib3==2.2.2 | ||||||
|     # via |     # via | ||||||
|     #   botocore |     #   botocore | ||||||
|     #   requests |     #   requests | ||||||
|  |  | ||||||
|  | @ -3,7 +3,8 @@ from pydantic_settings import BaseSettings, SettingsConfigDict | ||||||
| 
 | 
 | ||||||
| class Settings(BaseSettings): | class Settings(BaseSettings): | ||||||
|     base_url: str = "http://localhost:8000" |     base_url: str = "http://localhost:8000" | ||||||
|     anyscale_api_key: str = "Awesome API" |     llm_api_key: str = "Awesome API" | ||||||
|  |     llm_base_url: str = "Awessome Endpoint" | ||||||
|     gcs_access: str = "access" |     gcs_access: str = "access" | ||||||
|     gcs_secret: str = "secret" |     gcs_secret: str = "secret" | ||||||
|     gcs_bucketname: str = "bucket" |     gcs_bucketname: str = "bucket" | ||||||
|  |  | ||||||
|  | @ -22,7 +22,6 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key) | ||||||
| async def homepage(request: Request): | async def homepage(request: Request): | ||||||
|     user = request.session.get("user") |     user = request.session.get("user") | ||||||
|     if user: |     if user: | ||||||
|         print(json.dumps(user)) |  | ||||||
|         return RedirectResponse("/app") |         return RedirectResponse("/app") | ||||||
| 
 | 
 | ||||||
|     with open(static_path / "login.html") as f: |     with open(static_path / "login.html") as f: | ||||||
|  |  | ||||||
|  | @ -1,14 +1,34 @@ | ||||||
| from enum import StrEnum | from enum import StrEnum | ||||||
| 
 | 
 | ||||||
| from langchain_community.chat_models import ChatAnyscale | from langchain_fireworks import Fireworks | ||||||
| from langchain_core.messages import HumanMessage, SystemMessage | from langchain_core.prompts import PromptTemplate | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class AvailableModels(StrEnum): | class AvailableModels(StrEnum): | ||||||
|     llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct" |     llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" | ||||||
|     llama3_70b = "meta-llama/Meta-Llama-3-70B-Instruct" |  | ||||||
|     # Function calling model |     # Function calling model | ||||||
|     mixtral_8x7b = "mistralai/Mixtral-8x7B-Instruct-v0.1" |     mixtral_8x7b = "accounts/fireworks/models/mixtral-8x7b-instruct" | ||||||
|  |     mixtral_8x22b = "accounts/fireworks/models/mixtral-8x22b-instruct" | ||||||
|  |     firefunction_2 = "accounts/fireworks/models/firefunction-v2" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | general_prompt = """ | ||||||
|  | You're a helpful assistant. Perform the following tasks: | ||||||
|  | 
 | ||||||
|  | ---- | ||||||
|  | {query} | ||||||
|  | ---- | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | sql_prompt = """ | ||||||
|  | You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: | ||||||
|  | 
 | ||||||
|  | ---- | ||||||
|  | {query} | ||||||
|  | ---- | ||||||
|  | 
 | ||||||
|  | Return only the sql statement without any additional text. | ||||||
|  | """ | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Chat: | class Chat: | ||||||
|  | @ -25,33 +45,36 @@ class Chat: | ||||||
| 
 | 
 | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         model: AvailableModels = AvailableModels.llama3_8b, |         model: AvailableModels = AvailableModels.mixtral_8x7b, | ||||||
|         api_key: str = "", |         api_key: str = "", | ||||||
|         temperature: float = 0.5, |         temperature: float = 0.5, | ||||||
|     ): |     ): | ||||||
|         self.model_name = model |         self.model = model | ||||||
|         self.api_key = self.raise_no_key(api_key) |         self.api_key = self.raise_no_key(api_key) | ||||||
|         self.messages = [] |         self.messages = [] | ||||||
|         self.responses = [] |         self.responses = [] | ||||||
| 
 | 
 | ||||||
|         self.model: ChatAnyscale = ChatAnyscale( |         self.model: Fireworks = Fireworks( | ||||||
|             model_name=model, temperature=temperature, anyscale_api_key=self.api_key |             model=model, temperature=temperature, api_key=self.api_key | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     async def eval(self, system: str, human: str): |     async def eval(self, task): | ||||||
|         self.messages.append( |         prompt = PromptTemplate.from_template(general_prompt) | ||||||
|             [ |  | ||||||
|                 SystemMessage(content=system), |  | ||||||
|                 HumanMessage(content=human), |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
| 
 | 
 | ||||||
|         response = await self.model.ainvoke(self.messages[-1]) |         response = await self.model.ainvoke(prompt.format(query=task)) | ||||||
|  |         self.responses.append(response) | ||||||
|  |         return self | ||||||
|  | 
 | ||||||
|  |     async def sql_eval(self, question): | ||||||
|  |         prompt = PromptTemplate.from_template(sql_prompt) | ||||||
|  | 
 | ||||||
|  |         response = await self.model.ainvoke(prompt.format(query=question)) | ||||||
|         self.responses.append(response) |         self.responses.append(response) | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|     def last_response_content(self): |     def last_response_content(self): | ||||||
|         return self.responses[-1].content |         last_response = self.responses[-1] | ||||||
|  |         return last_response | ||||||
| 
 | 
 | ||||||
|     def last_response_metadata(self): |     def last_response_metadata(self): | ||||||
|         return self.responses[-1].response_metadata |         return self.responses[-1].response_metadata | ||||||
|  |  | ||||||
|  | @ -13,7 +13,7 @@ router = APIRouter() | ||||||
| 
 | 
 | ||||||
| @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) | @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) | ||||||
| async def query(sid: str = "", q: str = "") -> str: | async def query(sid: str = "", q: str = "") -> str: | ||||||
|     llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) |     llm = Chat(api_key=settings.llm_api_key, temperature=0.5) | ||||||
|     db = SessionDB( |     db = SessionDB( | ||||||
|         StorageEngines.gcs, |         StorageEngines.gcs, | ||||||
|         gcs_access=settings.gcs_access, |         gcs_access=settings.gcs_access, | ||||||
|  |  | ||||||
|  | @ -20,45 +20,45 @@ SID = "test" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") | ||||||
|     settings.anyscale_api_key == "Awesome API", reason="API Key not set" |  | ||||||
| ) |  | ||||||
| async def test_chat_simple(): | async def test_chat_simple(): | ||||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0) |     chat = Chat(api_key=settings.llm_api_key, temperature=0) | ||||||
|     chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'") |     chat = await chat.eval("Say literlly 'Hello'") | ||||||
|     assert chat.last_response_content() == "Hello!" |     assert "Hello" in chat.last_response_content() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") | ||||||
|     settings.anyscale_api_key == "Awesome API", reason="API Key not set" |  | ||||||
| ) |  | ||||||
| async def test_simple_data_query(): | async def test_simple_data_query(): | ||||||
|     query = "write a query that finds the average score of all students in the current database" |     query = "write a query that finds the average score of all students in the current database" | ||||||
| 
 | 
 | ||||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) |     chat = Chat( | ||||||
|  |         api_key=settings.llm_api_key, | ||||||
|  |         temperature=0.5, | ||||||
|  |     ) | ||||||
|     db = SessionDB( |     db = SessionDB( | ||||||
|         storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID |         storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID | ||||||
|     ).load_xls(TEST_XLS_PATH) |     ).load_xls(TEST_XLS_PATH) | ||||||
| 
 | 
 | ||||||
|     chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) |     chat = await chat.sql_eval(db.query_prompt(query)) | ||||||
|     query = extract_code_block(chat.last_response_content()) |     query = extract_code_block(chat.last_response_content()) | ||||||
|     assert query.startswith("SELECT") |     assert query.startswith("SELECT") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| @pytest.mark.skipif( | @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") | ||||||
|     settings.anyscale_api_key == "Awesome API", reason="API Key not set" |  | ||||||
| ) |  | ||||||
| async def test_data_query(): | async def test_data_query(): | ||||||
|     q = "find the average score of all the sudents" |     q = "find the average score of all the sudents" | ||||||
| 
 | 
 | ||||||
|     llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) |     llm = Chat( | ||||||
|  |         api_key=settings.llm_api_key, | ||||||
|  |         temperature=0.5, | ||||||
|  |     ) | ||||||
|     db = SessionDB( |     db = SessionDB( | ||||||
|         storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" |         storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" | ||||||
|     ).load_folder() |     ).load_folder() | ||||||
| 
 | 
 | ||||||
|     chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) |     chat = await llm.sql_eval(db.query_prompt(q)) | ||||||
|     query = extract_code_block(chat.last_response_content()) |     query = extract_code_block(chat.last_response_content()) | ||||||
|     result = db.query(query).pl() |     result = db.query(query).pl() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in a new issue