Compare commits

...

3 commits

Author SHA1 Message Date
Guillem Borrell 1e881e7537 Also gcs support 2024-05-21 20:55:13 +01:00
Guillem Borrell ff781b6b9c Got to dump the file locally 2024-05-21 20:50:17 +01:00
Guillem Borrell ea14f8c87e Refactored things a bit 2024-05-21 08:23:03 +02:00
16 changed files with 297 additions and 33 deletions

3
.gitignore vendored
View file

@ -160,3 +160,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
.DS_Store
test/data/output/*

View file

@ -3,3 +3,9 @@ langchain-community
openai openai
fastapi fastapi
pydantic-settings pydantic-settings
s3fs
aiofiles
duckdb
polars
pyarrow
xlsx2csv

View file

@ -0,0 +1,87 @@
import duckdb
class DDB:
def __init__(self):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.path = ""
def gcs_secret(self, gcs_access: str, gcs_secret: str):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
return self
def load_metadata(self, path: str = ""):
"""For some reason, the header is not loaded"""
self.db.sql(f"""
create table metadata as (
select
*
from
st_read('{path}',
layer='metadata'
)
)""")
self.sheets = tuple(
self.db.query("select Field2 from metadata where Field1 = 'Sheets'")
.fetchall()[0][0]
.split(",")
)
self.path = path
return self
def dump_local(self, path):
# TODO: Port to fsspec and have a single dump file
self.db.query(f"copy metadata to '{path}/metadata.csv'")
for sheet in self.sheets:
self.db.query(f"""
copy
(
select
*
from
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
to '{path}/{sheet}.csv'
""")
return self
def dump_gcs(self, bucketname, sid):
self.db.sql(f"copy metadata to 'gcs://{bucketname}/{sid}/data.csv'")
for sheet in self.sheets:
self.db.query(f"""
copy
(
select
*
from
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
to 'gcs://{bucketname}/{sid}/{sheet}.csv'
""")
return self
def query(self, sql):
return self.db.query(sql)

View file

@ -5,6 +5,7 @@ class Settings(BaseSettings):
anyscale_api_key: str = "Awesome API" anyscale_api_key: str = "Awesome API"
gcs_access: str = "access" gcs_access: str = "access"
gcs_secret: str = "secret" gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")

View file

@ -1,15 +1,12 @@
from pathlib import Path from pathlib import Path
from fastapi import FastAPI, status from fastapi import FastAPI, status
from fastapi.responses import PlainTextResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel from pydantic import BaseModel
import hellocomputer import hellocomputer
from .config import settings from .routers import files, sessions
static_path = Path(hellocomputer.__file__).parent / "static" static_path = Path(hellocomputer.__file__).parent / "static"
@ -43,30 +40,8 @@ def get_health() -> HealthCheck:
return HealthCheck(status="OK") return HealthCheck(status="OK")
@app.get("/greetings", response_class=PlainTextResponse) app.include_router(sessions.router)
async def get_greeting() -> str: app.include_router(files.router)
model = "meta-llama/Meta-Llama-3-8B-Instruct"
chat = ChatAnyscale(
model_name=model,
temperature=0.5,
anyscale_api_key=settings.anyscale_api_key,
)
messages = [
SystemMessage(content="You are a helpful AI that shares everything you know."),
HumanMessage(
content="Make a short presentation of yourself "
"as an assistant in Spanish in about 20 words. "
"You're capable of analyzing a file that a user "
"has previously uploaded."
),
]
model_response = await chat.ainvoke(messages)
print(model_response.response_metadata)
return model_response.content
app.mount( app.mount(
"/", "/",
StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]), StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]),

View file

@ -0,0 +1,52 @@
from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
class AvailableModels(StrEnum):
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
class Chat:
@staticmethod
def raise_no_key(api_key):
if api_key:
return api_key
elif api_key is None:
raise ValueError(
"You need to provide a valid API in the api_key init argument"
)
else:
raise ValueError("You need to provide a valid API key")
def __init__(
self,
model: AvailableModels = AvailableModels.llama3_8b,
api_key: str = "",
temperature: float = 0.5,
):
self.model = model
self.api_key = self.raise_no_key(api_key)
self.messages = []
self.responses = []
model: ChatAnyscale = ChatAnyscale(
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
)
async def eval(self, system: str, human: str):
self.messages.append(
[
SystemMessage(content=system),
HumanMessage(content=human),
]
)
self.responses.append(await self.model.ainvoke(self.messages[-1]))
return self
def last_response_content(self):
return self.responses[-1].content
def last_response_metadata(self):
return self.responses[-1].response_metadata

View file

View file

@ -0,0 +1,15 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from ..config import settings
from ..models import Chat
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "") -> str:
model = Chat(api_key=settings.anyscale_api_key).eval(
system="You're an expert analyst", human="Do some analysis"
)
return model.last_response_content()

View file

@ -0,0 +1,39 @@
import aiofiles
import s3fs
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..config import settings
from ..analytics import DDB
router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials
gcs = s3fs.S3FileSystem(
key=settings.gcs_access,
secret=settings.gcs_secret,
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
)
bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"])
async def upload_file(file: UploadFile = File(...), sid: str = ""):
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
content = await file.read()
await f.write(content)
await f.flush()
gcs.makedir(f"{settings.gcs_bucketname}/{sid}")
(
DDB()
.gcs_secret(settings.gcs_secret, settings.gcs_secret)
.load_metadata(f.name)
.dump_gcs(settings.gcs_bucketname, sid)
)
return JSONResponse(
content={"message": "File uploaded successfully"}, status_code=200
)

View file

@ -0,0 +1,19 @@
from uuid import uuid4
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
router = APIRouter()
@router.get("/new_session")
async def get_new_session() -> str:
return str(uuid4())
@router.get("/greetings", response_class=PlainTextResponse)
async def get_greeting() -> str:
return (
"Hi! I'm a helpful assistant. Please upload or select a file "
"and I'll try to analyze it following your orders"
)

Binary file not shown.

View file

@ -51,6 +51,11 @@
<div class="modal-body"> <div class="modal-body">
<input type="file" class="custom-file-input" id="inputGroupFile01"> <input type="file" class="custom-file-input" id="inputGroupFile01">
</div> </div>
<div class="modal-body" id="uploadButtonDiv">
<button type="button" class="btn btn-primary" id="uploadButton">Upload</button>
</div>
<div class="modal-body" id="uploadResultDiv">
</div>
<div class="modal-footer"> <div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button> <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
</div> </div>
@ -89,9 +94,6 @@
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js" <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js"
integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4" integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4"
crossorigin="anonymous"></script> crossorigin="anonymous"></script>
<script src="https://unpkg.com/htmx.org@1.9.12"
integrity="sha384-ujb1lZYygJmzgSwoxRggbCHcjc0rB2XoQrxeTUQyRjrOnlCoYta87iKBWq3EsdM2"
crossorigin="anonymous"></script>
<script src="script.js"></script> <script src="script.js"></script>
</body> </body>

View file

@ -46,11 +46,18 @@ async function fetchResponse(message) {
function addAIMessage() { function addAIMessage() {
const newMessage = document.createElement('div'); const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded'); newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner">'; newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner"></div>';
chatMessages.prepend(newMessage); // Add new message at the top chatMessages.prepend(newMessage); // Add new message at the top
fetchResponse(newMessage); fetchResponse(newMessage);
} }
function addAIManualMessage(m) {
const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div>' + m + '</div>';
chatMessages.prepend(newMessage); // Add new message at the top
}
function addUserMessage() { function addUserMessage() {
const messageContent = textarea.value.trim(); const messageContent = textarea.value.trim();
if (messageContent) { if (messageContent) {
@ -82,7 +89,11 @@ document.addEventListener("DOMContentLoaded", function () {
async function fetchGreeting() { async function fetchGreeting() {
try { try {
const response = await fetch('/greetings'); const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
if (!response.ok) { if (!response.ok) {
throw new Error('Network response was not ok ' + response.statusText); throw new Error('Network response was not ok ' + response.statusText);
} }
@ -101,3 +112,39 @@ document.addEventListener("DOMContentLoaded", function () {
// Call the function to fetch greeting // Call the function to fetch greeting
fetchGreeting(); fetchGreeting();
}); });
document.addEventListener("DOMContentLoaded", function () {
const fileInput = document.getElementById('inputGroupFile01');
const uploadButton = document.getElementById('uploadButton');
const uploadResultDiv = document.getElementById('uploadResultDiv');
uploadButton.addEventListener('click', async function () {
const file = fileInput.files[0];
if (!file) {
uploadResultDiv.textContent = 'Please select a file.';
return;
}
const formData = new FormData();
formData.append('file', file);
try {
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Network response was not ok ' + response.statusText);
}
const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
addAIManualMessage('File uploaded and processed!');
} catch (error) {
uploadResultDiv.textContent = 'Error: ' + error.message;
}
});
});

Binary file not shown.

1
test/output/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
*.csv

17
test/test_load.py Normal file
View file

@ -0,0 +1,17 @@
import hellocomputer
from hellocomputer.analytics import DDB
from pathlib import Path
TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data"
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_load_data():
db = (
DDB()
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
.dump_local(TEST_OUTPUT_FOLDER)
)
assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()