Pluggable authentication. Still need to fix gcs
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed
This commit is contained in:
parent
7743d93f1d
commit
e8755e627c
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -158,8 +158,12 @@ cython_debug/
|
|||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
.idea/
|
||||
|
||||
*DS_Store
|
||||
|
||||
test/data/output/*
|
||||
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
*.ipynb
|
1
notebooks/README.md
Normal file
1
notebooks/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Placeholder to store some notebooks. Gitignored
|
16
src/hellocomputer/auth.py
Normal file
16
src/hellocomputer/auth.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from starlette.requests import Request
|
||||
from .config import settings
|
||||
|
||||
|
||||
def get_user(request: Request) -> dict:
|
||||
if settings.auth:
|
||||
return request.session.get("user")
|
||||
else:
|
||||
return {"email": "test@test.com"}
|
||||
|
||||
|
||||
def get_user_email(request: Request) -> str:
|
||||
if settings.auth:
|
||||
return request.session.get("user").get("email")
|
||||
else:
|
||||
return "test@test.com"
|
|
@ -48,8 +48,9 @@ class DDB:
|
|||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(text("LOAD httpfs"))
|
||||
|
||||
self.path_prefix = f"gcs://{bucket}"
|
||||
self.path_prefix = f"gs://{bucket}"
|
||||
else:
|
||||
raise ValueError(
|
||||
"With GCS storage engine you need to provide "
|
||||
|
|
|
@ -24,7 +24,7 @@ class SessionDB(DDB):
|
|||
self.sid = sid
|
||||
# Override storage engine for sessions
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
||||
self.path_prefix = f"gs://{bucket}/sessions/{sid}"
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "sessions" / sid
|
||||
|
||||
|
|
|
@ -24,11 +24,13 @@ class UserDB(DDB):
|
|||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gcs://{bucket}/users"
|
||||
self.path_prefix = f"gs://{bucket}/users"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "users"
|
||||
|
||||
self.storage_engine = storage_engine
|
||||
|
||||
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
|
||||
df = pl.from_dict(user_data) # noqa
|
||||
record_id = uuid4() if record_id is None else record_id
|
||||
|
@ -36,9 +38,12 @@ class UserDB(DDB):
|
|||
|
||||
try:
|
||||
self.db.sql(query)
|
||||
except duckdb.duckdb.IOException:
|
||||
os.makedirs(self.path_prefix)
|
||||
self.db.sql(query)
|
||||
except duckdb.duckdb.IOException as e:
|
||||
if self.storage_engine == StorageEngines.local:
|
||||
os.makedirs(self.path_prefix)
|
||||
self.db.sql(query)
|
||||
else:
|
||||
raise e
|
||||
|
||||
return user_data
|
||||
|
||||
|
@ -64,7 +69,7 @@ class OwnershipDB(DDB):
|
|||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gcs://{bucket}/owners"
|
||||
self.path_prefix = f"gs://{bucket}/owners"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "owners"
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
@ -9,6 +8,7 @@ from starlette.requests import Request
|
|||
|
||||
import hellocomputer
|
||||
|
||||
from .auth import get_user
|
||||
from .config import settings
|
||||
from .routers import analysis, auth, files, health, sessions
|
||||
|
||||
|
@ -20,7 +20,7 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key)
|
|||
|
||||
@app.get("/")
|
||||
async def homepage(request: Request):
|
||||
user = request.session.get("user")
|
||||
user = get_user(request)
|
||||
if user:
|
||||
return RedirectResponse("/app")
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from enum import StrEnum
|
||||
|
||||
from langchain_fireworks import Fireworks
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
|
||||
class AvailableModels(StrEnum):
|
||||
|
|
|
@ -9,6 +9,7 @@ from hellocomputer.db import StorageEngines
|
|||
from hellocomputer.db.users import OwnershipDB
|
||||
|
||||
from ..config import settings
|
||||
from ..auth import get_user_email
|
||||
|
||||
# Scheme for the Authorization header
|
||||
|
||||
|
@ -30,7 +31,7 @@ async def get_greeting() -> str:
|
|||
|
||||
@router.get("/sessions")
|
||||
async def get_sessions(request: Request) -> List[str]:
|
||||
user_email = request.session.get("user").get("email")
|
||||
user_email = get_user_email(request)
|
||||
ownership = OwnershipDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
import pytest
|
||||
import polars as pl
|
||||
import pytest
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.models import Chat
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
|
|
Loading…
Reference in a new issue