Record ownership
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-06-11 18:54:10 +02:00
parent 73ee66db44
commit 1cc59d3707
8 changed files with 77 additions and 17 deletions

View file

@ -1,7 +1,7 @@
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from hellocomputer.analytics import AnalyticsDB from hellocomputer.sessions import SessionDB
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
@ -14,7 +14,7 @@ router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = AnalyticsDB( db = SessionDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret, gcs_secret=settings.gcs_secret,

View file

@ -4,7 +4,7 @@ import aiofiles
from fastapi import APIRouter, File, UploadFile from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ..analytics import AnalyticsDB from ..sessions import SessionDB
from ..config import settings from ..config import settings
from ..db import StorageEngines from ..db import StorageEngines
@ -28,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush() await f.flush()
( (
AnalyticsDB( SessionDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret, gcs_secret=settings.gcs_secret,

View file

@ -3,6 +3,9 @@ from uuid import uuid4
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from starlette.requests import Request from starlette.requests import Request
from hellocomputer.users import OwnershipDB
from hellocomputer.db import StorageEngines
from ..config import settings
# Scheme for the Authorization header # Scheme for the Authorization header
@ -11,9 +14,16 @@ router = APIRouter()
@router.get("/new_session") @router.get("/new_session")
async def get_new_session(request: Request) -> str: async def get_new_session(request: Request) -> str:
user = request.session.get("user") user_email = request.session.get("user").get("email")
print(user) ownership = OwnershipDB(
return str(uuid4()) StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
)
sid = str(uuid4())
return ownership.set_ownersip(user_email, sid)
@router.get("/greetings", response_class=PlainTextResponse) @router.get("/greetings", response_class=PlainTextResponse)

View file

@ -9,7 +9,7 @@ from hellocomputer.db import StorageEngines
from .db import DDB from .db import DDB
class AnalyticsDB(DDB): class SessionDB(DDB):
def __init__( def __init__(
self, self,
storage_engine: StorageEngines, storage_engine: StorageEngines,

View file

@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import duckdb import duckdb
import polars as pl import polars as pl
from datetime import datetime
from .db import DDB, StorageEngines from .db import DDB, StorageEngines
@ -47,3 +48,43 @@ class UserDB(DDB):
@staticmethod @staticmethod
def email(record: str) -> str: def email(record: str) -> str:
return json.loads(record)["email"] return json.loads(record)["email"]
class OwnershipDB(DDB):
def __init__(
self,
storage_engine: StorageEngines,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gcs://{bucket}/owners"
elif storage_engine == StorageEngines.local:
self.path_prefix = path / "owners"
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
now = datetime.now().isoformat()
record_id = uuid4() if record_id is None else record_id
query = f"""
COPY
(
SELECT
'{user_email}' as email,
'{sid}' as sid,
'{now}' as timestamp
)
TO '{self.path_prefix}/{record_id}.csv' (FORMAT JSON)"""
try:
self.db.sql(query)
except duckdb.duckdb.IOException:
os.makedirs(self.path_prefix)
self.db.sql(query)
return sid

View file

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.analytics import AnalyticsDB from hellocomputer.sessions import SessionDB
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
TEST_STORAGE = StorageEngines.local TEST_STORAGE = StorageEngines.local
@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump(): def test_0_dump():
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test") db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
db.load_xls(TEST_XLS_PATH).dump() db.load_xls(TEST_XLS_PATH).dump()
assert db.sheets == ("answers",) assert db.sheets == ("answers",)
@ -23,7 +23,7 @@ def test_0_dump():
def test_load(): def test_load():
db = AnalyticsDB( db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder() ).load_folder()
results = db.query("select * from answers").fetchall() results = db.query("select * from answers").fetchall()
@ -31,7 +31,7 @@ def test_load():
def test_load_description(): def test_load_description():
db = AnalyticsDB( db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder() ).load_folder()
file_description = db.load_description() file_description = db.load_description()
@ -39,7 +39,7 @@ def test_load_description():
def test_schema(): def test_schema():
db = AnalyticsDB( db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder() ).load_folder()
schema = [] schema = []
@ -50,7 +50,7 @@ def test_schema():
def test_query_prompt(): def test_query_prompt():
db = AnalyticsDB( db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder() ).load_folder()

View file

@ -2,7 +2,7 @@ from pathlib import Path
import hellocomputer import hellocomputer
import pytest import pytest
from hellocomputer.analytics import AnalyticsDB from hellocomputer.sessions import SessionDB
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
@ -34,7 +34,7 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database" query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = AnalyticsDB( db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
).load_xls(TEST_XLS_PATH) ).load_xls(TEST_XLS_PATH)

View file

@ -2,7 +2,7 @@ from pathlib import Path
import hellocomputer import hellocomputer
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.users import UserDB from hellocomputer.users import UserDB, OwnershipDB
TEST_STORAGE = StorageEngines.local TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
@ -23,3 +23,12 @@ def test_user_exists():
assert user.user_exists("[email protected]") assert user.user_exists("[email protected]")
assert not user.user_exists("notpresent") assert not user.user_exists("notpresent")
def test_assign_owner():
assert (
OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip(
"something.something@something", "1234", "test"
)
== "1234"
)