Compare commits
3 commits
d642bb0a39
...
181bc92884
Author | SHA1 | Date | |
---|---|---|---|
181bc92884 | |||
e43f81df39 | |||
e8c7600ed7 |
137
notebooks/tasks.ipynb
Normal file
137
notebooks/tasks.ipynb
Normal file
|
@ -0,0 +1,137 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from hellocomputer.config import settings\n",
|
||||
"from langchain_core.utils.function_calling import convert_to_openai_function\n",
|
||||
"import openai\n",
|
||||
"from operator import itemgetter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.tools import tool\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def add(a: int, b: int) -> int:\n",
|
||||
" \"\"\"Adds a and b.\"\"\"\n",
|
||||
" return a + b\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@tool\n",
|
||||
"def multiply(a: int, b: int) -> int:\n",
|
||||
" \"\"\"Multiplies a and b.\"\"\"\n",
|
||||
" return a * b\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tools = [convert_to_openai_function(t) for t in [add, multiply]]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tools_fmt = [\n",
|
||||
" {\"type\": \"function\",\n",
|
||||
" \"function\": tools[0]},\n",
|
||||
" {\"type\": \"function\",\n",
|
||||
" \"function\": tools[1]}\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" {\"role\": \"system\", \"content\": \"You are helpful assistant.\"},\n",
|
||||
" {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"client = openai.OpenAI(\n",
|
||||
" base_url = \"https://api.endpoints.anyscale.com/v1\",\n",
|
||||
" api_key = settings.anyscale_api_key\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
|
||||
" messages=messages,\n",
|
||||
" tools=tools_fmt,\n",
|
||||
" tool_choice=\"auto\", # auto is default, but we'll be explicit\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"get_args = itemgetter(\"arguments\")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"call = response.choices[0].message.tool_calls[0].function"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"5"
|
||||
]
|
||||
},
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"add.func(2,3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -6,9 +6,12 @@ pydantic-settings
|
|||
s3fs
|
||||
aiofiles
|
||||
duckdb
|
||||
duckdb-engine
|
||||
polars
|
||||
pyarrow
|
||||
pyjwt[crypto]
|
||||
python-multipart
|
||||
authlib
|
||||
itsdangerous
|
||||
itsdangerous
|
||||
sqlalchemy
|
||||
openai
|
|
@ -1,7 +1,7 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
|
@ -19,11 +19,13 @@ class DDB:
|
|||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.db = duckdb.connect()
|
||||
self.db.install_extension("spatial")
|
||||
self.db.install_extension("httpfs")
|
||||
self.db.load_extension("spatial")
|
||||
self.db.load_extension("httpfs")
|
||||
self.engine = create_engine(
|
||||
"duckdb:///:memory:",
|
||||
connect_args={
|
||||
"preload_extensions": ["https", "spatial"],
|
||||
"config": {"memory_limit": "300mb"},
|
||||
},
|
||||
)
|
||||
self.sheets = tuple()
|
||||
self.loaded = False
|
||||
|
||||
|
@ -35,12 +37,18 @@ class DDB:
|
|||
bucket is not None,
|
||||
)
|
||||
):
|
||||
self.db.sql(f"""
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
f"""
|
||||
CREATE SECRET (
|
||||
TYPE GCS,
|
||||
KEY_ID '{gcs_access}',
|
||||
SECRET '{gcs_secret}')
|
||||
""")
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
self.path_prefix = f"gcs://{bucket}"
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -55,3 +63,7 @@ class DDB:
|
|||
raise ValueError(
|
||||
"With local storage you need to provide the path keyword argument"
|
||||
)
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.engine.raw_connection()
|
|
@ -6,7 +6,7 @@ from typing_extensions import Self
|
|||
|
||||
from hellocomputer.db import StorageEngines
|
||||
|
||||
from .db import DDB
|
||||
from . import DDB
|
||||
|
||||
|
||||
class SessionDB(DDB):
|
||||
|
@ -149,7 +149,7 @@ class SessionDB(DDB):
|
|||
)
|
||||
|
||||
@property
|
||||
def schema(self):
|
||||
def schema(self) -> str:
|
||||
return os.linesep.join(
|
||||
[
|
||||
"The schema of the database is the following:",
|
|
@ -8,7 +8,7 @@ from uuid import UUID, uuid4
|
|||
import duckdb
|
||||
import polars as pl
|
||||
|
||||
from .db import DDB, StorageEngines
|
||||
from . import DDB, StorageEngines
|
||||
|
||||
|
||||
class UserDB(DDB):
|
||||
|
@ -99,10 +99,10 @@ class OwnershipDB(DDB):
|
|||
FROM
|
||||
'{self.path_prefix}/*.csv'
|
||||
WHERE
|
||||
email = '{user_email}
|
||||
email = '{user_email}'
|
||||
ORDER BY
|
||||
timestamp ASC
|
||||
LIMIT 10'
|
||||
LIMIT 10
|
||||
""")
|
||||
.pl()
|
||||
.to_series()
|
|
@ -2,8 +2,8 @@ from fastapi import APIRouter
|
|||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.sessions import SessionDB
|
||||
|
||||
from ..config import settings
|
||||
from ..models import Chat
|
||||
|
@ -13,7 +13,7 @@ router = APIRouter()
|
|||
|
||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||
async def query(sid: str = "", q: str = "") -> str:
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
|
@ -22,9 +22,8 @@ async def query(sid: str = "", q: str = "") -> str:
|
|||
sid=sid,
|
||||
).load_folder()
|
||||
|
||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(q))
|
||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = str(db.query(query))
|
||||
print(result)
|
||||
|
||||
return result
|
||||
|
|
|
@ -5,7 +5,7 @@ from starlette.requests import Request
|
|||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.users import UserDB
|
||||
from hellocomputer.db.users import UserDB
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ from starlette.requests import Request
|
|||
|
||||
from ..config import settings
|
||||
from ..db import StorageEngines
|
||||
from ..sessions import SessionDB
|
||||
from ..users import OwnershipDB
|
||||
from ..db.sessions import SessionDB
|
||||
from ..db.users import OwnershipDB
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from typing import List
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from starlette.requests import Request
|
||||
from typing import List
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.users import OwnershipDB
|
||||
from hellocomputer.db.users import OwnershipDB
|
||||
|
||||
from ..config import settings
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@
|
|||
<p class="techie-font">
|
||||
Hola, computer! is a web assistant that allows you to query excel files using natural language. It may
|
||||
not be as powerful as Excel, but it has an efficient query backend that can process your data faster
|
||||
and more efficiently than Excel.
|
||||
than Excel.
|
||||
</p>
|
||||
<a href="/"><button class="btn btn-secondary w-100">Back</button></a>
|
||||
</div>
|
||||
|
|
|
@ -25,9 +25,11 @@
|
|||
</p>
|
||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i
|
||||
class="bi bi-question-circle"></i> How to</a>
|
||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-file-ruled"></i>
|
||||
<a href="/app/templates" class="list-group-item list-group-item-action bg-light"><i
|
||||
class="bi bi-file-ruled"></i>
|
||||
File templates</a>
|
||||
<a href="#" class="list-group-item list-group-item-action bg-light"><i class="bi bi-info-circle"></i>
|
||||
<a href="/app/about.html" class="list-group-item list-group-item-action bg-light"><i
|
||||
class="bi bi-info-circle"></i>
|
||||
About</a>
|
||||
<a href="/config" class="list-group-item list-group-item-action bg-light"><i class="bi bi-toggles"></i>
|
||||
Config</a>
|
||||
|
|
|
@ -71,17 +71,27 @@ function addAIManualMessage(m) {
|
|||
chatMessages.prepend(newMessage); // Add new message at the top
|
||||
}
|
||||
|
||||
function addUserMessageBlock(messageContent) {
|
||||
const newMessage = document.createElement('div');
|
||||
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
|
||||
newMessage.textContent = messageContent;
|
||||
chatMessages.prepend(newMessage); // Add new message at the top
|
||||
textarea.value = ''; // Clear the textarea
|
||||
textarea.style.height = 'auto'; // Reset the textarea height
|
||||
textarea.style.overflowY = 'hidden';
|
||||
};
|
||||
|
||||
function addUserMessage() {
|
||||
const messageContent = textarea.value.trim();
|
||||
if (messageContent) {
|
||||
const newMessage = document.createElement('div');
|
||||
newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded');
|
||||
newMessage.textContent = messageContent;
|
||||
chatMessages.prepend(newMessage); // Add new message at the top
|
||||
textarea.value = ''; // Clear the textarea
|
||||
textarea.style.height = 'auto'; // Reset the textarea height
|
||||
textarea.style.overflowY = 'hidden';
|
||||
addAIMessage(messageContent);
|
||||
if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') {
|
||||
textarea.value = '';
|
||||
addAIManualMessage('Please upload a data file or select a session first!');
|
||||
}
|
||||
else {
|
||||
if (messageContent) {
|
||||
addUserMessageBlock(messageContent);
|
||||
addAIMessage(messageContent);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
try {
|
||||
const session_response = await fetch('/new_session');
|
||||
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
|
||||
sessionStorage.setItem("helloComputerSessionLoaded", false);
|
||||
|
||||
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
|
||||
|
||||
|
@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () {
|
|||
|
||||
const data = await response.text();
|
||||
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
|
||||
sessionStorage.setItem("helloComputerSessionLoaded", true);
|
||||
|
||||
addAIManualMessage('File uploaded and processed!');
|
||||
} catch (error) {
|
||||
|
|
BIN
src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
Normal file
BIN
src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx
Normal file
Binary file not shown.
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.sessions import SessionDB
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_XLS_PATH = (
|
||||
|
|
|
@ -6,14 +6,17 @@ from hellocomputer.config import settings
|
|||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.models import Chat
|
||||
from hellocomputer.sessions import SessionDB
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
TEST_XLS_PATH = (
|
||||
Path(hellocomputer.__file__).parents[2]
|
||||
/ "test"
|
||||
/ "data"
|
||||
/ "TestExcelHelloComputer.xlsx"
|
||||
)
|
||||
SID = "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -35,9 +38,28 @@ async def test_simple_data_query():
|
|||
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
||||
).load_xls(TEST_XLS_PATH)
|
||||
|
||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
assert query.startswith("SELECT")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(
|
||||
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||
)
|
||||
async def test_data_query():
|
||||
q = "find the average score of all the sudents"
|
||||
|
||||
llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
|
||||
chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result = db.query(query).pl()
|
||||
|
||||
assert result.shape[0] == 1
|
||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.users import OwnershipDB, UserDB
|
||||
from hellocomputer.db.users import OwnershipDB, UserDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
|
|
Loading…
Reference in a new issue