This commit is contained in:
parent
73ee66db44
commit
1cc59d3707
|
@ -1,7 +1,7 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
from hellocomputer.analytics import AnalyticsDB
|
from hellocomputer.sessions import SessionDB
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ router = APIRouter()
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "", q: str = "") -> str:
|
async def query(sid: str = "", q: str = "") -> str:
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = AnalyticsDB(
|
db = SessionDB(
|
||||||
StorageEngines.gcs,
|
StorageEngines.gcs,
|
||||||
gcs_access=settings.gcs_access,
|
gcs_access=settings.gcs_access,
|
||||||
gcs_secret=settings.gcs_secret,
|
gcs_secret=settings.gcs_secret,
|
||||||
|
|
|
@ -4,7 +4,7 @@ import aiofiles
|
||||||
from fastapi import APIRouter, File, UploadFile
|
from fastapi import APIRouter, File, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..analytics import AnalyticsDB
|
from ..sessions import SessionDB
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
from ..db import StorageEngines
|
from ..db import StorageEngines
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
|
||||||
await f.flush()
|
await f.flush()
|
||||||
|
|
||||||
(
|
(
|
||||||
AnalyticsDB(
|
SessionDB(
|
||||||
StorageEngines.gcs,
|
StorageEngines.gcs,
|
||||||
gcs_access=settings.gcs_access,
|
gcs_access=settings.gcs_access,
|
||||||
gcs_secret=settings.gcs_secret,
|
gcs_secret=settings.gcs_secret,
|
||||||
|
|
|
@ -3,6 +3,9 @@ from uuid import uuid4
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from hellocomputer.users import OwnershipDB
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
|
from ..config import settings
|
||||||
|
|
||||||
# Scheme for the Authorization header
|
# Scheme for the Authorization header
|
||||||
|
|
||||||
|
@ -11,9 +14,16 @@ router = APIRouter()
|
||||||
|
|
||||||
@router.get("/new_session")
|
@router.get("/new_session")
|
||||||
async def get_new_session(request: Request) -> str:
|
async def get_new_session(request: Request) -> str:
|
||||||
user = request.session.get("user")
|
user_email = request.session.get("user").get("email")
|
||||||
print(user)
|
ownership = OwnershipDB(
|
||||||
return str(uuid4())
|
StorageEngines.gcs,
|
||||||
|
gcs_access=settings.gcs_access,
|
||||||
|
gcs_secret=settings.gcs_secret,
|
||||||
|
bucket=settings.gcs_bucketname,
|
||||||
|
)
|
||||||
|
sid = str(uuid4())
|
||||||
|
|
||||||
|
return ownership.set_ownersip(user_email, sid)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/greetings", response_class=PlainTextResponse)
|
@router.get("/greetings", response_class=PlainTextResponse)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from hellocomputer.db import StorageEngines
|
||||||
from .db import DDB
|
from .db import DDB
|
||||||
|
|
||||||
|
|
||||||
class AnalyticsDB(DDB):
|
class SessionDB(DDB):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage_engine: StorageEngines,
|
storage_engine: StorageEngines,
|
|
@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from .db import DDB, StorageEngines
|
from .db import DDB, StorageEngines
|
||||||
|
|
||||||
|
@ -47,3 +48,43 @@ class UserDB(DDB):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def email(record: str) -> str:
|
def email(record: str) -> str:
|
||||||
return json.loads(record)["email"]
|
return json.loads(record)["email"]
|
||||||
|
|
||||||
|
|
||||||
|
class OwnershipDB(DDB):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_engine: StorageEngines,
|
||||||
|
path: Path | None = None,
|
||||||
|
gcs_access: str | None = None,
|
||||||
|
gcs_secret: str | None = None,
|
||||||
|
bucket: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||||
|
|
||||||
|
if storage_engine == StorageEngines.gcs:
|
||||||
|
self.path_prefix = f"gcs://{bucket}/owners"
|
||||||
|
|
||||||
|
elif storage_engine == StorageEngines.local:
|
||||||
|
self.path_prefix = path / "owners"
|
||||||
|
|
||||||
|
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
||||||
|
now = datetime.now().isoformat()
|
||||||
|
record_id = uuid4() if record_id is None else record_id
|
||||||
|
query = f"""
|
||||||
|
COPY
|
||||||
|
(
|
||||||
|
SELECT
|
||||||
|
'{user_email}' as email,
|
||||||
|
'{sid}' as sid,
|
||||||
|
'{now}' as timestamp
|
||||||
|
)
|
||||||
|
TO '{self.path_prefix}/{record_id}.csv' (FORMAT JSON)"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.db.sql(query)
|
||||||
|
except duckdb.duckdb.IOException:
|
||||||
|
os.makedirs(self.path_prefix)
|
||||||
|
self.db.sql(query)
|
||||||
|
|
||||||
|
return sid
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.analytics import AnalyticsDB
|
from hellocomputer.sessions import SessionDB
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
TEST_STORAGE = StorageEngines.local
|
||||||
|
@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
|
||||||
|
|
||||||
def test_0_dump():
|
def test_0_dump():
|
||||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
||||||
db.load_xls(TEST_XLS_PATH).dump()
|
db.load_xls(TEST_XLS_PATH).dump()
|
||||||
|
|
||||||
assert db.sheets == ("answers",)
|
assert db.sheets == ("answers",)
|
||||||
|
@ -23,7 +23,7 @@ def test_0_dump():
|
||||||
|
|
||||||
|
|
||||||
def test_load():
|
def test_load():
|
||||||
db = AnalyticsDB(
|
db = SessionDB(
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
).load_folder()
|
).load_folder()
|
||||||
results = db.query("select * from answers").fetchall()
|
results = db.query("select * from answers").fetchall()
|
||||||
|
@ -31,7 +31,7 @@ def test_load():
|
||||||
|
|
||||||
|
|
||||||
def test_load_description():
|
def test_load_description():
|
||||||
db = AnalyticsDB(
|
db = SessionDB(
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
).load_folder()
|
).load_folder()
|
||||||
file_description = db.load_description()
|
file_description = db.load_description()
|
||||||
|
@ -39,7 +39,7 @@ def test_load_description():
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
db = AnalyticsDB(
|
db = SessionDB(
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
).load_folder()
|
).load_folder()
|
||||||
schema = []
|
schema = []
|
||||||
|
@ -50,7 +50,7 @@ def test_schema():
|
||||||
|
|
||||||
|
|
||||||
def test_query_prompt():
|
def test_query_prompt():
|
||||||
db = AnalyticsDB(
|
db = SessionDB(
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
).load_folder()
|
).load_folder()
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.analytics import AnalyticsDB
|
from hellocomputer.sessions import SessionDB
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
@ -34,7 +34,7 @@ async def test_simple_data_query():
|
||||||
query = "write a query that finds the average score of all students in the current database"
|
query = "write a query that finds the average score of all students in the current database"
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = AnalyticsDB(
|
db = SessionDB(
|
||||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent
|
||||||
).load_xls(TEST_XLS_PATH)
|
).load_xls(TEST_XLS_PATH)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.users import UserDB
|
from hellocomputer.users import UserDB, OwnershipDB
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
TEST_STORAGE = StorageEngines.local
|
||||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
@ -23,3 +23,12 @@ def test_user_exists():
|
||||||
|
|
||||||
assert user.user_exists("[email protected]")
|
assert user.user_exists("[email protected]")
|
||||||
assert not user.user_exists("notpresent")
|
assert not user.user_exists("notpresent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_assign_owner():
|
||||||
|
assert (
|
||||||
|
OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip(
|
||||||
|
"something.something@something", "1234", "test"
|
||||||
|
)
|
||||||
|
== "1234"
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in a new issue