Compare commits

...

3 commits

Author SHA1 Message Date
Guillem Borrell 181bc92884 Successfully implemented function calling for anyscale
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
2024-06-17 22:18:18 +02:00
Guillem Borrell e43f81df39 Some refactoring 2024-06-16 09:31:33 +02:00
Guillem Borrell e8c7600ed7 Better tests and docs 2024-06-16 08:56:45 +02:00
16 changed files with 226 additions and 39 deletions

137
notebooks/tasks.ipynb Normal file
View file

@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from hellocomputer.config import settings\n",
"from langchain_core.utils.function_calling import convert_to_openai_function\n",
"import openai\n",
"from operator import itemgetter"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"def add(a: int, b: int) -> int:\n",
" \"\"\"Adds a and b.\"\"\"\n",
" return a + b\n",
"\n",
"\n",
"@tool\n",
"def multiply(a: int, b: int) -> int:\n",
" \"\"\"Multiplies a and b.\"\"\"\n",
" return a * b\n",
"\n",
"\n",
"tools = [convert_to_openai_function(t) for t in [add, multiply]]\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"tools_fmt = [\n",
" {\"type\": \"function\",\n",
" \"function\": tools[0]},\n",
" {\"type\": \"function\",\n",
" \"function\": tools[1]}\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"messages = [\n",
" {\"role\": \"system\", \"content\": \"You are helpful assistant.\"},\n",
" {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n",
"]\n",
"\n",
"client = openai.OpenAI(\n",
" base_url = \"https://api.endpoints.anyscale.com/v1\",\n",
" api_key = settings.anyscale_api_key\n",
")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
" messages=messages,\n",
" tools=tools_fmt,\n",
" tool_choice=\"auto\", # auto is default, but we'll be explicit\n",
")\n",
"\n",
"get_args = itemgetter(\"arguments\")\n"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"call = response.choices[0].message.tool_calls[0].function"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"add.func(2,3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -6,9 +6,12 @@ pydantic-settings
s3fs s3fs
aiofiles aiofiles
duckdb duckdb
duckdb-engine
polars polars
pyarrow pyarrow
pyjwt[crypto] pyjwt[crypto]
python-multipart python-multipart
authlib authlib
itsdangerous itsdangerous
sqlalchemy
openai

View file

@ -1,7 +1,7 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path from pathlib import Path
import duckdb from sqlalchemy import create_engine, text
class StorageEngines(StrEnum): class StorageEngines(StrEnum):
@ -19,11 +19,13 @@ class DDB:
bucket: str | None = None, bucket: str | None = None,
**kwargs, **kwargs,
): ):
self.db = duckdb.connect() self.engine = create_engine(
self.db.install_extension("spatial") "duckdb:///:memory:",
self.db.install_extension("httpfs") connect_args={
self.db.load_extension("spatial") "preload_extensions": ["https", "spatial"],
self.db.load_extension("httpfs") "config": {"memory_limit": "300mb"},
},
)
self.sheets = tuple() self.sheets = tuple()
self.loaded = False self.loaded = False
@ -35,12 +37,18 @@ class DDB:
bucket is not None, bucket is not None,
) )
): ):
self.db.sql(f""" with self.engine.connect() as conn:
conn.execute(
text(
f"""
CREATE SECRET ( CREATE SECRET (
TYPE GCS, TYPE GCS,
KEY_ID '{gcs_access}', KEY_ID '{gcs_access}',
SECRET '{gcs_secret}') SECRET '{gcs_secret}')
""") """
)
)
self.path_prefix = f"gcs://{bucket}" self.path_prefix = f"gcs://{bucket}"
else: else:
raise ValueError( raise ValueError(
@ -55,3 +63,7 @@ class DDB:
raise ValueError( raise ValueError(
"With local storage you need to provide the path keyword argument" "With local storage you need to provide the path keyword argument"
) )
@property
def db(self):
return self.engine.raw_connection()

View file

@ -6,7 +6,7 @@ from typing_extensions import Self
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from .db import DDB from . import DDB
class SessionDB(DDB): class SessionDB(DDB):
@ -149,7 +149,7 @@ class SessionDB(DDB):
) )
@property @property
def schema(self): def schema(self) -> str:
return os.linesep.join( return os.linesep.join(
[ [
"The schema of the database is the following:", "The schema of the database is the following:",

View file

@ -8,7 +8,7 @@ from uuid import UUID, uuid4
import duckdb import duckdb
import polars as pl import polars as pl
from .db import DDB, StorageEngines from . import DDB, StorageEngines
class UserDB(DDB): class UserDB(DDB):
@ -99,10 +99,10 @@ class OwnershipDB(DDB):
FROM FROM
'{self.path_prefix}/*.csv' '{self.path_prefix}/*.csv'
WHERE WHERE
email = '{user_email} email = '{user_email}'
ORDER BY ORDER BY
timestamp ASC timestamp ASC
LIMIT 10' LIMIT 10
""") """)
.pl() .pl()
.to_series() .to_series()

View file

@ -2,8 +2,8 @@ from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from hellocomputer.sessions import SessionDB
from ..config import settings from ..config import settings
from ..models import Chat from ..models import Chat
@ -13,7 +13,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB( db = SessionDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,
@ -22,9 +22,8 @@ async def query(sid: str = "", q: str = "") -> str:
sid=sid, sid=sid,
).load_folder() ).load_folder()
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q)) chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
result = str(db.query(query)) result = str(db.query(query))
print(result)
return result return result

