Further cleanups
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-05-31 23:59:02 +02:00
parent 0c21073e88
commit d7ba280a2f
3 changed files with 39 additions and 30 deletions

View file

@ -12,8 +12,16 @@ class StorageEngines(StrEnum):
class DDB:
def __init__(self, storage_engine: StorageEngines, **kwargs):
"""Write documentation"""
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
@ -23,28 +31,28 @@ class DDB:
self.loaded = False
if storage_engine == StorageEngines.gcs:
if (
"gcs_access" in kwargs
and "gcs_secret" in kwargs
and "bucketname" in kwargs
and "sid" in kwargs
if all(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{kwargs["gcs_access"]}',
SECRET '{kwargs["gcs_secret"]}')
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}"
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret and bucket keyword arguments"
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local:
if "path" in kwargs:
self.path_prefix = kwargs["path"]
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
@ -98,13 +106,23 @@ class DDB:
return self
def load_folder(self) -> Self:
self.query(
f"""
create table metadata as (
select
*
from
read_csv_auto('{self.path_prefix}/metadata.csv')
)
"""
)
self.sheets = tuple(
self.query(
f"""
"""
select
Field2
from
read_csv_auto('{self.path_prefix}/metadata.csv')
metadata
where
Field1 = 'Sheets'
"""
@ -130,11 +148,11 @@ class DDB:
def load_description(self) -> Self:
return self.query(
f"""
"""
select
Field2
from
read_csv_auto('{self.path_prefix}/metadata.csv')
metadata
where
Field1 = 'Description'"""
).fetchall()[0][0]

View file

@ -39,8 +39,6 @@ def test_schema():
for sheet in db.sheets:
schema.append(db.table_schema(sheet))
print(db.schema)
assert db.schema.startswith("The schema of the database")

View file

@ -3,7 +3,7 @@ from pathlib import Path
import hellocomputer
import pytest
from hellocomputer.analytics import DDB
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
@ -34,17 +34,10 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB().load_xls(TEST_XLS_PATH)
prompt = os.linesep.join(
[
query,
db.schema(),
db.load_description(),
"Return just the SQL statement",
]
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
TEST_XLS_PATH
)
chat = await chat.eval("You're an expert sql developer", prompt)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")