diff --git a/.gitignore b/.gitignore index 1a3c8a3..bbffe28 100644 --- a/.gitignore +++ b/.gitignore @@ -158,8 +158,12 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ *DS_Store -test/data/output/* \ No newline at end of file +test/data/output/* + +.pytest_cache +.ruff_cache +*.ipynb \ No newline at end of file diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 0000000..a73e4e4 --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1 @@ +Placeholder to store some notebooks. Gitignored \ No newline at end of file diff --git a/src/hellocomputer/auth.py b/src/hellocomputer/auth.py new file mode 100644 index 0000000..c97564b --- /dev/null +++ b/src/hellocomputer/auth.py @@ -0,0 +1,16 @@ +from starlette.requests import Request +from .config import settings + + +def get_user(request: Request) -> dict: + if settings.auth: + return request.session.get("user") + else: + return {"email": "test@test.com"} + + +def get_user_email(request: Request) -> str: + if settings.auth: + return request.session.get("user").get("email") + else: + return "test@test.com" diff --git a/src/hellocomputer/db/__init__.py b/src/hellocomputer/db/__init__.py index be93fa7..113e4e3 100644 --- a/src/hellocomputer/db/__init__.py +++ b/src/hellocomputer/db/__init__.py @@ -48,8 +48,9 @@ class DDB: """ ) ) + conn.execute(text("LOAD httpfs")) - self.path_prefix = f"gcs://{bucket}" + self.path_prefix = f"gs://{bucket}" else: raise ValueError( "With GCS storage engine you need to provide " diff --git a/src/hellocomputer/db/sessions.py b/src/hellocomputer/db/sessions.py index 95de577..a728334 100644 --- a/src/hellocomputer/db/sessions.py +++ b/src/hellocomputer/db/sessions.py @@ -24,7 +24,7 @@ class SessionDB(DDB): self.sid = sid # Override storage engine for sessions if storage_engine == StorageEngines.gcs: - self.path_prefix = f"gcs://{bucket}/sessions/{sid}" + self.path_prefix = f"gs://{bucket}/sessions/{sid}" elif storage_engine == StorageEngines.local: self.path_prefix = path / "sessions" / sid diff --git a/src/hellocomputer/db/users.py b/src/hellocomputer/db/users.py index 5a166ad..088e6ed 100644 --- a/src/hellocomputer/db/users.py +++ b/src/hellocomputer/db/users.py @@ -24,11 +24,13 @@ class UserDB(DDB): super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) if storage_engine == StorageEngines.gcs: - self.path_prefix = f"gcs://{bucket}/users" + self.path_prefix = f"gs://{bucket}/users" elif storage_engine == StorageEngines.local: self.path_prefix = path / "users" + self.storage_engine = storage_engine + def dump_user_record(self, user_data: dict, record_id: UUID | None = None): df = pl.from_dict(user_data) # noqa record_id = uuid4() if record_id is None else record_id @@ -36,9 +38,12 @@ class UserDB(DDB): try: self.db.sql(query) - except duckdb.duckdb.IOException: - os.makedirs(self.path_prefix) - self.db.sql(query) + except duckdb.duckdb.IOException as e: + if self.storage_engine == StorageEngines.local: + os.makedirs(self.path_prefix) + self.db.sql(query) + else: + raise e return user_data @@ -64,7 +69,7 @@ class OwnershipDB(DDB): super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) if storage_engine == StorageEngines.gcs: - self.path_prefix = f"gcs://{bucket}/owners" + self.path_prefix = f"gs://{bucket}/owners" elif storage_engine == StorageEngines.local: self.path_prefix = path / "owners" diff --git a/src/hellocomputer/main.py b/src/hellocomputer/main.py index 0fb807a..f4ab99a 100644 --- a/src/hellocomputer/main.py +++ b/src/hellocomputer/main.py @@ -1,4 +1,3 @@ -import json from pathlib import Path from fastapi import FastAPI @@ -9,6 +8,7 @@ from starlette.requests import Request import hellocomputer +from .auth import get_user from .config import settings from .routers import analysis, auth, files, health, sessions @@ -20,7 +20,7 @@ app.add_middleware(SessionMiddleware, secret_key=settings.app_secret_key) @app.get("/") async def homepage(request: Request): - user = request.session.get("user") + user = get_user(request) if user: return RedirectResponse("/app") diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index 09f6c9a..511b3b8 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -1,7 +1,7 @@ from enum import StrEnum -from langchain_fireworks import Fireworks from langchain_core.prompts import PromptTemplate +from langchain_fireworks import Fireworks class AvailableModels(StrEnum): diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index b67af48..c1fcfd9 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -9,6 +9,7 @@ from hellocomputer.db import StorageEngines from hellocomputer.db.users import OwnershipDB from ..config import settings +from ..auth import get_user_email # Scheme for the Authorization header @@ -30,7 +31,7 @@ async def get_greeting() -> str: @router.get("/sessions") async def get_sessions(request: Request) -> List[str]: - user_email = request.session.get("user").get("email") + user_email = get_user_email(request) ownership = OwnershipDB( StorageEngines.gcs, gcs_access=settings.gcs_access, diff --git a/test/test_query.py b/test/test_query.py index 07d42b3..759ae69 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1,13 +1,13 @@ from pathlib import Path import hellocomputer -import pytest import polars as pl +import pytest from hellocomputer.config import settings from hellocomputer.db import StorageEngines +from hellocomputer.db.sessions import SessionDB from hellocomputer.extraction import extract_code_block from hellocomputer.models import Chat -from hellocomputer.db.sessions import SessionDB TEST_STORAGE = StorageEngines.local TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"