diff --git a/src/hellocomputer/analytics.py b/src/hellocomputer/analytics.py index 6074c3c..0062b19 100644 --- a/src/hellocomputer/analytics.py +++ b/src/hellocomputer/analytics.py @@ -1,65 +1,11 @@ import os -from enum import StrEnum from pathlib import Path -import duckdb from typing_extensions import Self +from .db import DDB -class StorageEngines(StrEnum): - local = "Local" - gcs = "GCS" - - -class DDB: - def __init__( - self, - storage_engine: StorageEngines, - sid: str | None = None, - path: Path | None = None, - gcs_access: str | None = None, - gcs_secret: str | None = None, - bucket: str | None = None, - **kwargs, - ): - self.db = duckdb.connect() - self.db.install_extension("spatial") - self.db.install_extension("httpfs") - self.db.load_extension("spatial") - self.db.load_extension("httpfs") - self.sheets = tuple() - self.loaded = False - - if storage_engine == StorageEngines.gcs: - if all( - ( - gcs_access is not None, - gcs_secret is not None, - bucket is not None, - sid is not None, - ) - ): - self.db.sql(f""" - CREATE SECRET ( - TYPE GCS, - KEY_ID '{gcs_access}', - SECRET '{gcs_secret}') - """) - self.path_prefix = f"gcs://{bucket}/sessions/{sid}" - else: - raise ValueError( - "With GCS storage engine you need to provide " - "the gcs_access, gcs_secret, sid, and bucket keyword arguments" - ) - - elif storage_engine == StorageEngines.local: - if path is not None: - self.path_prefix = path - else: - raise ValueError( - "With local storage you need to provide the path keyword argument" - ) - +class AnalyticsDB(DDB): def load_xls(self, xls_path: Path) -> Self: """For some reason, the header is not loaded""" self.db.sql(f""" diff --git a/src/hellocomputer/db.py b/src/hellocomputer/db.py new file mode 100644 index 0000000..49bc9c3 --- /dev/null +++ b/src/hellocomputer/db.py @@ -0,0 +1,58 @@ +from enum import StrEnum +import duckdb +from pathlib import Path + + +class StorageEngines(StrEnum): + local = "Local" + gcs = "GCS" + + +class DDB: + def __init__( + self, + storage_engine: StorageEngines, + sid: str | None = None, + path: Path | None = None, + gcs_access: str | None = None, + gcs_secret: str | None = None, + bucket: str | None = None, + **kwargs, + ): + self.db = duckdb.connect() + self.db.install_extension("spatial") + self.db.install_extension("httpfs") + self.db.load_extension("spatial") + self.db.load_extension("httpfs") + self.sheets = tuple() + self.loaded = False + + if storage_engine == StorageEngines.gcs: + if all( + ( + gcs_access is not None, + gcs_secret is not None, + bucket is not None, + sid is not None, + ) + ): + self.db.sql(f""" + CREATE SECRET ( + TYPE GCS, + KEY_ID '{gcs_access}', + SECRET '{gcs_secret}') + """) + self.path_prefix = f"gcs://{bucket}/sessions/{sid}" + else: + raise ValueError( + "With GCS storage engine you need to provide " + "the gcs_access, gcs_secret, sid, and bucket keyword arguments" + ) + + elif storage_engine == StorageEngines.local: + if path is not None: + self.path_prefix = path + else: + raise ValueError( + "With local storage you need to provide the path keyword argument" + ) diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index e0976b9..0e4e41b 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,7 +1,8 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse -from hellocomputer.analytics import DDB, StorageEngines +from hellocomputer.db import StorageEngines +from hellocomputer.analytics import AnalyticsDB from hellocomputer.extraction import extract_code_block from ..config import settings @@ -13,7 +14,7 @@ router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) async def query(sid: str = "", q: str = "") -> str: chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = DDB( + db = AnalyticsDB( StorageEngines.gcs, gcs_access=settings.gcs_access, gcs_secret=settings.gcs_secret, diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index b114e4b..5766e0d 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -4,7 +4,8 @@ import aiofiles from fastapi import APIRouter, File, UploadFile from fastapi.responses import JSONResponse -from ..analytics import DDB, StorageEngines +from ..db import StorageEngines +from ..analytics import AnalyticsDB from ..config import settings router = APIRouter() @@ -27,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""): await f.flush() ( - DDB( + AnalyticsDB( StorageEngines.gcs, gcs_access=settings.gcs_access, gcs_secret=settings.gcs_secret, diff --git a/src/hellocomputer/users.py b/src/hellocomputer/users.py new file mode 100644 index 0000000..96038f0 --- /dev/null +++ b/src/hellocomputer/users.py @@ -0,0 +1,2 @@ +class UserManagement: + pass diff --git a/test/test_data.py b/test/test_data.py index 2b8d8fa..337807d 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,7 +1,8 @@ from pathlib import Path import hellocomputer -from hellocomputer.analytics import DDB, StorageEngines +from hellocomputer.db import StorageEngines +from hellocomputer.analytics import AnalyticsDB TEST_STORAGE = StorageEngines.local TEST_XLS_PATH = ( @@ -14,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" def test_0_dump(): - db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) db.load_xls(TEST_XLS_PATH).dump() assert db.sheets == ("answers",) @@ -22,19 +23,19 @@ def test_0_dump(): def test_load(): - db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() results = db.query("select * from answers").fetchall() assert len(results) == 6 def test_load_description(): - db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() file_description = db.load_description() assert file_description.startswith("answers") def test_schema(): - db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() schema = [] for sheet in db.sheets: schema.append(db.table_schema(sheet)) @@ -43,7 +44,7 @@ def test_schema(): def test_query_prompt(): - db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() assert db.query_prompt("Find the average score of all students").startswith( "The following sentence" diff --git a/test/test_query.py b/test/test_query.py index 3995d2c..35e0d34 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -2,7 +2,8 @@ from pathlib import Path import hellocomputer import pytest -from hellocomputer.analytics import DDB, StorageEngines +from hellocomputer.db import StorageEngines +from hellocomputer.analytics import AnalyticsDB from hellocomputer.config import settings from hellocomputer.extraction import extract_code_block from hellocomputer.models import Chat @@ -33,9 +34,9 @@ async def test_simple_data_query(): query = "write a query that finds the average score of all students in the current database" chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls( - TEST_XLS_PATH - ) + db = AnalyticsDB( + storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent + ).load_xls(TEST_XLS_PATH) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) query = extract_code_block(chat.last_response_content())