Refactored things a bit

This commit is contained in:
Guillem Borrell 2024-05-21 08:23:03 +02:00
parent b8e4e930d7
commit ea14f8c87e
12 changed files with 234 additions and 33 deletions

1
.gitignore vendored
View file

@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
.DS_Store

View file

@ -3,3 +3,8 @@ langchain-community
openai openai
fastapi fastapi
pydantic-settings pydantic-settings
s3fs
aiofiles
duckdb
polars
pyarrow

View file

@ -5,6 +5,7 @@ class Settings(BaseSettings):
anyscale_api_key: str = "Awesome API" anyscale_api_key: str = "Awesome API"
gcs_access: str = "access" gcs_access: str = "access"
gcs_secret: str = "secret" gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")

View file

@ -1,15 +1,12 @@
from pathlib import Path from pathlib import Path
from fastapi import FastAPI, status from fastapi import FastAPI, status
from fastapi.responses import PlainTextResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel from pydantic import BaseModel
import hellocomputer import hellocomputer
from .config import settings from .routers import files, sessions
static_path = Path(hellocomputer.__file__).parent / "static" static_path = Path(hellocomputer.__file__).parent / "static"
@ -43,30 +40,8 @@ def get_health() -> HealthCheck:
return HealthCheck(status="OK") return HealthCheck(status="OK")
@app.get("/greetings", response_class=PlainTextResponse) app.include_router(sessions.router)
async def get_greeting() -> str: app.include_router(files.router)
model = "meta-llama/Meta-Llama-3-8B-Instruct"
chat = ChatAnyscale(
model_name=model,
temperature=0.5,
anyscale_api_key=settings.anyscale_api_key,
)
messages = [
SystemMessage(content="You are a helpful AI that shares everything you know."),
HumanMessage(
content="Make a short presentation of yourself "
"as an assistant in Spanish in about 20 words. "
"You're capable of analyzing a file that a user "
"has previously uploaded."
),
]
model_response = await chat.ainvoke(messages)
print(model_response.response_metadata)
return model_response.content
app.mount( app.mount(
"/", "/",
StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]), StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]),

View file

@ -0,0 +1,52 @@
from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage
class AvailableModels(StrEnum):
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
class Chat:
@staticmethod
def raise_no_key(api_key):
if api_key:
return api_key
elif api_key is None:
raise ValueError(
"You need to provide a valid API in the api_key init argument"
)
else:
raise ValueError("You need to provide a valid API key")
def __init__(
self,
model: AvailableModels = AvailableModels.llama3_8b,
api_key: str = "",
temperature: float = 0.5,
):
self.model = model
self.api_key = self.raise_no_key(api_key)
self.messages = []
self.responses = []
model: ChatAnyscale = ChatAnyscale(
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
)
async def eval(self, system: str, human: str):
self.messages.append(
[
SystemMessage(content=system),
HumanMessage(content=human),
]
)
self.responses.append(await self.model.ainvoke(self.messages[-1]))
return self
def last_response_content(self):
return self.responses[-1].content
def last_response_metadata(self):
return self.responses[-1].response_metadata

View file

View file

@ -0,0 +1,15 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from ..config import settings
from ..models import Chat
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "") -> str:
model = Chat(api_key=settings.anyscale_api_key).eval(
system="You're an expert analyst", human="Do some analysis"
)
return model.last_response_content()

View file

