This commit is contained in:
parent
73ee66db44
commit
1cc59d3707
|
@ -1,7 +1,7 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from hellocomputer.analytics import AnalyticsDB
|
||||
from hellocomputer.sessions import SessionDB
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
|
||||
|
@ -14,7 +14,7 @@ router = APIRouter()
|
|||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||
async def query(sid: str = "", q: str = "") -> str:
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = AnalyticsDB(
|
||||
db = SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
|
|
|
@ -4,7 +4,7 @@ import aiofiles
|
|||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..analytics import AnalyticsDB
|
||||
from ..sessions import SessionDB
|
||||
from ..config import settings
|
||||
from ..db import StorageEngines
|
||||
|
||||
|
@ -28,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
|
|||
await f.flush()
|
||||
|
||||
(
|
||||
AnalyticsDB(
|
||||
SessionDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
|
|
|
@ -3,6 +3,9 @@ from uuid import uuid4
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from starlette.requests import Request
|
||||
from hellocomputer.users import OwnershipDB
|
||||
from hellocomputer.db import StorageEngines
|
||||
from ..config import settings
|
||||
|
||||
# Scheme for the Authorization header
|
||||
|
||||
|
@ -11,9 +14,16 @@ router = APIRouter()
|
|||
|
||||
@router.get("/new_session")
|
||||
async def get_new_session(request: Request) -> str:
|
||||
user = request.session.get("user")
|
||||
print(user)
|
||||
return str(uuid4())
|
||||
user_email = request.session.get("user").get("email")
|
||||
ownership = OwnershipDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
bucket=settings.gcs_bucketname,
|
||||
)
|
||||
sid = str(uuid4())
|
||||
|
||||
return ownership.set_ownersip(user_email, sid)
|
||||
|
||||
|
||||
@router.get("/greetings", response_class=PlainTextResponse)
|
||||
|
|
|
@ -9,7 +9,7 @@ from hellocomputer.db import StorageEngines
|
|||
from .db import DDB
|
||||
|
||||
|
||||
class AnalyticsDB(DDB):
|
||||
class SessionDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
|
@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
|||
|
||||
import duckdb
|
||||
import polars as pl
|
||||
from datetime import datetime
|
||||
|
||||
from .db import DDB, StorageEngines
|
||||
|
||||
|
@ -47,3 +48,43 @@ class UserDB(DDB):
|
|||
@staticmethod
|
||||
def email(record: str) -> str:
|
||||
return json.loads(record)["email"]
|
||||
|
||||
|
||||
class OwnershipDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gcs://{bucket}/owners"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "owners"
|
||||
|
||||
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
||||
now = datetime.now().isoformat()
|
||||
record_id = uuid4() if record_id is None else record_id
|
||||
query = f"""
|
||||
COPY
|
||||
(
|
||||
SELECT
|
||||
'{user_email}' as email,
|
||||
'{sid}' as sid,
|
||||
'{now}' as timestamp
|
||||
)
|
||||
TO '{self.path_prefix}/{record_id}.csv' (FORMAT JSON)"""
|
||||
|
||||
try:
|
||||
self.db.sql(query)
|
||||
except duckdb.duckdb.IOException:
|
||||
os.makedirs(self.path_prefix)
|
||||
self.db.sql(query)
|
||||
|
||||
return sid
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.analytics import AnalyticsDB
|
||||
from hellocomputer.sessions import SessionDB
|
||||
from hellocomputer.db import StorageEngines
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
|
@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
|||
|
||||
|
||||
def test_0_dump():
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
||||
db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
||||
db.load_xls(TEST_XLS_PATH).dump()
|
||||
|
||||
assert db.sheets == ("answers",)
|
||||
|
@ -23,7 +23,7 @@ def test_0_dump():
|
|||
|
||||
|
||||
def test_load():
|
||||
db = AnalyticsDB(
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
results = db.query("select * from answers").fetchall()
|
||||
|
@ -31,7 +31,7 @@ def test_load():
|
|||
|
||||
|
||||
def test_load_description():
|
||||
db = AnalyticsDB(
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
file_description = db.load_description()
|
||||
|
@ -39,7 +39,7 @@ def test_load_description():
|
|||
|
||||
|
||||
def test_schema():
|
||||
db = AnalyticsDB(
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
schema = []
|
||||
|
@ -50,7 +50,7 @@ def test_schema():
|
|||
|
||||
|
||||
def test_query_prompt():
|
||||
db = AnalyticsDB(
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
import hellocomputer
|
||||
import pytest
|
||||
from hellocomputer.analytics import AnalyticsDB
|
||||
from hellocomputer.sessions import SessionDB
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
|
@ -34,7 +34,7 @@ async def test_simple_data_query():
|
|||
query = "write a query that finds the average score of all students in the current database"
|
||||
|
||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||
db = AnalyticsDB(
|
||||
db = SessionDB(
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
||||
).load_xls(TEST_XLS_PATH)
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.users import UserDB
|
||||
from hellocomputer.users import UserDB, OwnershipDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
|
@ -23,3 +23,12 @@ def test_user_exists():
|
|||
|
||||
assert user.user_exists("[email protected]")
|
||||
assert not user.user_exists("notpresent")
|
||||
|
||||
|
||||
def test_assign_owner():
|
||||
assert (
|
||||
OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip(
|
||||
"something.something@something", "1234", "test"
|
||||
)
|
||||
== "1234"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue