Refactored analytics class

This commit is contained in:
Guillem Borrell 2024-05-28 21:23:11 +01:00
parent 06ac295e17
commit 0c21073e88
9 changed files with 156 additions and 163 deletions

View file

@ -1,36 +1,63 @@
import duckdb
import os
from enum import StrEnum
from pathlib import Path
import duckdb
from typing_extensions import Self
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(self):
def __init__(self, storage_engine: StorageEngines, **kwargs):
"""Write documentation"""
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.path = ""
self.loaded = False
def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self:
if storage_engine == StorageEngines.gcs:
if (
"gcs_access" in kwargs
and "gcs_secret" in kwargs
and "bucketname" in kwargs
and "sid" in kwargs
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
KEY_ID '{kwargs["gcs_access"]}',
SECRET '{kwargs["gcs_secret"]}')
""")
self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret and bucket keyword arguments"
)
return self
elif storage_engine == StorageEngines.local:
if "path" in kwargs:
self.path_prefix = kwargs["path"]
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
def load_metadata(self, path: str = "") -> Self:
def load_xls(self, xls_path: Path) -> Self:
"""For some reason, the header is not loaded"""
self.db.sql(f"""
create table metadata as (
select
*
from
st_read('{path}',
st_read('{xls_path}',
layer='metadata'
)
)""")
@ -39,62 +66,45 @@ class DDB:
.fetchall()[0][0]
.split(";")
)
self.path = path
return self
def dump_local(self, path) -> Self:
# TODO: Port to fsspec and have a single dump file
self.db.query(f"copy metadata to '{path}/metadata.csv'")
for sheet in self.sheets:
self.db.query(f"""
copy
create table {sheet} as
(
select
*
from
st_read
(
'{self.path}',
'{xls_path}',
layer = '{sheet}'
)
)
to '{path}/{sheet}.csv'
""")
self.loaded = True
return self
def dump_gcs(self, bucketname, sid) -> Self:
self.db.sql(
f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'"
)
def dump(self) -> Self:
# TODO: Create a decorator
if not self.loaded:
raise ValueError("Data should be loaded first")
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
for sheet in self.sheets:
self.db.query(f"""
copy
(
select
*
from
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv'
""")
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
return self
def load_folder_local(self, path: str) -> Self:
def load_folder(self) -> Self:
self.sheets = tuple(
self.query(
f"""
select
Field2
from
read_csv_auto('{path}/metadata.csv')
read_csv_auto('{self.path_prefix}/metadata.csv')
where
Field1 = 'Sheets'
"""
@ -110,61 +120,21 @@ class DDB:
select
*
from
read_csv_auto('{path}/{sheet}.csv')
read_csv_auto('{self.path_prefix}/{sheet}.csv')
)
""")
return self
def load_folder_gcs(self, bucketname: str, sid: str) -> Self:
self.sheets = tuple(
self.query(
f"""
select
Field2
from
read_csv_auto(
'gcs://{bucketname}/sessions/{sid}/metadata.csv'
)
where
Field1 = 'Sheets'
"""
)
.fetchall()[0][0]
.split(";")
)
# Load all the tables into the database
for sheet in self.sheets:
self.db.query(f"""
create table {sheet} as (
select
*
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv')
)
""")
self.loaded = True
return self
def load_description_local(self, path: str) -> Self:
def load_description(self) -> Self:
return self.query(
f"""
select
Field2
from
read_csv_auto('{path}/metadata.csv')
where
Field1 = 'Description'"""
).fetchall()[0][0]
def load_description_gcs(self, bucketname: str, sid: str) -> Self:
return self.query(
f"""
select
Field2
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv')
read_csv_auto('{self.path_prefix}/metadata.csv')
where
Field1 = 'Description'"""
).fetchall()[0][0]
@ -182,9 +152,11 @@ class DDB:
f"select column_name, column_type from (describe {table})"
).fetchall()
)
+ [os.linesep]
)
def db_schema(self):
@property
def schema(self):
return os.linesep.join(
[
"The schema of the database is the following:",
@ -194,3 +166,18 @@ class DDB:
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
def query_prompt(self, user_prompt: str) -> str:
query = (
f"The following sentence is the description of a query that "
f"needs to be executed in a database: {user_prompt}"
)
return os.linesep.join(
[
query,
self.schema,
self.load_description(),
"Return just the SQL statement",
]
)

View file

@ -6,7 +6,7 @@ class Settings(BaseSettings):
gcs_access: str = "access"
gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"
auth: bool = False
auth: bool = True
model_config = SettingsConfigDict(env_file=".env")

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
import hellocomputer
from .routers import files, sessions, analysis
from .routers import analysis, files, sessions
static_path = Path(hellocomputer.__file__).parent / "static"

View file

@ -1,4 +1,5 @@
from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage

View file

@ -1,41 +1,29 @@
import os
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from ..config import settings
from ..models import Chat
from hellocomputer.analytics import DDB
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.extraction import extract_code_block
import os
from ..config import settings
from ..models import Chat
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
print(q)
query = f"Write a query that {q} in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = (
DDB()
.gcs_secret(settings.gcs_access, settings.gcs_secret)
.load_folder_gcs(settings.gcs_bucketname, sid)
)
db = DDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
).load_folder()
prompt = os.linesep.join(
[
query,
db.db_schema(),
db.load_description_gcs(settings.gcs_bucketname, sid),
"Return just the SQL statement",
]
)
print(prompt)
chat = await chat.eval("You're an expert sql developer", prompt)
chat = await chat.eval("You're an expert sql developer", db.query_prompt())
query = extract_code_block(chat.last_response_content())
result = str(db.query(query))
print(result)

