Compare commits

...

3 commits

Author SHA1 Message Date
Guillem Borrell 1e881e7537 Also gcs support 2024-05-21 20:55:13 +01:00
Guillem Borrell ff781b6b9c Got to dump the file locally 2024-05-21 20:50:17 +01:00
Guillem Borrell ea14f8c87e Refactored things a bit 2024-05-21 08:23:03 +02:00
16 changed files with 297 additions and 33 deletions

3
.gitignore vendored
View file

@ -160,3 +160,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.DS_Store
test/data/output/*

View file

@ -3,3 +3,9 @@ langchain-community
openai
fastapi
pydantic-settings
s3fs
aiofiles
duckdb
polars
pyarrow
xlsx2csv

View file

@ -0,0 +1,87 @@
import duckdb
class DDB:
def __init__(self):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.path = ""
def gcs_secret(self, gcs_access: str, gcs_secret: str):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
return self
def load_metadata(self, path: str = ""):
"""For some reason, the header is not loaded"""
self.db.sql(f"""
create table metadata as (
select
*
from
st_read('{path}',
layer='metadata'
)
)""")
self.sheets = tuple(
self.db.query("select Field2 from metadata where Field1 = 'Sheets'")
.fetchall()[0][0]
.split(",")
)
self.path = path
return self
def dump_local(self, path):
# TODO: Port to fsspec and have a single dump file
self.db.query(f"copy metadata to '{path}/metadata.csv'")
for sheet in self.sheets:
self.db.query(f"""
copy
(
select
*
from
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
to '{path}/{sheet}.csv'
""")
return self
def dump_gcs(self, bucketname, sid):
self.db.sql(f"copy metadata to 'gcs://{bucketname}/{sid}/data.csv'")
for sheet in self.sheets:
self.db.query(f"""
copy
(
select
*
from
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
to 'gcs://{bucketname}/{sid}/{sheet}.csv'
""")
return self
def query(self, sql):
return self.db.query(sql)

View file

@ -5,6 +5,7 @@ class Settings(BaseSettings):
anyscale_api_key: str = "Awesome API"
gcs_access: str = "access"
gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"
model_config = SettingsConfigDict(env_file=".env")

View file

@ -1,15 +1,12 @@
from pathlib import Path
from fastapi import FastAPI, status
from fastapi.responses import PlainTextResponse
from fastapi.staticfiles import StaticFiles
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel
import hellocomputer
from .config import settings
from .routers import files, sessions
static_path = Path(hellocomputer.__file__).parent / "static"
@ -43,30 +40,8 @@ def get_health() -> HealthCheck:
return HealthCheck(status="OK")
@app.get("/greetings", response_class=PlainTextResponse)
async def get_greeting() -> str:
model = "meta-llama/Meta-Llama-3-8B-Instruct"
chat = ChatAnyscale(
model_name=model,
temperature=0.5,
anyscale_api_key=settings.anyscale_api_key,
)
messages = [
SystemMessage(content="You are a helpful AI that shares everything you know."),
HumanMessage(
content="Make a short presentation of yourself "
"as an assistant in Spanish in about 20 words. "
"You're capable of analyzing a file that a user "
"has previously uploaded."
),
]
model_response = await chat.ainvoke(messages)
print(model_response.response_metadata)
return model_response.content
app.include_router(sessions.router)
app.include_router(files.router)
app.mount(
"/",
StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]),

View file

@ -0,0 +1,52 @@
from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
class AvailableModels(StrEnum):
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
class Chat:
@staticmethod
def raise_no_key(api_key):
if api_key:
return api_key
elif api_key is None:
raise ValueError(
"You need to provide a valid API in the api_key init argument"
)
else:
raise ValueError("You need to provide a valid API key")
def __init__(
self,
model: AvailableModels = AvailableModels.llama3_8b,
api_key: str = "",
temperature: float = 0.5,
):
self.model = model
self.api_key = self.raise_no_key(api_key)
self.messages = []
self.responses = []
model: ChatAnyscale = ChatAnyscale(
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
)
async def eval(self, system: str, human: str):
self.messages.append(
[
SystemMessage(content=system),
HumanMessage(content=human),
]
)
self.responses.append(await self.model.ainvoke(self.messages[-1]))
return self
def last_response_content(self):
return self.responses[-1].content
def last_response_metadata(self):
return self.responses[-1].response_metadata

View file

View file

@ -0,0 +1,15 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from ..config import settings
from ..models import Chat
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "") -> str:
model = Chat(api_key=settings.anyscale_api_key).eval(
system="You're an expert analyst", human="Do some analysis"
)
return model.last_response_content()

View file

@ -0,0 +1,39 @@
import aiofiles
import s3fs
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..config import settings
from ..analytics import DDB
router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials
gcs = s3fs.S3FileSystem(
key=settings.gcs_access,
secret=settings.gcs_secret,
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
)
bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"])
async def upload_file(file: UploadFile = File(...), sid: str = ""):
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
content = await file.read()
await f.write(content)
await f.flush()
gcs.makedir(f"{settings.gcs_bucketname}/{sid}")
(
DDB()
.gcs_secret(settings.gcs_secret, settings.gcs_secret)
.load_metadata(f.name)
.dump_gcs(settings.gcs_bucketname, sid)
)
return JSONResponse(
content={"message": "File uploaded successfully"}, status_code=200
)

View file

@ -0,0 +1,19 @@
from uuid import uuid4
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
router = APIRouter()
@router.get("/new_session")
async def get_new_session() -> str:
return str(uuid4())
@router.get("/greetings", response_class=PlainTextResponse)
async def get_greeting() -> str:
return (
"Hi! I'm a helpful assistant. Please upload or select a file "
"and I'll try to analyze it following your orders"
)

Binary file not shown.

View file

@ -51,6 +51,11 @@
<div class="modal-body">
<input type="file" class="custom-file-input" id="inputGroupFile01">
</div>
<div class="modal-body" id="uploadButtonDiv">
<button type="button" class="btn btn-primary" id="uploadButton">Upload</button>
</div>
<div class="modal-body" id="uploadResultDiv">
</div>
<div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
</div>
@ -89,9 +94,6 @@
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js"
integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4"
crossorigin="anonymous"></script>
<script src="https://unpkg.com/htmx.org@1.9.12"
integrity="sha384-ujb1lZYygJmzgSwoxRggbCHcjc0rB2XoQrxeTUQyRjrOnlCoYta87iKBWq3EsdM2"
crossorigin="anonymous"></script>
<script src="script.js"></script>
</body>

View file

@ -46,11 +46,18 @@ async function fetchResponse(message) {
function addAIMessage() {
const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner">';
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner"></div>';
chatMessages.prepend(newMessage); // Add new message at the top
fetchResponse(newMessage);
}
function addAIManualMessage(m) {
const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div>' + m + '</div>';
chatMessages.prepend(newMessage); // Add new message at the top
}
function addUserMessage() {
const messageContent = textarea.value.trim();
if (messageContent) {
@ -82,7 +89,11 @@ document.addEventListener("DOMContentLoaded", function () {
async function fetchGreeting() {
try {
const response = await fetch('/greetings');
const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
if (!response.ok) {
throw new Error('Network response was not ok ' + response.statusText);
}
@ -101,3 +112,39 @@ document.addEventListener("DOMContentLoaded", function () {
// Call the function to fetch greeting
fetchGreeting();
});
document.addEventListener("DOMContentLoaded", function () {
const fileInput = document.getElementById('inputGroupFile01');
const uploadButton = document.getElementById('uploadButton');
const uploadResultDiv = document.getElementById('uploadResultDiv');
uploadButton.addEventListener('click', async function () {
const file = fileInput.files[0];
if (!file) {
uploadResultDiv.textContent = 'Please select a file.';
return;
}
const formData = new FormData();
formData.append('file', file);
try {
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Network response was not ok ' + response.statusText);
}
const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
addAIManualMessage('File uploaded and processed!');
} catch (error) {
uploadResultDiv.textContent = 'Error: ' + error.message;
}
});
});

Binary file not shown.

1
test/output/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
*.csv

17
test/test_load.py Normal file
View file

@ -0,0 +1,17 @@
import hellocomputer
from hellocomputer.analytics import DDB
from pathlib import Path
TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data"
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_load_data():
db = (
DDB()
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
.dump_local(TEST_OUTPUT_FOLDER)
)
assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()