Refactored things a bit
This commit is contained in:
parent
b8e4e930d7
commit
ea14f8c87e
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -160,3 +160,4 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
.DS_Store
|
|
@ -3,3 +3,8 @@ langchain-community
|
||||||
openai
|
openai
|
||||||
fastapi
|
fastapi
|
||||||
pydantic-settings
|
pydantic-settings
|
||||||
|
s3fs
|
||||||
|
aiofiles
|
||||||
|
duckdb
|
||||||
|
polars
|
||||||
|
pyarrow
|
|
@ -5,6 +5,7 @@ class Settings(BaseSettings):
|
||||||
anyscale_api_key: str = "Awesome API"
|
anyscale_api_key: str = "Awesome API"
|
||||||
gcs_access: str = "access"
|
gcs_access: str = "access"
|
||||||
gcs_secret: str = "secret"
|
gcs_secret: str = "secret"
|
||||||
|
gcs_bucketname: str = "bucket"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,12 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import FastAPI, status
|
from fastapi import FastAPI, status
|
||||||
from fastapi.responses import PlainTextResponse
|
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from langchain_community.chat_models import ChatAnyscale
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
|
|
||||||
from .config import settings
|
from .routers import files, sessions
|
||||||
|
|
||||||
static_path = Path(hellocomputer.__file__).parent / "static"
|
static_path = Path(hellocomputer.__file__).parent / "static"
|
||||||
|
|
||||||
|
@ -43,30 +40,8 @@ def get_health() -> HealthCheck:
|
||||||
return HealthCheck(status="OK")
|
return HealthCheck(status="OK")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/greetings", response_class=PlainTextResponse)
|
app.include_router(sessions.router)
|
||||||
async def get_greeting() -> str:
|
app.include_router(files.router)
|
||||||
model = "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
||||||
chat = ChatAnyscale(
|
|
||||||
model_name=model,
|
|
||||||
temperature=0.5,
|
|
||||||
anyscale_api_key=settings.anyscale_api_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content="You are a helpful AI that shares everything you know."),
|
|
||||||
HumanMessage(
|
|
||||||
content="Make a short presentation of yourself "
|
|
||||||
"as an assistant in Spanish in about 20 words. "
|
|
||||||
"You're capable of analyzing a file that a user "
|
|
||||||
"has previously uploaded."
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
model_response = await chat.ainvoke(messages)
|
|
||||||
print(model_response.response_metadata)
|
|
||||||
return model_response.content
|
|
||||||
|
|
||||||
|
|
||||||
app.mount(
|
app.mount(
|
||||||
"/",
|
"/",
|
||||||
StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]),
|
StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]),
|
||||||
|
|
52
src/hellocomputer/models.py
Normal file
52
src/hellocomputer/models.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
from enum import StrEnum
|
||||||
|
from langchain_community.chat_models import ChatAnyscale
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
|
class AvailableModels(StrEnum):
|
||||||
|
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
class Chat:
|
||||||
|
@staticmethod
|
||||||
|
def raise_no_key(api_key):
|
||||||
|
if api_key:
|
||||||
|
return api_key
|
||||||
|
elif api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You need to provide a valid API in the api_key init argument"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("You need to provide a valid API key")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: AvailableModels = AvailableModels.llama3_8b,
|
||||||
|
api_key: str = "",
|
||||||
|
temperature: float = 0.5,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.api_key = self.raise_no_key(api_key)
|
||||||
|
self.messages = []
|
||||||
|
self.responses = []
|
||||||
|
|
||||||
|
model: ChatAnyscale = ChatAnyscale(
|
||||||
|
model_name=model, temperature=temperature, anyscale_api_key=self.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
async def eval(self, system: str, human: str):
|
||||||
|
self.messages.append(
|
||||||
|
[
|
||||||
|
SystemMessage(content=system),
|
||||||
|
HumanMessage(content=human),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.responses.append(await self.model.ainvoke(self.messages[-1]))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def last_response_content(self):
|
||||||
|
return self.responses[-1].content
|
||||||
|
|
||||||
|
def last_response_metadata(self):
|
||||||
|
return self.responses[-1].response_metadata
|
0
src/hellocomputer/routers/__init__.py
Normal file
0
src/hellocomputer/routers/__init__.py
Normal file
15
src/hellocomputer/routers/analysis.py
Normal file
15
src/hellocomputer/routers/analysis.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
|
from ..config import settings
|
||||||
|
from ..models import Chat
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
|
async def query(sid: str = "") -> str:
|
||||||
|
model = Chat(api_key=settings.anyscale_api_key).eval(
|
||||||
|
system="You're an expert analyst", human="Do some analysis"
|
||||||
|
)
|
||||||
|
return model.last_response_content()
|
84
src/hellocomputer/routers/files.py
Normal file
84
src/hellocomputer/routers/files.py
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
import aiofiles
|
||||||
|
import duckdb
|
||||||
|
import polars as pl
|
||||||
|
import s3fs
|
||||||
|
from fastapi import APIRouter, File, UploadFile
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from ..config import settings
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# Configure the S3FS with your Google Cloud Storage credentials
|
||||||
|
gcs = s3fs.S3FileSystem(
|
||||||
|
key=settings.gcs_access,
|
||||||
|
secret=settings.gcs_secret,
|
||||||
|
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
|
||||||
|
)
|
||||||
|
bucket_name = settings.gcs_bucketname
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", tags=["files"])
|
||||||
|
async def upload_file(file: UploadFile = File(...), sid: str = ""):
|
||||||
|
async with aiofiles.tempfile.NamedTemporaryFile("wb") as f:
|
||||||
|
content = await file.read()
|
||||||
|
await f.write(content)
|
||||||
|
await f.flush()
|
||||||
|
|
||||||
|
gcs.makedir(f"{settings.gcs_bucketname}/{sid}")
|
||||||
|
|
||||||
|
db = duckdb.connect()
|
||||||
|
db.install_extension("spatial")
|
||||||
|
db.install_extension("httpfs")
|
||||||
|
db.load_extension("httpfs")
|
||||||
|
db.load_extension("spatial")
|
||||||
|
|
||||||
|
db.sql(f"""
|
||||||
|
CREATE SECRET (
|
||||||
|
TYPE GCS,
|
||||||
|
KEY_ID '{settings.gcs_access}',
|
||||||
|
SECRET '{settings.gcs_secret}')
|
||||||
|
""")
|
||||||
|
|
||||||
|
db.sql(f"""
|
||||||
|
create table metadata as (
|
||||||
|
select
|
||||||
|
*
|
||||||
|
from
|
||||||
|
st_read('{f.name}',
|
||||||
|
layer='metadata',
|
||||||
|
open_options=['HEADERS_FORCE', 'FIELD_TYPES=auto']
|
||||||
|
)
|
||||||
|
)""")
|
||||||
|
|
||||||
|
metadata = db.query("select * from metadata").pl()
|
||||||
|
sheets = metadata.select(pl.col("Key") == "Sheets")
|
||||||
|
print(sheets)
|
||||||
|
|
||||||
|
for sheet in sheets.to_dict():
|
||||||
|
print(sheet)
|
||||||
|
|
||||||
|
db.sql(
|
||||||
|
f"""
|
||||||
|
create table data as (
|
||||||
|
select
|
||||||
|
*
|
||||||
|
from
|
||||||
|
st_read('{f.name}',
|
||||||
|
layer='data',
|
||||||
|
open_options=['HEADERS_FORCE', 'FIELD_TYPES=auto']
|
||||||
|
)
|
||||||
|
)"""
|
||||||
|
)
|
||||||
|
|
||||||
|
db.sql(f"""
|
||||||
|
copy
|
||||||
|
data
|
||||||
|
to
|
||||||
|
'gcs://{settings.gcs_bucketname}/{sid}/data.csv';
|
||||||
|
""")
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
content={"message": "File uploaded successfully"}, status_code=200
|
||||||
|
)
|
19
src/hellocomputer/routers/sessions.py
Normal file
19
src/hellocomputer/routers/sessions.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/new_session")
|
||||||
|
async def get_new_session() -> str:
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/greetings", response_class=PlainTextResponse)
|
||||||
|
async def get_greeting() -> str:
|
||||||
|
return (
|
||||||
|
"Hi! I'm a helpful assistant. Please upload or select a file "
|
||||||
|
"and I'll try to analyze it following your orders"
|
||||||
|
)
|
BIN
src/hellocomputer/static/.DS_Store
vendored
BIN
src/hellocomputer/static/.DS_Store
vendored
Binary file not shown.
|
@ -51,6 +51,11 @@
|
||||||
<div class="modal-body">
|
<div class="modal-body">
|
||||||
<input type="file" class="custom-file-input" id="inputGroupFile01">
|
<input type="file" class="custom-file-input" id="inputGroupFile01">
|
||||||
</div>
|
</div>
|
||||||
|
<div class="modal-body" id="uploadButtonDiv">
|
||||||
|
<button type="button" class="btn btn-primary" id="uploadButton">Upload</button>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body" id="uploadResultDiv">
|
||||||
|
</div>
|
||||||
<div class="modal-footer">
|
<div class="modal-footer">
|
||||||
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
|
<button type="button" class="btn btn-secondary" data-bs-dismiss="modal">Close</button>
|
||||||
</div>
|
</div>
|
||||||
|
@ -89,9 +94,6 @@
|
||||||
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js"
|
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.2.3/dist/js/bootstrap.bundle.min.js"
|
||||||
integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4"
|
integrity="sha384-kenU1KFdBIe4zVF0s0G1M5b4hcpxyD9F7jL+jjXkk+Q2h455rYXK/7HAuoJl+0I4"
|
||||||
crossorigin="anonymous"></script>
|
crossorigin="anonymous"></script>
|
||||||
<script src="https://unpkg.com/htmx.org@1.9.12"
|
|
||||||
integrity="sha384-ujb1lZYygJmzgSwoxRggbCHcjc0rB2XoQrxeTUQyRjrOnlCoYta87iKBWq3EsdM2"
|
|
||||||
crossorigin="anonymous"></script>
|
|
||||||
<script src="script.js"></script>
|
<script src="script.js"></script>
|
||||||
</body>
|
</body>
|
||||||
|
|
||||||
|
|
|
@ -46,11 +46,18 @@ async function fetchResponse(message) {
|
||||||
function addAIMessage() {
|
function addAIMessage() {
|
||||||
const newMessage = document.createElement('div');
|
const newMessage = document.createElement('div');
|
||||||
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
|
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
|
||||||
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner">';
|
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div id="spinner" class="spinner"></div>';
|
||||||
chatMessages.prepend(newMessage); // Add new message at the top
|
chatMessages.prepend(newMessage); // Add new message at the top
|
||||||
fetchResponse(newMessage);
|
fetchResponse(newMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function addAIManualMessage(m) {
|
||||||
|
const newMessage = document.createElement('div');
|
||||||
|
newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded');
|
||||||
|
newMessage.innerHTML = '<img src="/img/assistant.webp" width="50px"> <div>' + m + '</div>';
|
||||||
|
chatMessages.prepend(newMessage); // Add new message at the top
|
||||||
|
}
|
||||||
|
|
||||||
function addUserMessage() {
|
function addUserMessage() {
|
||||||
const messageContent = textarea.value.trim();
|
const messageContent = textarea.value.trim();
|
||||||
if (messageContent) {
|
if (messageContent) {
|
||||||
|
@ -82,7 +89,11 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
|
||||||
async function fetchGreeting() {
|
async function fetchGreeting() {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/greetings');
|
const session_response = await fetch('/new_session');
|
||||||
|
sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text()));
|
||||||
|
|
||||||
|
const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession'));
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error('Network response was not ok ' + response.statusText);
|
throw new Error('Network response was not ok ' + response.statusText);
|
||||||
}
|
}
|
||||||
|
@ -101,3 +112,39 @@ document.addEventListener("DOMContentLoaded", function () {
|
||||||
// Call the function to fetch greeting
|
// Call the function to fetch greeting
|
||||||
fetchGreeting();
|
fetchGreeting();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
const fileInput = document.getElementById('inputGroupFile01');
|
||||||
|
const uploadButton = document.getElementById('uploadButton');
|
||||||
|
const uploadResultDiv = document.getElementById('uploadResultDiv');
|
||||||
|
|
||||||
|
uploadButton.addEventListener('click', async function () {
|
||||||
|
const file = fileInput.files[0];
|
||||||
|
|
||||||
|
if (!file) {
|
||||||
|
uploadResultDiv.textContent = 'Please select a file.';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('file', file);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/upload?sid=' + sessionStorage.getItem('helloComputerSession'), {
|
||||||
|
method: 'POST',
|
||||||
|
body: formData
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Network response was not ok ' + response.statusText);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.text();
|
||||||
|
uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message'];
|
||||||
|
|
||||||
|
addAIManualMessage('File uploaded and processed!');
|
||||||
|
} catch (error) {
|
||||||
|
uploadResultDiv.textContent = 'Error: ' + error.message;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
Loading…
Reference in a new issue