Refactored db as well.
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-06-11 14:19:51 +02:00
parent 04351888a8
commit 56dc012e23
7 changed files with 80 additions and 70 deletions

View file

@ -1,65 +1,11 @@
import os import os
from enum import StrEnum
from pathlib import Path from pathlib import Path
import duckdb
from typing_extensions import Self from typing_extensions import Self
from .db import DDB
class StorageEngines(StrEnum): class AnalyticsDB(DDB):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.loaded = False
if storage_engine == StorageEngines.gcs:
if all(
(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
)
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local:
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
def load_xls(self, xls_path: Path) -> Self: def load_xls(self, xls_path: Path) -> Self:
"""For some reason, the header is not loaded""" """For some reason, the header is not loaded"""
self.db.sql(f""" self.db.sql(f"""

58
src/hellocomputer/db.py Normal file
View file

@ -0,0 +1,58 @@
from enum import StrEnum
import duckdb
from pathlib import Path
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.loaded = False
if storage_engine == StorageEngines.gcs:
if all(
(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
)
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local:
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)

View file

@ -1,7 +1,8 @@
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from hellocomputer.analytics import DDB, StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from ..config import settings from ..config import settings
@ -13,7 +14,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB( db = AnalyticsDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret, gcs_secret=settings.gcs_secret,

View file

@ -4,7 +4,8 @@ import aiofiles
from fastapi import APIRouter, File, UploadFile from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ..analytics import DDB, StorageEngines from ..db import StorageEngines
from ..analytics import AnalyticsDB
from ..config import settings from ..config import settings
router = APIRouter() router = APIRouter()
@ -27,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush() await f.flush()
( (
DDB( AnalyticsDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret, gcs_secret=settings.gcs_secret,

View file

@ -0,0 +1,2 @@
class UserManagement:
pass

View file

@ -1,7 +1,8 @@
from pathlib import Path from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.analytics import DDB, StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
TEST_STORAGE = StorageEngines.local TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = ( TEST_XLS_PATH = (
@ -14,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump(): def test_0_dump():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
db.load_xls(TEST_XLS_PATH).dump() db.load_xls(TEST_XLS_PATH).dump()
assert db.sheets == ("answers",) assert db.sheets == ("answers",)
@ -22,19 +23,19 @@ def test_0_dump():
def test_load(): def test_load():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
results = db.query("select * from answers").fetchall() results = db.query("select * from answers").fetchall()
assert len(results) == 6 assert len(results) == 6
def test_load_description(): def test_load_description():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
file_description = db.load_description() file_description = db.load_description()
assert file_description.startswith("answers") assert file_description.startswith("answers")
def test_schema(): def test_schema():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
schema = [] schema = []
for sheet in db.sheets: for sheet in db.sheets:
schema.append(db.table_schema(sheet)) schema.append(db.table_schema(sheet))
@ -43,7 +44,7 @@ def test_schema():
def test_query_prompt(): def test_query_prompt():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.query_prompt("Find the average score of all students").startswith( assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence" "The following sentence"

View file

@ -2,7 +2,8 @@ from pathlib import Path
import hellocomputer import hellocomputer
import pytest import pytest
from hellocomputer.analytics import DDB, StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat from hellocomputer.models import Chat
@ -33,9 +34,9 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database" query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls( db = AnalyticsDB(
TEST_XLS_PATH storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
) ).load_xls(TEST_XLS_PATH)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())