View file

@ -1,21 +1,22 @@
import aiofiles
import s3fs
# import s3fs
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..analytics import DDB, StorageEngines
from ..config import settings
from ..analytics import DDB
router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials
gcs = s3fs.S3FileSystem(
key=settings.gcs_access,
secret=settings.gcs_secret,
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
)
bucket_name = settings.gcs_bucketname
# gcs = s3fs.S3FileSystem(
# key=settings.gcs_access,
# secret=settings.gcs_secret,
# client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
# )
# bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"])
@ -26,10 +27,15 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush()
(
DDB()
.gcs_secret(settings.gcs_access, settings.gcs_secret)
.load_metadata(f.name)
.dump_gcs(settings.gcs_bucketname, sid)
DDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
)
.load_xls(f.name)
.dump()
)
return JSONResponse(

View file

@ -1,22 +1,16 @@
from uuid import uuid4
from typing import Annotated
from uuid import uuid4
from fastapi import APIRouter, Depends
from fastapi.responses import PlainTextResponse
from ..config import settings
from ..security import oauth2_scheme
# Scheme for the Authorization header
router = APIRouter()
if settings.auth:
@router.get("/token")
async def get_token() -> str:
return str(uuid4())
@router.get("/new_session")
async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
return str(uuid4())

View file

@ -1,40 +1,52 @@
import hellocomputer
from hellocomputer.analytics import DDB
from pathlib import Path
TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data"
import hellocomputer
from hellocomputer.analytics import DDB, StorageEngines
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_dump():
db = (
DDB()
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
.dump_local(TEST_OUTPUT_FOLDER)
)
def test_0_dump():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
db.load_xls(TEST_XLS_PATH).dump()
assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()
def test_load():
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
assert db.sheets == ("answers",)
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
results = db.query("select * from answers").fetchall()
assert len(results) == 6
def test_load_description():
file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER)
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
file_description = db.load_description()
assert file_description.startswith("answers")
def test_schema():
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
schema = []
for sheet in db.sheets:
schema.append(db.table_schema(sheet))
assert schema[0].startswith("Table name:")
print(db.schema)
assert db.schema.startswith("The schema of the database")
def test_query_prompt():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"
)

View file

@ -1,14 +1,19 @@
import pytest
import os
import hellocomputer
from hellocomputer.config import settings
from hellocomputer.models import Chat
from hellocomputer.extraction import extract_code_block
from pathlib import Path
import hellocomputer
import pytest
from hellocomputer.analytics import DDB
from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
@pytest.mark.asyncio
@ -29,13 +34,13 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
db = DDB().load_xls(TEST_XLS_PATH)
prompt = os.linesep.join(
[
query,
db.db_schema(),
db.load_description_local(TEST_OUTPUT_FOLDER),
db.schema(),
db.load_description(),
"Return just the SQL statement",
]
)