Refactored db as well.
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-06-11 14:19:51 +02:00
parent 04351888a8
commit 56dc012e23
7 changed files with 80 additions and 70 deletions

View file

@ -1,65 +1,11 @@
import os
from enum import StrEnum
from pathlib import Path
import duckdb
from typing_extensions import Self
from .db import DDB
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.loaded = False
if storage_engine == StorageEngines.gcs:
if all(
(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
)
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local:
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
class AnalyticsDB(DDB):
def load_xls(self, xls_path: Path) -> Self:
"""For some reason, the header is not loaded"""
self.db.sql(f"""

58
src/hellocomputer/db.py Normal file
View file

@ -0,0 +1,58 @@
from enum import StrEnum
import duckdb
from pathlib import Path
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.loaded = False
if storage_engine == StorageEngines.gcs:
if all(
(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
)
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local:
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)

View file

@ -1,7 +1,8 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.extraction import extract_code_block
from ..config import settings
@ -13,7 +14,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB(
db = AnalyticsDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,

View file

@ -4,7 +4,8 @@ import aiofiles
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..analytics import DDB, StorageEngines
from ..db import StorageEngines
from ..analytics import AnalyticsDB
from ..config import settings
router = APIRouter()
@ -27,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush()
(
DDB(
AnalyticsDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,

View file

@ -0,0 +1,2 @@
class UserManagement:
pass

View file

@ -1,7 +1,8 @@
from pathlib import Path
import hellocomputer
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (
@ -14,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
db.load_xls(TEST_XLS_PATH).dump()
assert db.sheets == ("answers",)
@ -22,19 +23,19 @@ def test_0_dump():
def test_load():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
results = db.query("select * from answers").fetchall()
assert len(results) == 6
def test_load_description():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
file_description = db.load_description()
assert file_description.startswith("answers")
def test_schema():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
schema = []
for sheet in db.sheets:
schema.append(db.table_schema(sheet))
@ -43,7 +44,7 @@ def test_schema():
def test_query_prompt():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"

View file

@ -2,7 +2,8 @@ from pathlib import Path
import hellocomputer
import pytest
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
@ -33,9 +34,9 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
TEST_XLS_PATH
)
db = AnalyticsDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content())