@ -0,0 +1,84 @@
import aiofiles
import duckdb
import polars as pl
import s3fs
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..config import settings
router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials
gcs = s3fs.S3FileSystem(
key=settings.gcs_access,
secret=settings.gcs_secret,
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
)
bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"])
async def upload_file(file: UploadFile = File(...), sid: str = ""):
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
content = await file.read()
await f.write(content)
await f.flush()
gcs.makedir(f"{settings.gcs_bucketname}/{sid}")
db = duckdb.connect()
db.install_extension("spatial")
db.install_extension("httpfs")
db.load_extension("httpfs")
db.load_extension("spatial")
db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{settings.gcs_access}',
SECRET '{settings.gcs_secret}')
""")
db.sql(f"""
create table metadata as (
select
*
from
st_read('{f.name}',
layer='metadata',
open_options=['HEADERS_FORCE', 'FIELD_TYPES=auto']
)
)""")
metadata = db.query("select * from metadata").pl()
sheets = metadata.select(pl.col("Key") == "Sheets")
print(sheets)
for sheet in sheets.to_dict():
print(sheet)
db.sql(
f"""
create table data as (
select
*
from
st_read('{f.name}',
layer='data',
open_options=['HEADERS_FORCE', 'FIELD_TYPES=auto']
)
)"""
)
db.sql(f"""
copy
data
to
'gcs://{settings.gcs_bucketname}/{sid}/data.csv';
""")
return JSONResponse(
content={"message": "File uploaded successfully"}, status_code=200
)

View file

@ -0,0 +1,19 @@
from uuid import uuid4
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
router = APIRouter()
@router.get("/new_session")
async def get_new_session() -> str:
return str(uuid4())
@router.get("/greetings", response_class=PlainTextResponse)
async def get_greeting() -> str:
return (
"Hi! I'm a helpful assistant. Please upload or select a file "
"and I'll try to analyze it following your orders"
)

Binary file not shown.

View file

@ -51,6 +51,11 @@
<div class="modal-body"> <div class="modal-body">
<input type="file" class="custom-file-input" id="inputGroupFile01"> <input type="file" class="custom-file-input" id="inputGroupFile01">
</div> </div>
<div class="modal-body" id="uploadButtonDiv">
<button type="button" class="btn btn-primary" id="uploadButton">Upload</button>
</div>
<div class="modal-body" id="uploadResultDiv">
</div>
<div class="modal-footer"> <div class="modal-footer">
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button> <button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
</div> </div>
@ -89,9 +94,6 @@
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js" <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js"
integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4" integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4"
crossorigin="anonymous"></script> crossorigin="anonymous"></script>
<script src="https://unpkg.com/htmx.org@1.9.12"
integrity="sha384-ujb1lZYygJmzgSwoxRggbCHcjc0rB2XoQrxeTUQyRjrOnlCoYta87iKBWq3EsdM2"
crossorigin="anonymous"></script>
<script src="script.js"></script> <script src="script.js"></script>
</body> </body>

View file

@ -46,11 +46,18 @@ async function fetchResponse(message) {
function addAIMessage() { function addAIMessage() {
const newMessage = document.createElement('div'); const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded'); newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner">'; newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner"></div>';
chatMessages.prepend(newMessage); // Add new message at the top chatMessages.prepend(newMessage); // Add new message at the top
fetchResponse(newMessage); fetchResponse(newMessage);
} }
function addAIManualMessage(m) {
const newMessage = document.createElement('div');
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div>' + m + '</div>';
chatMessages.prepend(newMessage); // Add new message at the top
}
function addUserMessage() { function addUserMessage() {
const messageContent = textarea.value.trim(); const messageContent = textarea.value.trim();
if (messageContent) { if (messageContent) {
@ -82,7 +89,11 @@ document.addEventListener("DOMContentLoaded", function () {
async function fetchGreeting() { async function fetchGreeting() {
try { try {
const response = await fetch('/greetings'); const session_response = await fetch('/new_session');
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
if (!response.ok) { if (!response.ok) {
throw new Error('Network response was not ok ' + response.statusText); throw new Error('Network response was not ok ' + response.statusText);
} }
@ -101,3 +112,39 @@ document.addEventListener("DOMContentLoaded", function () {
// Call the function to fetch greeting // Call the function to fetch greeting
fetchGreeting(); fetchGreeting();
}); });
document.addEventListener("DOMContentLoaded", function () {
const fileInput = document.getElementById('inputGroupFile01');
const uploadButton = document.getElementById('uploadButton');
const uploadResultDiv = document.getElementById('uploadResultDiv');
uploadButton.addEventListener('click', async function () {
const file = fileInput.files[0];
if (!file) {
uploadResultDiv.textContent = 'Please select a file.';
return;
}
const formData = new FormData();
formData.append('file', file);
try {
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Network response was not ok ' + response.statusText);
}
const data = await response.text();
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
addAIManualMessage('File uploaded and processed!');
} catch (error) {
uploadResultDiv.textContent = 'Error: ' + error.message;
}
});
});