View file

@ -5,7 +5,7 @@ from starlette.requests import Request
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.users import UserDB from hellocomputer.db.users import UserDB
router = APIRouter() router = APIRouter()

View file

@ -7,8 +7,8 @@ from starlette.requests import Request
from ..config import settings from ..config import settings
from ..db import StorageEngines from ..db import StorageEngines
from ..sessions import SessionDB from ..db.sessions import SessionDB
from ..users import OwnershipDB from ..db.users import OwnershipDB
router = APIRouter() router = APIRouter()

View file

@ -1,12 +1,12 @@
from typing import List
from uuid import uuid4 from uuid import uuid4
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from starlette.requests import Request from starlette.requests import Request
from typing import List
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.users import OwnershipDB from hellocomputer.db.users import OwnershipDB
from ..config import settings from ..config import settings

View file

@ -37,7 +37,7 @@
<p class="techie-font"> <p class="techie-font">
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
not be as powerful as Excel, but it has an efficient query backend that can process your data faster not be as powerful as Excel, but it has an efficient query backend that can process your data faster
and more efficiently than Excel. than Excel.
</p> </p>
<a href="/"><button class="btn btn-secondary w-100">Back</button></a> <a href="/"><button class="btn btn-secondary w-100">Back</button></a>
</div> </div>

View file

@ -25,9 +25,11 @@
</p> </p>
<a href="#" class="list-group-item list-group-item-action bg-light"><i <a href="#" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-question-circle"></i> How to</a> class="bi bi-question-circle"></i> How to</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i> <a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-file-ruled"></i>
File templates</a> File templates</a>
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i> <a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
class="bi bi-info-circle"></i>
About</a> About</a>
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i> <a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
Config</a> Config</a>

View file

@ -71,9 +71,7 @@ function addAIManualMessage(m) {
chatMessages.prepend(newMessage); // Add new message at the top chatMessages.prepend(newMessage); // Add new message at the top
} }
function addUserMessage() { function addUserMessageBlock(messageContent) {
const messageContent = textarea.value.trim();
if (messageContent) {
const newMessage = document.createElement('div'); const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
newMessage.textContent = messageContent; newMessage.textContent = messageContent;
@ -81,8 +79,20 @@ function addUserMessage() {
textarea.value = ''; // Clear the textarea textarea.value = ''; // Clear the textarea
textarea.style.height = 'auto'; // Reset the textarea height textarea.style.height = 'auto'; // Reset the textarea height
textarea.style.overflowY = 'hidden'; textarea.style.overflowY = 'hidden';
};
function addUserMessage() {
const messageContent = textarea.value.trim();
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
textarea.value = '';
addAIManualMessage('Please upload a data file or select a session first!');
}
else {
if (messageContent) {
addUserMessageBlock(messageContent);
addAIMessage(messageContent); addAIMessage(messageContent);
} }
}
}; };
sendButton.addEventListener('click', addUserMessage); sendButton.addEventListener('click', addUserMessage);
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
try { try {
const session_response = await fetch('/new_session'); const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
sessionStorage.setItem("helloComputerSessionLoaded", false);
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
const data = await response.text(); const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
sessionStorage.setItem("helloComputerSessionLoaded", true);
addAIManualMessage('File uploaded and processed!'); addAIManualMessage('File uploaded and processed!');
} catch (error) { } catch (error) {

View file

@ -2,7 +2,7 @@ from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.sessions import SessionDB from hellocomputer.db.sessions import SessionDB
TEST_STORAGE = StorageEngines.local TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = ( TEST_XLS_PATH = (

View file

@ -6,14 +6,17 @@ from hellocomputer.config import settings
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat from hellocomputer.models import Chat
from hellocomputer.sessions import SessionDB from hellocomputer.db.sessions import SessionDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
TEST_XLS_PATH = ( TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2] Path(hellocomputer.__file__).parents[2]
/ "test" / "test"
/ "data" / "data"
/ "TestExcelHelloComputer.xlsx" / "TestExcelHelloComputer.xlsx"
) )
SID = "test"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -35,9 +38,28 @@ async def test_simple_data_query():
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB( db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
).load_xls(TEST_XLS_PATH) ).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT") assert query.startswith("SELECT")
@pytest.mark.asyncio
@pytest.mark.skipif(
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
)
async def test_data_query():
q = "find the average score of all the sudents"
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result = db.query(query).pl()
assert result.shape[0] == 1

View file

@ -2,7 +2,7 @@ from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.users import OwnershipDB, UserDB from hellocomputer.db.users import OwnershipDB, UserDB
TEST_STORAGE = StorageEngines.local TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"