diff --git a/src/hellocomputer/analytics.py b/src/hellocomputer/analytics.py index 607b2b4..08905ab 100644 --- a/src/hellocomputer/analytics.py +++ b/src/hellocomputer/analytics.py @@ -12,8 +12,16 @@ class StorageEngines(StrEnum): class DDB: - def __init__(self, storage_engine: StorageEngines, **kwargs): - """Write documentation""" + def __init__( + self, + storage_engine: StorageEngines, + sid: str | None = None, + path: Path | None = None, + gcs_access: str | None = None, + gcs_secret: str | None = None, + bucket: str | None = None, + **kwargs, + ): self.db = duckdb.connect() self.db.install_extension("spatial") self.db.install_extension("httpfs") @@ -23,28 +31,28 @@ class DDB: self.loaded = False if storage_engine == StorageEngines.gcs: - if ( - "gcs_access" in kwargs - and "gcs_secret" in kwargs - and "bucketname" in kwargs - and "sid" in kwargs + if all( + gcs_access is not None, + gcs_secret is not None, + bucket is not None, + sid is not None, ): self.db.sql(f""" CREATE SECRET ( TYPE GCS, - KEY_ID '{kwargs["gcs_access"]}', - SECRET '{kwargs["gcs_secret"]}') + KEY_ID '{gcs_access}', + SECRET '{gcs_secret}') """) - self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}" + self.path_prefix = f"gcs://{bucket}/sessions/{sid}" else: raise ValueError( "With GCS storage engine you need to provide " - "the gcs_access, gcs_secret and bucket keyword arguments" + "the gcs_access, gcs_secret, sid, and bucket keyword arguments" ) elif storage_engine == StorageEngines.local: - if "path" in kwargs: - self.path_prefix = kwargs["path"] + if path is not None: + self.path_prefix = path else: raise ValueError( "With local storage you need to provide the path keyword argument" @@ -98,13 +106,23 @@ class DDB: return self def load_folder(self) -> Self: + self.query( + f""" + create table metadata as ( + select + * + from + read_csv_auto('{self.path_prefix}/metadata.csv') + ) + """ + ) self.sheets = tuple( self.query( - f""" + """ select Field2 from - read_csv_auto('{self.path_prefix}/metadata.csv') + metadata where Field1 = 'Sheets' """ @@ -130,11 +148,11 @@ class DDB: def load_description(self) -> Self: return self.query( - f""" + """ select Field2 from - read_csv_auto('{self.path_prefix}/metadata.csv') + metadata where Field1 = 'Description'""" ).fetchall()[0][0] diff --git a/test/test_data.py b/test/test_data.py index 5030eaa..2b8d8fa 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -39,8 +39,6 @@ def test_schema(): for sheet in db.sheets: schema.append(db.table_schema(sheet)) - print(db.schema) - assert db.schema.startswith("The schema of the database") diff --git a/test/test_query.py b/test/test_query.py index 59956a3..0880587 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -3,7 +3,7 @@ from pathlib import Path import hellocomputer import pytest -from hellocomputer.analytics import DDB +from hellocomputer.analytics import DDB, StorageEngines from hellocomputer.config import settings from hellocomputer.extraction import extract_code_block from hellocomputer.models import Chat @@ -34,17 +34,10 @@ async def test_simple_data_query(): query = "write a query that finds the average score of all students in the current database" chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = DDB().load_xls(TEST_XLS_PATH) - - prompt = os.linesep.join( - [ - query, - db.schema(), - db.load_description(), - "Return just the SQL statement", - ] + db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls( + TEST_XLS_PATH ) - chat = await chat.eval("You're an expert sql developer", prompt) + chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) query = extract_code_block(chat.last_response_content()) assert query.startswith("SELECT")