This commit is contained in:
parent
0c21073e88
commit
d7ba280a2f
|
@ -12,8 +12,16 @@ class StorageEngines(StrEnum):
|
||||||
|
|
||||||
|
|
||||||
class DDB:
|
class DDB:
|
||||||
def __init__(self, storage_engine: StorageEngines, **kwargs):
|
def __init__(
|
||||||
"""Write documentation"""
|
self,
|
||||||
|
storage_engine: StorageEngines,
|
||||||
|
sid: str | None = None,
|
||||||
|
path: Path | None = None,
|
||||||
|
gcs_access: str | None = None,
|
||||||
|
gcs_secret: str | None = None,
|
||||||
|
bucket: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.db = duckdb.connect()
|
self.db = duckdb.connect()
|
||||||
self.db.install_extension("spatial")
|
self.db.install_extension("spatial")
|
||||||
self.db.install_extension("httpfs")
|
self.db.install_extension("httpfs")
|
||||||
|
@ -23,28 +31,28 @@ class DDB:
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
|
|
||||||
if storage_engine == StorageEngines.gcs:
|
if storage_engine == StorageEngines.gcs:
|
||||||
if (
|
if all(
|
||||||
"gcs_access" in kwargs
|
gcs_access is not None,
|
||||||
and "gcs_secret" in kwargs
|
gcs_secret is not None,
|
||||||
and "bucketname" in kwargs
|
bucket is not None,
|
||||||
and "sid" in kwargs
|
sid is not None,
|
||||||
):
|
):
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
CREATE SECRET (
|
CREATE SECRET (
|
||||||
TYPE GCS,
|
TYPE GCS,
|
||||||
KEY_ID '{kwargs["gcs_access"]}',
|
KEY_ID '{gcs_access}',
|
||||||
SECRET '{kwargs["gcs_secret"]}')
|
SECRET '{gcs_secret}')
|
||||||
""")
|
""")
|
||||||
self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}"
|
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"With GCS storage engine you need to provide "
|
"With GCS storage engine you need to provide "
|
||||||
"the gcs_access, gcs_secret and bucket keyword arguments"
|
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif storage_engine == StorageEngines.local:
|
elif storage_engine == StorageEngines.local:
|
||||||
if "path" in kwargs:
|
if path is not None:
|
||||||
self.path_prefix = kwargs["path"]
|
self.path_prefix = path
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"With local storage you need to provide the path keyword argument"
|
"With local storage you need to provide the path keyword argument"
|
||||||
|
@ -98,13 +106,23 @@ class DDB:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def load_folder(self) -> Self:
|
def load_folder(self) -> Self:
|
||||||
|
self.query(
|
||||||
|
f"""
|
||||||
|
create table metadata as (
|
||||||
|
select
|
||||||
|
*
|
||||||
|
from
|
||||||
|
read_csv_auto('{self.path_prefix}/metadata.csv')
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
self.sheets = tuple(
|
self.sheets = tuple(
|
||||||
self.query(
|
self.query(
|
||||||
f"""
|
"""
|
||||||
select
|
select
|
||||||
Field2
|
Field2
|
||||||
from
|
from
|
||||||
read_csv_auto('{self.path_prefix}/metadata.csv')
|
metadata
|
||||||
where
|
where
|
||||||
Field1 = 'Sheets'
|
Field1 = 'Sheets'
|
||||||
"""
|
"""
|
||||||
|
@ -130,11 +148,11 @@ class DDB:
|
||||||
|
|
||||||
def load_description(self) -> Self:
|
def load_description(self) -> Self:
|
||||||
return self.query(
|
return self.query(
|
||||||
f"""
|
"""
|
||||||
select
|
select
|
||||||
Field2
|
Field2
|
||||||
from
|
from
|
||||||
read_csv_auto('{self.path_prefix}/metadata.csv')
|
metadata
|
||||||
where
|
where
|
||||||
Field1 = 'Description'"""
|
Field1 = 'Description'"""
|
||||||
).fetchall()[0][0]
|
).fetchall()[0][0]
|
||||||
|
|
|
@ -39,8 +39,6 @@ def test_schema():
|
||||||
for sheet in db.sheets:
|
for sheet in db.sheets:
|
||||||
schema.append(db.table_schema(sheet))
|
schema.append(db.table_schema(sheet))
|
||||||
|
|
||||||
print(db.schema)
|
|
||||||
|
|
||||||
assert db.schema.startswith("The schema of the database")
|
assert db.schema.startswith("The schema of the database")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.analytics import DDB
|
from hellocomputer.analytics import DDB, StorageEngines
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
from hellocomputer.models import Chat
|
from hellocomputer.models import Chat
|
||||||
|
@ -34,17 +34,10 @@ async def test_simple_data_query():
|
||||||
query = "write a query that finds the average score of all students in the current database"
|
query = "write a query that finds the average score of all students in the current database"
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = DDB().load_xls(TEST_XLS_PATH)
|
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
|
||||||
|
TEST_XLS_PATH
|
||||||
prompt = os.linesep.join(
|
|
||||||
[
|
|
||||||
query,
|
|
||||||
db.schema(),
|
|
||||||
db.load_description(),
|
|
||||||
"Return just the SQL statement",
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", prompt)
|
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
assert query.startswith("SELECT")
|
assert query.startswith("SELECT")
|
||||||
|
|
Loading…
Reference in a new issue