Refactored analytics class

This commit is contained in:
Guillem Borrell 2024-05-28 21:23:11 +01:00
parent 06ac295e17
commit 0c21073e88
9 changed files with 156 additions and 163 deletions

View file

@ -1,36 +1,63 @@
import duckdb
import os import os
from enum import StrEnum
from pathlib import Path
import duckdb
from typing_extensions import Self from typing_extensions import Self
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB: class DDB:
def __init__(self): def __init__(self, storage_engine: StorageEngines, **kwargs):
"""Write documentation"""
self.db = duckdb.connect() self.db = duckdb.connect()
self.db.install_extension("spatial") self.db.install_extension("spatial")
self.db.install_extension("httpfs") self.db.install_extension("httpfs")
self.db.load_extension("spatial") self.db.load_extension("spatial")
self.db.load_extension("httpfs") self.db.load_extension("httpfs")
self.sheets = tuple() self.sheets = tuple()
self.path = "" self.loaded = False
def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self: if storage_engine == StorageEngines.gcs:
self.db.sql(f""" if (
CREATE SECRET ( "gcs_access" in kwargs
TYPE GCS, and "gcs_secret" in kwargs
KEY_ID '{gcs_access}', and "bucketname" in kwargs
SECRET '{gcs_secret}') and "sid" in kwargs
""") ):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{kwargs["gcs_access"]}',
SECRET '{kwargs["gcs_secret"]}')
""")
self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret and bucket keyword arguments"
)
return self elif storage_engine == StorageEngines.local:
if "path" in kwargs:
self.path_prefix = kwargs["path"]
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
def load_metadata(self, path: str = "") -> Self: def load_xls(self, xls_path: Path) -> Self:
"""For some reason, the header is not loaded""" """For some reason, the header is not loaded"""
self.db.sql(f""" self.db.sql(f"""
create table metadata as ( create table metadata as (
select select
* *
from from
st_read('{path}', st_read('{xls_path}',
layer='metadata' layer='metadata'
) )
)""") )""")
@ -39,62 +66,45 @@ class DDB:
.fetchall()[0][0] .fetchall()[0][0]
.split(";") .split(";")
) )
self.path = path
return self
def dump_local(self, path) -> Self:
# TODO: Port to fsspec and have a single dump file
self.db.query(f"copy metadata to '{path}/metadata.csv'")
for sheet in self.sheets: for sheet in self.sheets:
self.db.query(f""" self.db.query(f"""
copy create table {sheet} as
( (
select select
* *
from from
st_read st_read
( (
'{self.path}', '{xls_path}',
layer = '{sheet}' layer = '{sheet}'
) )
) )
to '{path}/{sheet}.csv'
""") """)
self.loaded = True
return self return self
def dump_gcs(self, bucketname, sid) -> Self: def dump(self) -> Self:
self.db.sql( # TODO: Create a decorator
f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'" if not self.loaded:
) raise ValueError("Data should be loaded first")
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
for sheet in self.sheets: for sheet in self.sheets:
self.db.query(f""" self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
copy
(
select
*
from
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv'
""")
return self return self
def load_folder_local(self, path: str) -> Self: def load_folder(self) -> Self:
self.sheets = tuple( self.sheets = tuple(
self.query( self.query(
f""" f"""
select select
Field2 Field2
from from
read_csv_auto('{path}/metadata.csv') read_csv_auto('{self.path_prefix}/metadata.csv')
where where
Field1 = 'Sheets' Field1 = 'Sheets'
""" """
@ -110,61 +120,21 @@ class DDB:
select select
* *
from from
read_csv_auto('{path}/{sheet}.csv') read_csv_auto('{self.path_prefix}/{sheet}.csv')
) )
""") """)
return self self.loaded = True
def load_folder_gcs(self, bucketname: str, sid: str) -> Self:
self.sheets = tuple(
self.query(
f"""
select
Field2
from
read_csv_auto(
'gcs://{bucketname}/sessions/{sid}/metadata.csv'
)
where
Field1 = 'Sheets'
"""
)
.fetchall()[0][0]
.split(";")
)
# Load all the tables into the database
for sheet in self.sheets:
self.db.query(f"""
create table {sheet} as (
select
*
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv')
)
""")
return self return self
def load_description_local(self, path: str) -> Self: def load_description(self) -> Self:
return self.query( return self.query(
f""" f"""
select select
Field2 Field2
from from
read_csv_auto('{path}/metadata.csv') read_csv_auto('{self.path_prefix}/metadata.csv')
where
Field1 = 'Description'"""
).fetchall()[0][0]
def load_description_gcs(self, bucketname: str, sid: str) -> Self:
return self.query(
f"""
select
Field2
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv')
where where
Field1 = 'Description'""" Field1 = 'Description'"""
).fetchall()[0][0] ).fetchall()[0][0]
@ -182,9 +152,11 @@ class DDB:
f"select column_name, column_type from (describe {table})" f"select column_name, column_type from (describe {table})"
).fetchall() ).fetchall()
) )
+ [os.linesep]
) )
def db_schema(self): @property
def schema(self):
return os.linesep.join( return os.linesep.join(
[ [
"The schema of the database is the following:", "The schema of the database is the following:",
@ -194,3 +166,18 @@ class DDB:
def query(self, sql, *args, **kwargs): def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs) return self.db.query(sql, *args, **kwargs)
def query_prompt(self, user_prompt: str) -> str:
query = (
f"The following sentence is the description of a query that "
f"needs to be executed in a database: {user_prompt}"
)
return os.linesep.join(
[
query,
self.schema,
self.load_description(),
"Return just the SQL statement",
]
)

View file

@ -6,7 +6,7 @@ class Settings(BaseSettings):
gcs_access: str = "access" gcs_access: str = "access"
gcs_secret: str = "secret" gcs_secret: str = "secret"
gcs_bucketname: str = "bucket" gcs_bucketname: str = "bucket"
auth: bool = False auth: bool = True
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
import hellocomputer import hellocomputer
from .routers import files, sessions, analysis from .routers import analysis, files, sessions
static_path = Path(hellocomputer.__file__).parent / "static" static_path = Path(hellocomputer.__file__).parent / "static"

View file

@ -1,4 +1,5 @@
from enum import StrEnum from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage

View file

@ -1,41 +1,29 @@
import os
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from ..config import settings from hellocomputer.analytics import DDB, StorageEngines
from ..models import Chat
from hellocomputer.analytics import DDB
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
import os from ..config import settings
from ..models import Chat
router = APIRouter() router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(sid: str = "", q: str = "") -> str:
print(q)
query = f"Write a query that {q} in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = ( db = DDB(
DDB() StorageEngines.gcs,
.gcs_secret(settings.gcs_access, settings.gcs_secret) gcs_access=settings.gcs_access,
.load_folder_gcs(settings.gcs_bucketname, sid) gcs_secret=settings.gcs_secret,
) bucket=settings.gcs_bucketname,
sid=sid,
).load_folder()
prompt = os.linesep.join( chat = await chat.eval("You're an expert sql developer", db.query_prompt())
[
query,
db.db_schema(),
db.load_description_gcs(settings.gcs_bucketname, sid),
"Return just the SQL statement",
]
)
print(prompt)
chat = await chat.eval("You're an expert sql developer", prompt)
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
result = str(db.query(query)) result = str(db.query(query))
print(result) print(result)

View file

@ -1,21 +1,22 @@
import aiofiles import aiofiles
import s3fs
# import s3fs
from fastapi import APIRouter, File, UploadFile from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ..analytics import DDB, StorageEngines
from ..config import settings from ..config import settings
from ..analytics import DDB
router = APIRouter() router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials # Configure the S3FS with your Google Cloud Storage credentials
gcs = s3fs.S3FileSystem( # gcs = s3fs.S3FileSystem(
key=settings.gcs_access, # key=settings.gcs_access,
secret=settings.gcs_secret, # secret=settings.gcs_secret,
client_kwargs={"endpoint_url": "https://storage.googleapis.com"}, # client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
) # )
bucket_name = settings.gcs_bucketname # bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"]) @router.post("/upload", tags=["files"])
@ -26,10 +27,15 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush() await f.flush()
( (
DDB() DDB(
.gcs_secret(settings.gcs_access, settings.gcs_secret) StorageEngines.gcs,
.load_metadata(f.name) gcs_access=settings.gcs_access,
.dump_gcs(settings.gcs_bucketname, sid) gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
)
.load_xls(f.name)
.dump()
) )
return JSONResponse( return JSONResponse(

View file

@ -1,22 +1,16 @@
from uuid import uuid4
from typing import Annotated from typing import Annotated
from uuid import uuid4
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from ..config import settings
from ..security import oauth2_scheme from ..security import oauth2_scheme
# Scheme for the Authorization header # Scheme for the Authorization header
router = APIRouter() router = APIRouter()
if settings.auth:
@router.get("/token")
async def get_token() -> str:
return str(uuid4())
@router.get("/new_session") @router.get("/new_session")
async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str: async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
return str(uuid4()) return str(uuid4())

View file

@ -1,40 +1,52 @@
import hellocomputer
from hellocomputer.analytics import DDB
from pathlib import Path from pathlib import Path
TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data" import hellocomputer
from hellocomputer.analytics import DDB, StorageEngines
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_dump(): def test_0_dump():
db = ( db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
DDB() db.load_xls(TEST_XLS_PATH).dump()
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
.dump_local(TEST_OUTPUT_FOLDER)
)
assert db.sheets == ("answers",) assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists() assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()
def test_load(): def test_load():
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.sheets == ("answers",)
results = db.query("select * from answers").fetchall() results = db.query("select * from answers").fetchall()
assert len(results) == 6 assert len(results) == 6
def test_load_description(): def test_load_description():
file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER) db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
file_description = db.load_description()
assert file_description.startswith("answers") assert file_description.startswith("answers")
def test_schema(): def test_schema():
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
schema = [] schema = []
for sheet in db.sheets: for sheet in db.sheets:
schema.append(db.table_schema(sheet)) schema.append(db.table_schema(sheet))
assert schema[0].startswith("Table name:") print(db.schema)
assert db.schema.startswith("The schema of the database")
def test_query_prompt():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"
)

View file

@ -1,14 +1,19 @@
import pytest
import os import os
import hellocomputer
from hellocomputer.config import settings
from hellocomputer.models import Chat
from hellocomputer.extraction import extract_code_block
from pathlib import Path from pathlib import Path
import hellocomputer
import pytest
from hellocomputer.analytics import DDB from hellocomputer.analytics import DDB
from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
TEST_XLS_PATH = (
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -29,13 +34,13 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database" query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) db = DDB().load_xls(TEST_XLS_PATH)
prompt = os.linesep.join( prompt = os.linesep.join(
[ [
query, query,
db.db_schema(), db.schema(),
db.load_description_local(TEST_OUTPUT_FOLDER), db.load_description(),
"Return just the SQL statement", "Return just the SQL statement",
] ]
) )