From 56ec151b70e7e062700c09bbf212fa60e511043d Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Mon, 10 Jun 2024 08:32:51 +0200 Subject: [PATCH] Now with proper authentication --- requirements.in | 4 +- src/hellocomputer/config.py | 5 +++ src/hellocomputer/main.py | 62 ++++++++++++++++++++++++++- src/hellocomputer/routers/sessions.py | 9 ++-- src/hellocomputer/security.py | 3 -- src/hellocomputer/static/index.html | 17 ++++++-- src/hellocomputer/static/login.html | 43 +++++++++++++++++++ src/hellocomputer/static/script.js | 8 ++-- 8 files changed, 134 insertions(+), 17 deletions(-) create mode 100644 src/hellocomputer/static/login.html diff --git a/requirements.in b/requirements.in index 0100067..6795e1d 100644 --- a/requirements.in +++ b/requirements.in @@ -7,4 +7,6 @@ s3fs aiofiles duckdb pyjwt[crypto] -python-multipart \ No newline at end of file +python-multipart +authlib +itsdangerous \ No newline at end of file diff --git a/src/hellocomputer/config.py b/src/hellocomputer/config.py index 6e86baa..aabaf44 100644 --- a/src/hellocomputer/config.py +++ b/src/hellocomputer/config.py @@ -2,11 +2,16 @@ from pydantic_settings import BaseSettings, SettingsConfigDict class Settings(BaseSettings): + base_url: str = "http://localhost:8000" anyscale_api_key: str = "Awesome API" gcs_access: str = "access" gcs_secret: str = "secret" gcs_bucketname: str = "bucket" auth: bool = True + auth0_client_id: str = "" + auth0_client_secret: str = "" + auth0_domain: str = "" + app_secret_key: str = "" model_config = SettingsConfigDict(env_file=".env") diff --git a/src/hellocomputer/main.py b/src/hellocomputer/main.py index 127dfc7..ec17aab 100644 --- a/src/hellocomputer/main.py +++ b/src/hellocomputer/main.py @@ -1,16 +1,76 @@ from pathlib import Path from fastapi import FastAPI, status +from fastapi.responses import RedirectResponse, HTMLResponse from fastapi.staticfiles import StaticFiles +from starlette.middleware.sessions import SessionMiddleware +from starlette.requests import Request +from authlib.integrations.starlette_client import OAuth, OAuthError from pydantic import BaseModel +import json + import hellocomputer from .routers import analysis, files, sessions +from .config import settings static_path = Path(hellocomputer.__file__).parent / "static" +oauth = OAuth() +oauth.register( + "auth0", + client_id=settings.auth0_client_id, + client_secret=settings.auth0_client_secret, + client_kwargs={"scope": "openid profile email", "verify": False}, + server_metadata_url=f"https://{settings.auth0_domain}/.well-known/openid-configuration", +) app = FastAPI() +app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key) + + +@app.get("/") +async def homepage(request: Request): + user = request.session.get("user") + if user: + print(json.dumps(user)) + return RedirectResponse("/app") + + with open(static_path / "login.html") as f: + return HTMLResponse(f.read()) + + +@app.route("/login") +async def login(request: Request): + return await oauth.auth0.authorize_redirect( + request, + redirect_uri="http://localhost:8000/callback", + ) + + +@app.route("/callback", methods=["GET", "POST"]) +async def callback(request: Request): + try: + token = await oauth.auth0.authorize_access_token(request) + except OAuthError as error: + return HTMLResponse(f"

{error.error}

") + user = token.get("userinfo") + if user: + request.session["user"] = dict(user) + + return RedirectResponse(url="/app") + + +@app.route("/logout") +async def logout(request: Request): + request.session.pop("user", None) + return RedirectResponse(url="/") + + +@app.route("/user") +async def user(request: Request): + user = request.session.get("user") + return user class HealthCheck(BaseModel): @@ -44,7 +104,7 @@ app.include_router(sessions.router) app.include_router(files.router) app.include_router(analysis.router) app.mount( - "/", + "/app", StaticFiles(directory=static_path, html=True), name="static", ) diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index 9d68dd1..3fc99ae 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -1,10 +1,9 @@ -from typing import Annotated from uuid import uuid4 -from fastapi import APIRouter, Depends +from fastapi import APIRouter +from starlette.requests import Request from fastapi.responses import PlainTextResponse -from ..security import oauth2_scheme # Scheme for the Authorization header @@ -12,7 +11,9 @@ router = APIRouter() @router.get("/new_session") -async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str: +async def get_new_session(request: Request) -> str: + user = request.session.get("user") + print(user) return str(uuid4()) diff --git a/src/hellocomputer/security.py b/src/hellocomputer/security.py index 5fa38a6..e69de29 100644 --- a/src/hellocomputer/security.py +++ b/src/hellocomputer/security.py @@ -1,3 +0,0 @@ -from fastapi.security import OAuth2PasswordBearer - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") diff --git a/src/hellocomputer/static/index.html b/src/hellocomputer/static/index.html index 7ddff25..68c33d8 100644 --- a/src/hellocomputer/static/index.html +++ b/src/hellocomputer/static/index.html @@ -8,6 +8,7 @@ + @@ -21,9 +22,17 @@ Hello, computer!

- How to - File templates - About + How to + + File templates + + About + + Config + + Logout @@ -105,7 +114,7 @@
- +
diff --git a/src/hellocomputer/static/login.html b/src/hellocomputer/static/login.html new file mode 100644 index 0000000..31a54a6 --- /dev/null +++ b/src/hellocomputer/static/login.html @@ -0,0 +1,43 @@ + + + + + + + Login Page + + + + + + +
+ +
+ + + + + \ No newline at end of file diff --git a/src/hellocomputer/static/script.js b/src/hellocomputer/static/script.js index 01a941a..66fbb3a 100644 --- a/src/hellocomputer/static/script.js +++ b/src/hellocomputer/static/script.js @@ -37,16 +37,16 @@ async function fetchResponse(message, newMessage) { const data = await response.text(); // Hide spinner and display result - newMessage.innerHTML = '
' + data + '
'; + newMessage.innerHTML = '
' + data + '
'; } catch (error) { - newMessage.innerHTML = '' + 'Error: ' + error.message; + newMessage.innerHTML = '' + 'Error: ' + error.message; } } function addAIMessage(messageContent) { const newMessage = document.createElement('div'); newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded'); - newMessage.innerHTML = '
'; + newMessage.innerHTML = '
'; chatMessages.prepend(newMessage); // Add new message at the top fetchResponse(messageContent, newMessage); } @@ -54,7 +54,7 @@ function addAIMessage(messageContent) { function addAIManualMessage(m) { const newMessage = document.createElement('div'); newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded'); - newMessage.innerHTML = '
' + m + '
'; + newMessage.innerHTML = '
' + m + '
'; chatMessages.prepend(newMessage); // Add new message at the top }