Pluggable authentication. Still need to fix gcs
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-13 22:46:34 +02:00
parent 7743d93f1d
commit e8755e627c
10 changed files with 43 additions and 15 deletions

6
.gitignore vendored
View file

@ -158,8 +158,12 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
*DS_Store *DS_Store
test/data/output/* test/data/output/*
.pytest_cache
.ruff_cache
*.ipynb

1
notebooks/README.md Normal file
View file

@ -0,0 +1 @@
Placeholder to store some notebooks. Gitignored

16
src/hellocomputer/auth.py Normal file
View file

@ -0,0 +1,16 @@
from starlette.requests import Request
from .config import settings
def get_user(request: Request) -> dict:
if settings.auth:
return request.session.get("user")
else:
return {"email": "test@test.com"}
def get_user_email(request: Request) -> str:
if settings.auth:
return request.session.get("user").get("email")
else:
return "test@test.com"

View file

@ -48,8 +48,9 @@ class DDB:
""" """
) )
) )
conn.execute(text("LOAD httpfs"))
self.path_prefix = f"gcs://{bucket}" self.path_prefix = f"gs://{bucket}"
else: else:
raise ValueError( raise ValueError(
"With GCS storage engine you need to provide " "With GCS storage engine you need to provide "

View file

@ -24,7 +24,7 @@ class SessionDB(DDB):
self.sid = sid self.sid = sid
# Override storage engine for sessions # Override storage engine for sessions
if storage_engine == StorageEngines.gcs: if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gcs://{bucket}/sessions/{sid}" self.path_prefix = f"gs://{bucket}/sessions/{sid}"
elif storage_engine == StorageEngines.local: elif storage_engine == StorageEngines.local:
self.path_prefix = path / "sessions" / sid self.path_prefix = path / "sessions" / sid

View file

@ -24,11 +24,13 @@ class UserDB(DDB):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
if storage_engine == StorageEngines.gcs: if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gcs://{bucket}/users" self.path_prefix = f"gs://{bucket}/users"
elif storage_engine == StorageEngines.local: elif storage_engine == StorageEngines.local:
self.path_prefix = path / "users" self.path_prefix = path / "users"
self.storage_engine = storage_engine
def dump_user_record(self, user_data: dict, record_id: UUID | None = None): def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
df = pl.from_dict(user_data) # noqa df = pl.from_dict(user_data) # noqa
record_id = uuid4() if record_id is None else record_id record_id = uuid4() if record_id is None else record_id
@ -36,9 +38,12 @@ class UserDB(DDB):
try: try:
self.db.sql(query) self.db.sql(query)
except duckdb.duckdb.IOException: except duckdb.duckdb.IOException as e:
os.makedirs(self.path_prefix) if self.storage_engine == StorageEngines.local:
self.db.sql(query) os.makedirs(self.path_prefix)
self.db.sql(query)
else:
raise e
return user_data return user_data
@ -64,7 +69,7 @@ class OwnershipDB(DDB):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
if storage_engine == StorageEngines.gcs: if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gcs://{bucket}/owners" self.path_prefix = f"gs://{bucket}/owners"
elif storage_engine == StorageEngines.local: elif storage_engine == StorageEngines.local:
self.path_prefix = path / "owners" self.path_prefix = path / "owners"

View file

@ -1,4 +1,3 @@
import json
from pathlib import Path from pathlib import Path
from fastapi import FastAPI from fastapi import FastAPI
@ -9,6 +8,7 @@ from starlette.requests import Request
import hellocomputer import hellocomputer
from .auth import get_user
from .config import settings from .config import settings
from .routers import analysis, auth, files, health, sessions from .routers import analysis, auth, files, health, sessions
@ -20,7 +20,7 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key)
@app.get("/") @app.get("/")
async def homepage(request: Request): async def homepage(request: Request):
user = request.session.get("user") user = get_user(request)
if user: if user:
return RedirectResponse("/app") return RedirectResponse("/app")

View file

@ -1,7 +1,7 @@
from enum import StrEnum from enum import StrEnum
from langchain_fireworks import Fireworks
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_fireworks import Fireworks
class AvailableModels(StrEnum): class AvailableModels(StrEnum):

View file

@ -9,6 +9,7 @@ from hellocomputer.db import StorageEngines
from hellocomputer.db.users import OwnershipDB from hellocomputer.db.users import OwnershipDB
from ..config import settings from ..config import settings
from ..auth import get_user_email
# Scheme for the Authorization header # Scheme for the Authorization header
@ -30,7 +31,7 @@ async def get_greeting() -> str:
@router.get("/sessions") @router.get("/sessions")
async def get_sessions(request: Request) -> List[str]: async def get_sessions(request: Request) -> List[str]:
user_email = request.session.get("user").get("email") user_email = get_user_email(request)
ownership = OwnershipDB( ownership = OwnershipDB(
StorageEngines.gcs, StorageEngines.gcs,
gcs_access=settings.gcs_access, gcs_access=settings.gcs_access,

View file

@ -1,13 +1,13 @@
from pathlib import Path from pathlib import Path
import hellocomputer import hellocomputer
import pytest
import polars as pl import polars as pl
import pytest
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat from hellocomputer.models import Chat
from hellocomputer.db.sessions import SessionDB
TEST_STORAGE = StorageEngines.local TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"