From 1cc59d3707e4dd3a0b7fe160d09b6476143d48cd Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Tue, 11 Jun 2024 18:54:10 +0200 Subject: [PATCH] Record ownership --- src/hellocomputer/routers/analysis.py | 4 +- src/hellocomputer/routers/files.py | 4 +- src/hellocomputer/routers/sessions.py | 16 ++++++-- .../{analytics.py => sessions.py} | 2 +- src/hellocomputer/users.py | 41 +++++++++++++++++++ test/test_data.py | 12 +++--- test/test_query.py | 4 +- test/test_user.py | 11 ++++- 8 files changed, 77 insertions(+), 17 deletions(-) rename src/hellocomputer/{analytics.py => sessions.py} (99%) diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 8d07096..08c763c 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,7 +1,7 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse -from hellocomputer.analytics import AnalyticsDB +from hellocomputer.sessions import SessionDB from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block @@ -14,7 +14,7 @@ router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) async def query(sid: str = "", q: str = "") -> str: chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = AnalyticsDB( + db = SessionDB( StorageEngines.gcs, gcs_access=settings.gcs_access, gcs_secret=settings.gcs_secret, diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index fcac0b4..74542b3 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -4,7 +4,7 @@ import aiofiles from fastapi import APIRouter, File, UploadFile from fastapi.responses import JSONResponse -from ..analytics import AnalyticsDB +from ..sessions import SessionDB from ..config import settings from ..db import StorageEngines @@ -28,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""): await f.flush() ( - AnalyticsDB( + SessionDB( StorageEngines.gcs, gcs_access=settings.gcs_access, gcs_secret=settings.gcs_secret, diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index bee8fd1..98ec2c7 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -3,6 +3,9 @@ from uuid import uuid4 from fastapi import APIRouter from fastapi.responses import PlainTextResponse from starlette.requests import Request +from hellocomputer.users import OwnershipDB +from hellocomputer.db import StorageEngines +from ..config import settings # Scheme for the Authorization header @@ -11,9 +14,16 @@ router = APIRouter() @router.get("/new_session") async def get_new_session(request: Request) -> str: - user = request.session.get("user") - print(user) - return str(uuid4()) + user_email = request.session.get("user").get("email") + ownership = OwnershipDB( + StorageEngines.gcs, + gcs_access=settings.gcs_access, + gcs_secret=settings.gcs_secret, + bucket=settings.gcs_bucketname, + ) + sid = str(uuid4()) + + return ownership.set_ownersip(user_email, sid) @router.get("/greetings", response_class=PlainTextResponse) diff --git a/src/hellocomputer/analytics.py b/src/hellocomputer/sessions.py similarity index 99% rename from src/hellocomputer/analytics.py rename to src/hellocomputer/sessions.py index d09bf20..6e53c2c 100644 --- a/src/hellocomputer/analytics.py +++ b/src/hellocomputer/sessions.py @@ -9,7 +9,7 @@ from hellocomputer.db import StorageEngines from .db import DDB -class AnalyticsDB(DDB): +class SessionDB(DDB): def __init__( self, storage_engine: StorageEngines, diff --git a/src/hellocomputer/users.py b/src/hellocomputer/users.py index 0c2142c..f64d32d 100644 --- a/src/hellocomputer/users.py +++ b/src/hellocomputer/users.py @@ -5,6 +5,7 @@ from uuid import UUID, uuid4 import duckdb import polars as pl +from datetime import datetime from .db import DDB, StorageEngines @@ -47,3 +48,43 @@ class UserDB(DDB): @staticmethod def email(record: str) -> str: return json.loads(record)["email"] + + +class OwnershipDB(DDB): + def __init__( + self, + storage_engine: StorageEngines, + path: Path | None = None, + gcs_access: str | None = None, + gcs_secret: str | None = None, + bucket: str | None = None, + **kwargs, + ): + super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) + + if storage_engine == StorageEngines.gcs: + self.path_prefix = f"gcs://{bucket}/owners" + + elif storage_engine == StorageEngines.local: + self.path_prefix = path / "owners" + + def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None): + now = datetime.now().isoformat() + record_id = uuid4() if record_id is None else record_id + query = f""" + COPY + ( + SELECT + '{user_email}' as email, + '{sid}' as sid, + '{now}' as timestamp + ) + TO '{self.path_prefix}/{record_id}.csv' (FORMAT JSON)""" + + try: + self.db.sql(query) + except duckdb.duckdb.IOException: + os.makedirs(self.path_prefix) + self.db.sql(query) + + return sid diff --git a/test/test_data.py b/test/test_data.py index 926b6da..56538c8 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,7 +1,7 @@ from pathlib import Path import hellocomputer -from hellocomputer.analytics import AnalyticsDB +from hellocomputer.sessions import SessionDB from hellocomputer.db import StorageEngines TEST_STORAGE = StorageEngines.local @@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" def test_0_dump(): - db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test") + db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test") db.load_xls(TEST_XLS_PATH).dump() assert db.sheets == ("answers",) @@ -23,7 +23,7 @@ def test_0_dump(): def test_load(): - db = AnalyticsDB( + db = SessionDB( storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" ).load_folder() results = db.query("select * from answers").fetchall() @@ -31,7 +31,7 @@ def test_load(): def test_load_description(): - db = AnalyticsDB( + db = SessionDB( storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" ).load_folder() file_description = db.load_description() @@ -39,7 +39,7 @@ def test_load_description(): def test_schema(): - db = AnalyticsDB( + db = SessionDB( storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" ).load_folder() schema = [] @@ -50,7 +50,7 @@ def test_schema(): def test_query_prompt(): - db = AnalyticsDB( + db = SessionDB( storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" ).load_folder() diff --git a/test/test_query.py b/test/test_query.py index 7064ebe..e65b9f4 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -2,7 +2,7 @@ from pathlib import Path import hellocomputer import pytest -from hellocomputer.analytics import AnalyticsDB +from hellocomputer.sessions import SessionDB from hellocomputer.config import settings from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block @@ -34,7 +34,7 @@ async def test_simple_data_query(): query = "write a query that finds the average score of all students in the current database" chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = AnalyticsDB( + db = SessionDB( storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent ).load_xls(TEST_XLS_PATH) diff --git a/test/test_user.py b/test/test_user.py index 6e325ae..8acf1f6 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -2,7 +2,7 @@ from pathlib import Path import hellocomputer from hellocomputer.db import StorageEngines -from hellocomputer.users import UserDB +from hellocomputer.users import UserDB, OwnershipDB TEST_STORAGE = StorageEngines.local TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" @@ -23,3 +23,12 @@ def test_user_exists(): assert user.user_exists("[email protected]") assert not user.user_exists("notpresent") + + +def test_assign_owner(): + assert ( + OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip( + "something.something@something", "1234", "test" + ) + == "1234" + )