Compare commits

...

3 commits

Author SHA1 Message Date
Guillem Borrell 181bc92884 Successfully implemented function calling for anyscale
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
2024-06-17 22:18:18 +02:00
Guillem Borrell e43f81df39 Some refactoring 2024-06-16 09:31:33 +02:00
Guillem Borrell e8c7600ed7 Better tests and docs 2024-06-16 08:56:45 +02:00
16 changed files with 226 additions and 39 deletions

137
notebooks/tasks.ipynb Normal file
View file

@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from hellocomputer.config import settings\n",
"from langchain_core.utils.function_calling import convert_to_openai_function\n",
"import openai\n",
"from operator import itemgetter"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"def add(a: int, b: int) -> int:\n",
" \"\"\"Adds a and b.\"\"\"\n",
" return a + b\n",
"\n",
"\n",
"@tool\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiplies a and b.\"\"\"\n",
" return a * b\n",
"\n",
"\n",
"tools = [convert_to_openai_function(t) for t in [add, multiply]]\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"tools_fmt = [\n",
" {\"type\": \"function\",\n",
" \"function\": tools[0]},\n",
" {\"type\": \"function\",\n",
" \"function\": tools[1]}\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n",
"]\n",
"\n",
"client = openai.OpenAI(\n",
" base_url = \"https://api.endpoints.anyscale.com/v1\",\n",
" api_key = settings.anyscale_api_key\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
" messages=messages,\n",
" tools=tools_fmt,\n",
" tool_choice=\"auto\", # auto is default, but we'll be explicit\n",
")\n",
"\n",
"get_args = itemgetter(\"arguments\")\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"call = response.choices[0].message.tool_calls[0].function"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"add.func(2,3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -6,9 +6,12 @@ pydantic-settings
s3fs
aiofiles
duckdb
duckdb-engine
polars
pyarrow
pyjwt[crypto]
python-multipart
authlib
itsdangerous
sqlalchemy
openai

View file

@ -1,7 +1,7 @@
from enum import StrEnum
from pathlib import Path
import duckdb
from sqlalchemy import create_engine, text
class StorageEngines(StrEnum):
@ -19,11 +19,13 @@ class DDB:
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.engine = create_engine(
"duckdb:///:memory:",
connect_args={
"preload_extensions": ["https", "spatial"],
"config": {"memory_limit": "300mb"},
},
)
self.sheets = tuple()
self.loaded = False
@ -35,12 +37,18 @@ class DDB:
bucket is not None,
)
):
self.db.sql(f"""
with self.engine.connect() as conn:
conn.execute(
text(
f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
"""
)
)
self.path_prefix = f"gcs://{bucket}"
else:
raise ValueError(
@ -55,3 +63,7 @@ class DDB:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
@property
def db(self):
return self.engine.raw_connection()

View file

@ -6,7 +6,7 @@ from typing_extensions import Self
from hellocomputer.db import StorageEngines
from .db import DDB
from . import DDB
class SessionDB(DDB):
@ -149,7 +149,7 @@ class SessionDB(DDB):
)
@property
def schema(self):
def schema(self) -> str:
return os.linesep.join(
[
"The schema of the database is the following:",

View file

@ -8,7 +8,7 @@ from uuid import UUID, uuid4
import duckdb
import polars as pl
from .db import DDB, StorageEngines
from . import DDB, StorageEngines
class UserDB(DDB):
@ -99,10 +99,10 @@ class OwnershipDB(DDB):
FROM
'{self.path_prefix}/*.csv'
WHERE
email = '{user_email}
email = '{user_email}'
ORDER BY
timestamp ASC
LIMIT 10'
LIMIT 10
""")
.pl()
.to_series()

View file

@ -2,8 +2,8 @@ from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from hellocomputer.db import StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block
from hellocomputer.sessions import SessionDB
from ..config import settings
from ..models import Chat
@ -13,7 +13,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
@ -22,9 +22,8 @@ async def query(sid: str = "", q: str = "") -> str:
sid=sid,
).load_folder()
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q))
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = str(db.query(query))
print(result)
return result

View file

@ -5,7 +5,7 @@ from starlette.requests import Request
from hellocomputer.config import settings
from hellocomputer.db import StorageEngines
from hellocomputer.users import UserDB
from hellocomputer.db.users import UserDB
router = APIRouter()

View file

@ -7,8 +7,8 @@ from starlette.requests import Request
from ..config import settings
from ..db import StorageEngines
from ..sessions import SessionDB
from ..users import OwnershipDB
from ..db.sessions import SessionDB
from ..db.users import OwnershipDB
router = APIRouter()

View file

@ -1,12 +1,12 @@
from typing import List
from uuid import uuid4
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from starlette.requests import Request
from typing import List
from hellocomputer.db import StorageEngines
from hellocomputer.users import OwnershipDB
from hellocomputer.db.users import OwnershipDB
from ..config import settings

View file

@ -37,7 +37,7 @@
<p class="techie-font">
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
not be as powerful as Excel, but it has an efficient query backend that can process your data faster
and more efficiently than Excel.
than Excel.
</p>
<a href="/"><button class="btn btn-secondary w-100">Back</button></a>
</div>

View file

@ -25,9 +25,11 @@
</p>
<a href="#" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-question-circle"></i> How to</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i>
<a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-file-ruled"></i>
File templates</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i>
<a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-info-circle"></i>
About</a>
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
Config</a>

View file

@ -71,9 +71,7 @@ function addAIManualMessage(m) {
chatMessages.prepend(newMessage); // Add new message at the top
}
function addUserMessage() {
const messageContent = textarea.value.trim();
if (messageContent) {
function addUserMessageBlock(messageContent) {
const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
newMessage.textContent = messageContent;
@ -81,8 +79,20 @@ function addUserMessage() {
textarea.value = ''; // Clear the textarea
textarea.style.height = 'auto'; // Reset the textarea height
textarea.style.overflowY = 'hidden';
};
function addUserMessage() {
const messageContent = textarea.value.trim();
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
textarea.value = '';
addAIManualMessage('Please upload a data file or select a session first!');
}
else {
if (messageContent) {
addUserMessageBlock(messageContent);
addAIMessage(messageContent);
}
}
};
sendButton.addEventListener('click', addUserMessage);
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
try {
const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
sessionStorage.setItem("helloComputerSessionLoaded", false);
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
sessionStorage.setItem("helloComputerSessionLoaded", true);
addAIManualMessage('File uploaded and processed!');
} catch (error) {

View file

@ -2,7 +2,7 @@ from pathlib import Path
import hellocomputer
from hellocomputer.db import StorageEngines
from hellocomputer.sessions import SessionDB
from hellocomputer.db.sessions import SessionDB
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (

View file

@ -6,14 +6,17 @@ from hellocomputer.config import settings
from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
from hellocomputer.sessions import SessionDB
from hellocomputer.db.sessions import SessionDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
SID = "test"
@pytest.mark.asyncio
@ -35,9 +38,28 @@ async def test_simple_data_query():
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_data_query():
q = "find the average score of all the sudents"
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = db.query(query).pl()
assert result.shape[0] == 1

View file

@ -2,7 +2,7 @@ from pathlib import Path
import hellocomputer
from hellocomputer.db import StorageEngines
from hellocomputer.users import OwnershipDB, UserDB
from hellocomputer.db.users import OwnershipDB, UserDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"