This commit is contained in:
parent
04351888a8
commit
56dc012e23
|
@ -1,65 +1,11 @@
|
||||||
import os
|
import os
|
||||||
from enum import StrEnum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import duckdb
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
from .db import DDB
|
||||||
|
|
||||||
|
|
||||||
class StorageEngines(StrEnum):
|
class AnalyticsDB(DDB):
|
||||||
local = "Local"
|
|
||||||
gcs = "GCS"
|
|
||||||
|
|
||||||
|
|
||||||
class DDB:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
storage_engine: StorageEngines,
|
|
||||||
sid: str | None = None,
|
|
||||||
path: Path | None = None,
|
|
||||||
gcs_access: str | None = None,
|
|
||||||
gcs_secret: str | None = None,
|
|
||||||
bucket: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.db = duckdb.connect()
|
|
||||||
self.db.install_extension("spatial")
|
|
||||||
self.db.install_extension("httpfs")
|
|
||||||
self.db.load_extension("spatial")
|
|
||||||
self.db.load_extension("httpfs")
|
|
||||||
self.sheets = tuple()
|
|
||||||
self.loaded = False
|
|
||||||
|
|
||||||
if storage_engine == StorageEngines.gcs:
|
|
||||||
if all(
|
|
||||||
(
|
|
||||||
gcs_access is not None,
|
|
||||||
gcs_secret is not None,
|
|
||||||
bucket is not None,
|
|
||||||
sid is not None,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
self.db.sql(f"""
|
|
||||||
CREATE SECRET (
|
|
||||||
TYPE GCS,
|
|
||||||
KEY_ID '{gcs_access}',
|
|
||||||
SECRET '{gcs_secret}')
|
|
||||||
""")
|
|
||||||
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"With GCS storage engine you need to provide "
|
|
||||||
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif storage_engine == StorageEngines.local:
|
|
||||||
if path is not None:
|
|
||||||
self.path_prefix = path
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"With local storage you need to provide the path keyword argument"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_xls(self, xls_path: Path) -> Self:
|
def load_xls(self, xls_path: Path) -> Self:
|
||||||
"""For some reason, the header is not loaded"""
|
"""For some reason, the header is not loaded"""
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
|
|
58
src/hellocomputer/db.py
Normal file
58
src/hellocomputer/db.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
from enum import StrEnum
|
||||||
|
import duckdb
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class StorageEngines(StrEnum):
|
||||||
|
local = "Local"
|
||||||
|
gcs = "GCS"
|
||||||
|
|
||||||
|
|
||||||
|
class DDB:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_engine: StorageEngines,
|
||||||
|
sid: str | None = None,
|
||||||
|
path: Path | None = None,
|
||||||
|
gcs_access: str | None = None,
|
||||||
|
gcs_secret: str | None = None,
|
||||||
|
bucket: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.db = duckdb.connect()
|
||||||
|
self.db.install_extension("spatial")
|
||||||
|
self.db.install_extension("httpfs")
|
||||||
|
self.db.load_extension("spatial")
|
||||||
|
self.db.load_extension("httpfs")
|
||||||
|
self.sheets = tuple()
|
||||||
|
self.loaded = False
|
||||||
|
|
||||||
|
if storage_engine == StorageEngines.gcs:
|
||||||
|
if all(
|
||||||
|
(
|
||||||
|
gcs_access is not None,
|
||||||
|
gcs_secret is not None,
|
||||||
|
bucket is not None,
|
||||||
|
sid is not None,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
self.db.sql(f"""
|
||||||
|
CREATE SECRET (
|
||||||
|
TYPE GCS,
|
||||||
|
KEY_ID '{gcs_access}',
|
||||||
|
SECRET '{gcs_secret}')
|
||||||
|
""")
|
||||||
|
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"With GCS storage engine you need to provide "
|
||||||
|
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif storage_engine == StorageEngines.local:
|
||||||
|
if path is not None:
|
||||||
|
self.path_prefix = path
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"With local storage you need to provide the path keyword argument"
|
||||||
|
)
|
|
@ -1,7 +1,8 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
from hellocomputer.analytics import DDB, StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
|
from hellocomputer.analytics import AnalyticsDB
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
|
@ -13,7 +14,7 @@ router = APIRouter()
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "", q: str = "") -> str:
|
async def query(sid: str = "", q: str = "") -> str:
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = DDB(
|
db = AnalyticsDB(
|
||||||
StorageEngines.gcs,
|
StorageEngines.gcs,
|
||||||
gcs_access=settings.gcs_access,
|
gcs_access=settings.gcs_access,
|
||||||
gcs_secret=settings.gcs_secret,
|
gcs_secret=settings.gcs_secret,
|
||||||
|
|
|
@ -4,7 +4,8 @@ import aiofiles
|
||||||
from fastapi import APIRouter, File, UploadFile
|
from fastapi import APIRouter, File, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..analytics import DDB, StorageEngines
|
from ..db import StorageEngines
|
||||||
|
from ..analytics import AnalyticsDB
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -27,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
|
||||||
await f.flush()
|
await f.flush()
|
||||||
|
|
||||||
(
|
(
|
||||||
DDB(
|
AnalyticsDB(
|
||||||
StorageEngines.gcs,
|
StorageEngines.gcs,
|
||||||
gcs_access=settings.gcs_access,
|
gcs_access=settings.gcs_access,
|
||||||
gcs_secret=settings.gcs_secret,
|
gcs_secret=settings.gcs_secret,
|
||||||
|
|
2
src/hellocomputer/users.py
Normal file
2
src/hellocomputer/users.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
class UserManagement:
|
||||||
|
pass
|
|
@ -1,7 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.analytics import DDB, StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
|
from hellocomputer.analytics import AnalyticsDB
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
TEST_STORAGE = StorageEngines.local
|
||||||
TEST_XLS_PATH = (
|
TEST_XLS_PATH = (
|
||||||
|
@ -14,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
|
||||||
|
|
||||||
def test_0_dump():
|
def test_0_dump():
|
||||||
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||||
db.load_xls(TEST_XLS_PATH).dump()
|
db.load_xls(TEST_XLS_PATH).dump()
|
||||||
|
|
||||||
assert db.sheets == ("answers",)
|
assert db.sheets == ("answers",)
|
||||||
|
@ -22,19 +23,19 @@ def test_0_dump():
|
||||||
|
|
||||||
|
|
||||||
def test_load():
|
def test_load():
|
||||||
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
results = db.query("select * from answers").fetchall()
|
results = db.query("select * from answers").fetchall()
|
||||||
assert len(results) == 6
|
assert len(results) == 6
|
||||||
|
|
||||||
|
|
||||||
def test_load_description():
|
def test_load_description():
|
||||||
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
file_description = db.load_description()
|
file_description = db.load_description()
|
||||||
assert file_description.startswith("answers")
|
assert file_description.startswith("answers")
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
schema = []
|
schema = []
|
||||||
for sheet in db.sheets:
|
for sheet in db.sheets:
|
||||||
schema.append(db.table_schema(sheet))
|
schema.append(db.table_schema(sheet))
|
||||||
|
@ -43,7 +44,7 @@ def test_schema():
|
||||||
|
|
||||||
|
|
||||||
def test_query_prompt():
|
def test_query_prompt():
|
||||||
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
|
|
||||||
assert db.query_prompt("Find the average score of all students").startswith(
|
assert db.query_prompt("Find the average score of all students").startswith(
|
||||||
"The following sentence"
|
"The following sentence"
|
||||||
|
|
|
@ -2,7 +2,8 @@ from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.analytics import DDB, StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
|
from hellocomputer.analytics import AnalyticsDB
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
from hellocomputer.models import Chat
|
from hellocomputer.models import Chat
|
||||||
|
@ -33,9 +34,9 @@ async def test_simple_data_query():
|
||||||
query = "write a query that finds the average score of all students in the current database"
|
query = "write a query that finds the average score of all students in the current database"
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
|
db = AnalyticsDB(
|
||||||
TEST_XLS_PATH
|
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
||||||
)
|
).load_xls(TEST_XLS_PATH)
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
|
|
Loading…
Reference in a new issue