Further cleanups
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-05-31 23:59:02 +02:00
parent 0c21073e88
commit d7ba280a2f
3 changed files with 39 additions and 30 deletions

View file

@ -12,8 +12,16 @@ class StorageEngines(StrEnum):
class DDB: class DDB:
def __init__(self, storage_engine: StorageEngines, **kwargs): def __init__(
"""Write documentation""" self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect() self.db = duckdb.connect()
self.db.install_extension("spatial") self.db.install_extension("spatial")
self.db.install_extension("httpfs") self.db.install_extension("httpfs")
@ -23,28 +31,28 @@ class DDB:
self.loaded = False self.loaded = False
if storage_engine == StorageEngines.gcs: if storage_engine == StorageEngines.gcs:
if ( if all(
"gcs_access" in kwargs gcs_access is not None,
and "gcs_secret" in kwargs gcs_secret is not None,
and "bucketname" in kwargs bucket is not None,
and "sid" in kwargs sid is not None,
): ):
self.db.sql(f""" self.db.sql(f"""
CREATE SECRET ( CREATE SECRET (
TYPE GCS, TYPE GCS,
KEY_ID '{kwargs["gcs_access"]}', KEY_ID '{gcs_access}',
SECRET '{kwargs["gcs_secret"]}') SECRET '{gcs_secret}')
""") """)
self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}" self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else: else:
raise ValueError( raise ValueError(
"With GCS storage engine you need to provide " "With GCS storage engine you need to provide "
"the gcs_access, gcs_secret and bucket keyword arguments" "the gcs_access, gcs_secret, sid, and bucket keyword arguments"
) )
elif storage_engine == StorageEngines.local: elif storage_engine == StorageEngines.local:
if "path" in kwargs: if path is not None:
self.path_prefix = kwargs["path"] self.path_prefix = path
else: else:
raise ValueError( raise ValueError(
"With local storage you need to provide the path keyword argument" "With local storage you need to provide the path keyword argument"
@ -98,13 +106,23 @@ class DDB:
return self return self
def load_folder(self) -> Self: def load_folder(self) -> Self:
self.query(
f"""
create table metadata as (
select
*
from
read_csv_auto('{self.path_prefix}/metadata.csv')
)
"""
)
self.sheets = tuple( self.sheets = tuple(
self.query( self.query(
f""" """
select select
Field2 Field2
from from
read_csv_auto('{self.path_prefix}/metadata.csv') metadata
where where
Field1 = 'Sheets' Field1 = 'Sheets'
""" """
@ -130,11 +148,11 @@ class DDB:
def load_description(self) -> Self: def load_description(self) -> Self:
return self.query( return self.query(
f""" """
select select
Field2 Field2
from from
read_csv_auto('{self.path_prefix}/metadata.csv') metadata
where where
Field1 = 'Description'""" Field1 = 'Description'"""
).fetchall()[0][0] ).fetchall()[0][0]

View file

@ -39,8 +39,6 @@ def test_schema():
for sheet in db.sheets: for sheet in db.sheets:
schema.append(db.table_schema(sheet)) schema.append(db.table_schema(sheet))
print(db.schema)
assert db.schema.startswith("The schema of the database") assert db.schema.startswith("The schema of the database")

View file

@ -3,7 +3,7 @@ from pathlib import Path
import hellocomputer import hellocomputer
import pytest import pytest
from hellocomputer.analytics import DDB from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat from hellocomputer.models import Chat
@ -34,17 +34,10 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database" query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB().load_xls(TEST_XLS_PATH) db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
TEST_XLS_PATH
prompt = os.linesep.join(
[
query,
db.schema(),
db.load_description(),
"Return just the SQL statement",
]
) )
chat = await chat.eval("You're an expert sql developer", prompt) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT") assert query.startswith("SELECT")