diff --git a/requirements.in b/requirements.in index 6795e1d..031ffd6 100644 --- a/requirements.in +++ b/requirements.in @@ -6,6 +6,7 @@ pydantic-settings s3fs aiofiles duckdb +polars pyjwt[crypto] python-multipart authlib diff --git a/requirements.txt b/requirements.txt index 005ded6..f6c087e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,108 +1,225 @@ +# This file was autogenerated by uv via the following command: +# uv pip compile requirements.in aiobotocore==2.13.0 + # via s3fs aiofiles==23.2.1 aiohttp==3.9.5 + # via + # aiobotocore + # langchain + # langchain-community + # s3fs aioitertools==0.11.0 + # via aiobotocore aiosignal==1.3.1 -annotated-types==0.6.0 -anyio==4.3.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +anyio==4.4.0 + # via + # httpx + # openai + # starlette + # watchfiles attrs==23.2.0 + # via aiohttp authlib==1.3.1 -babel==2.15.0 -bootstrap4==0.1.0 botocore==1.34.106 -certifi==2024.2.2 + # via aiobotocore +certifi==2024.6.2 + # via + # httpcore + # httpx + # requests cffi==1.16.0 + # via cryptography charset-normalizer==3.3.2 + # via requests click==8.1.7 -colorama==0.4.6 -cryptography==42.0.7 -dataclasses-json==0.6.6 + # via + # typer + # uvicorn +cryptography==42.0.8 + # via + # authlib + # pyjwt +dataclasses-json==0.6.7 + # via langchain-community distro==1.9.0 + # via openai dnspython==2.6.1 -docutils==0.21.2 -duckdb==0.10.2 + # via email-validator +duckdb==1.0.0 email-validator==2.1.1 + # via fastapi fastapi==0.111.0 -fastapi-cli==0.0.3 -flit==3.9.0 -flit-core==3.9.0 +fastapi-cli==0.0.4 + # via fastapi frozenlist==1.4.1 -fsspec==2024.5.0 -ghp-import==2.1.0 + # via + # aiohttp + # aiosignal +fsspec==2024.6.0 + # via s3fs h11==0.14.0 + # via + # httpcore + # uvicorn httpcore==1.0.5 + # via httpx httptools==0.6.1 + # via uvicorn httpx==0.27.0 + # via + # fastapi + # openai idna==3.7 -iniconfig==2.0.0 + # via + # anyio + # email-validator + # httpx + # requests + # yarl itsdangerous==2.2.0 jinja2==3.1.4 + # via fastapi jmespath==1.0.1 + # via botocore jsonpatch==1.33 -jsonpointer==2.4 -langchain==0.2.0 -langchain-community==0.2.0 -langchain-core==0.2.0 -langchain-text-splitters==0.2.0 -langsmith==0.1.59 -markdown==3.6 + # via langchain-core +jsonpointer==3.0.0 + # via jsonpatch +langchain==0.2.3 + # via langchain-community +langchain-community==0.2.4 +langchain-core==0.2.5 + # via + # langchain + # langchain-community + # langchain-text-splitters +langchain-text-splitters==0.2.1 + # via langchain +langsmith==0.1.76 + # via + # langchain + # langchain-community + # langchain-core markdown-it-py==3.0.0 + # via rich markupsafe==2.1.5 -marshmallow==3.21.2 + # via jinja2 +marshmallow==3.21.3 + # via dataclasses-json mdurl==0.1.2 -mergedeep==1.3.4 -mkdocs==1.6.0 -mkdocs-get-deps==0.2.0 -mkdocs-material==9.5.23 -mkdocs-material-extensions==1.3.1 + # via markdown-it-py multidict==6.0.5 + # via + # aiohttp + # yarl mypy-extensions==1.0.0 + # via typing-inspect numpy==1.26.4 -openai==1.30.1 -orjson==3.10.3 + # via + # langchain + # langchain-community +openai==1.33.0 +orjson==3.10.4 + # via + # fastapi + # langsmith packaging==23.2 -paginate==0.5.6 -pathspec==0.12.1 -platformdirs==4.2.2 -pluggy==1.5.0 -polars==0.20.26 -pyarrow==16.1.0 + # via + # langchain-core + # marshmallow +polars==0.20.31 pycparser==2.22 -pydantic==2.7.1 -pydantic-core==2.18.2 -pydantic-settings==2.2.1 + # via cffi +pydantic==2.7.3 + # via + # fastapi + # langchain + # langchain-core + # langsmith + # openai + # pydantic-settings +pydantic-core==2.18.4 + # via pydantic +pydantic-settings==2.3.2 pygments==2.18.0 + # via rich pyjwt==2.8.0 -pymdown-extensions==10.8.1 -pytest==8.2.1 -pytest-asyncio==0.23.7 python-dateutil==2.9.0.post0 + # via botocore python-dotenv==1.0.1 + # via + # pydantic-settings + # uvicorn python-multipart==0.0.9 + # via fastapi pyyaml==6.0.1 -pyyaml-env-tag==0.1 -regex==2024.5.15 -requests==2.31.0 + # via + # langchain + # langchain-community + # langchain-core + # uvicorn +requests==2.32.3 + # via + # langchain + # langchain-community + # langsmith rich==13.7.1 -ruff==0.4.4 -s3fs==2024.5.0 + # via typer +s3fs==2024.6.0 shellingham==1.5.4 + # via typer six==1.16.0 + # via python-dateutil sniffio==1.3.1 + # via + # anyio + # httpx + # openai sqlalchemy==2.0.30 + # via + # langchain + # langchain-community starlette==0.37.2 + # via fastapi tenacity==8.3.0 -tomli-w==1.0.0 + # via + # langchain + # langchain-community + # langchain-core tqdm==4.66.4 + # via openai typer==0.12.3 -typing-extensions==4.11.0 + # via fastapi-cli +typing-extensions==4.12.2 + # via + # fastapi + # openai + # pydantic + # pydantic-core + # sqlalchemy + # typer + # typing-inspect typing-inspect==0.9.0 + # via dataclasses-json ujson==5.10.0 + # via fastapi urllib3==2.2.1 -uvicorn==0.29.0 + # via + # botocore + # requests +uvicorn==0.30.1 + # via fastapi uvloop==0.19.0 -watchdog==4.0.0 -watchfiles==0.21.0 + # via uvicorn +watchfiles==0.22.0 + # via uvicorn websockets==12.0 + # via uvicorn wrapt==1.16.0 + # via aiobotocore yarl==1.9.4 + # via aiohttp diff --git a/src/hellocomputer/analytics.py b/src/hellocomputer/analytics.py index 0062b19..d09bf20 100644 --- a/src/hellocomputer/analytics.py +++ b/src/hellocomputer/analytics.py @@ -1,11 +1,33 @@ import os from pathlib import Path +import duckdb from typing_extensions import Self + +from hellocomputer.db import StorageEngines + from .db import DDB class AnalyticsDB(DDB): + def __init__( + self, + storage_engine: StorageEngines, + sid: str | None = None, + path: Path | None = None, + gcs_access: str | None = None, + gcs_secret: str | None = None, + bucket: str | None = None, + **kwargs, + ): + super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) + self.sid = sid + # Override storage engine for sessions + if storage_engine == StorageEngines.gcs: + self.path_prefix = "gcs://{bucket}/sessions/{sid}" + elif storage_engine == StorageEngines.local: + self.path_prefix = path / "sessions" / sid + def load_xls(self, xls_path: Path) -> Self: """For some reason, the header is not loaded""" self.db.sql(f""" @@ -47,7 +69,12 @@ class AnalyticsDB(DDB): if not self.loaded: raise ValueError("Data should be loaded first") - self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") + try: + self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") + except duckdb.duckdb.IOException: + # Create the folder + os.makedirs(self.path_prefix) + self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") for sheet in self.sheets: self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") diff --git a/src/hellocomputer/db.py b/src/hellocomputer/db.py index 49bc9c3..3630616 100644 --- a/src/hellocomputer/db.py +++ b/src/hellocomputer/db.py @@ -1,7 +1,8 @@ from enum import StrEnum -import duckdb from pathlib import Path +import duckdb + class StorageEngines(StrEnum): local = "Local" @@ -12,7 +13,6 @@ class DDB: def __init__( self, storage_engine: StorageEngines, - sid: str | None = None, path: Path | None = None, gcs_access: str | None = None, gcs_secret: str | None = None, @@ -33,7 +33,6 @@ class DDB: gcs_access is not None, gcs_secret is not None, bucket is not None, - sid is not None, ) ): self.db.sql(f""" @@ -42,11 +41,11 @@ class DDB: KEY_ID '{gcs_access}', SECRET '{gcs_secret}') """) - self.path_prefix = f"gcs://{bucket}/sessions/{sid}" + self.path_prefix = f"gcs://{bucket}" else: raise ValueError( "With GCS storage engine you need to provide " - "the gcs_access, gcs_secret, sid, and bucket keyword arguments" + "the gcs_access, gcs_secret, and bucket keyword arguments" ) elif storage_engine == StorageEngines.local: diff --git a/src/hellocomputer/main.py b/src/hellocomputer/main.py index 8c9b436..4d6bf59 100644 --- a/src/hellocomputer/main.py +++ b/src/hellocomputer/main.py @@ -10,7 +10,7 @@ from starlette.requests import Request import hellocomputer from .config import settings -from .routers import analysis, auth, files, sessions, health +from .routers import analysis, auth, files, health, sessions static_path = Path(hellocomputer.__file__).parent / "static" diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 0e4e41b..8d07096 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,8 +1,8 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse -from hellocomputer.db import StorageEngines from hellocomputer.analytics import AnalyticsDB +from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block from ..config import settings diff --git a/src/hellocomputer/routers/auth.py b/src/hellocomputer/routers/auth.py index f9cbc8b..aec0b68 100644 --- a/src/hellocomputer/routers/auth.py +++ b/src/hellocomputer/routers/auth.py @@ -2,8 +2,9 @@ from authlib.integrations.starlette_client import OAuth, OAuthError from fastapi import APIRouter from fastapi.responses import HTMLResponse, RedirectResponse from starlette.requests import Request - -from ..config import settings +from hellocomputer.config import settings +from hellocomputer.users import UserDB +from hellocomputer.db import StorageEngines router = APIRouter() @@ -33,7 +34,15 @@ async def callback(request: Request): return HTMLResponse(f"

{error.error}

") user = token.get("userinfo") if user: - request.session["user"] = dict(user) + user_info = dict(user) + request.session["user"] = user_info + user_db = UserDB( + StorageEngines.gcs, + gcs_access=settings.gcs_access, + gcs_secret=settings.gcs_secret, + bucket=settings.gcs_bucketname, + ) + user_db.dump_user_record(user_info) return RedirectResponse(url="/app") diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index 5766e0d..fcac0b4 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -4,9 +4,9 @@ import aiofiles from fastapi import APIRouter, File, UploadFile from fastapi.responses import JSONResponse -from ..db import StorageEngines from ..analytics import AnalyticsDB from ..config import settings +from ..db import StorageEngines router = APIRouter() diff --git a/src/hellocomputer/users.py b/src/hellocomputer/users.py index 96038f0..555f603 100644 --- a/src/hellocomputer/users.py +++ b/src/hellocomputer/users.py @@ -1,2 +1,49 @@ -class UserManagement: - pass +import json +import os +from pathlib import Path +from uuid import UUID, uuid4 + +import duckdb +import polars as pl + +from .db import DDB, StorageEngines + + +class UserDB(DDB): + def __init__( + self, + storage_engine: StorageEngines, + path: Path | None = None, + gcs_access: str | None = None, + gcs_secret: str | None = None, + bucket: str | None = None, + **kwargs, + ): + super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) + + if storage_engine == StorageEngines.gcs: + self.path_prefix = "gcs://{bucket}/users" + + elif storage_engine == StorageEngines.local: + self.path_prefix = path / "users" + + def dump_user_record(self, user_data: dict, record_id: UUID | None = None): + df = pl.from_dict(user_data) # noqa + record_id = uuid4() if record_id is None else record_id + query = f"COPY df TO '{self.path_prefix}/{record_id}.ndjson' (FORMAT JSON)" + + try: + self.db.sql(query) + except duckdb.duckdb.IOException: + os.makedirs(self.path_prefix) + self.db.sql(query) + + return user_data + + def user_exists(self, email: str) -> bool: + query = f"SELECT * FROM '{self.path_prefix}/*.ndjson' WHERE email = '{email}'" + return self.db.sql(query).pl().shape[0] > 0 + + @staticmethod + def email(record: str) -> str: + return json.loads(record)["email"] diff --git a/test/output/.gitignore b/test/output/.gitignore index 16f2dc5..0c1ef24 100644 --- a/test/output/.gitignore +++ b/test/output/.gitignore @@ -1 +1,3 @@ -*.csv \ No newline at end of file +*.csv +*.json +*.ndjson \ No newline at end of file diff --git a/test/test_data.py b/test/test_data.py index 337807d..926b6da 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,8 +1,8 @@ from pathlib import Path import hellocomputer -from hellocomputer.db import StorageEngines from hellocomputer.analytics import AnalyticsDB +from hellocomputer.db import StorageEngines TEST_STORAGE = StorageEngines.local TEST_XLS_PATH = ( @@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" def test_0_dump(): - db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test") db.load_xls(TEST_XLS_PATH).dump() assert db.sheets == ("answers",) @@ -23,19 +23,25 @@ def test_0_dump(): def test_load(): - db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB( + storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" + ).load_folder() results = db.query("select * from answers").fetchall() assert len(results) == 6 def test_load_description(): - db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB( + storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" + ).load_folder() file_description = db.load_description() assert file_description.startswith("answers") def test_schema(): - db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB( + storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" + ).load_folder() schema = [] for sheet in db.sheets: schema.append(db.table_schema(sheet)) @@ -44,7 +50,9 @@ def test_schema(): def test_query_prompt(): - db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + db = AnalyticsDB( + storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" + ).load_folder() assert db.query_prompt("Find the average score of all students").startswith( "The following sentence" diff --git a/test/test_query.py b/test/test_query.py index 35e0d34..7064ebe 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -2,9 +2,9 @@ from pathlib import Path import hellocomputer import pytest -from hellocomputer.db import StorageEngines from hellocomputer.analytics import AnalyticsDB from hellocomputer.config import settings +from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block from hellocomputer.models import Chat diff --git a/test/test_user.py b/test/test_user.py new file mode 100644 index 0000000..6e325ae --- /dev/null +++ b/test/test_user.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import hellocomputer +from hellocomputer.db import StorageEngines +from hellocomputer.users import UserDB + +TEST_STORAGE = StorageEngines.local +TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" + + +def test_create_user(): + user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + user_data = {"name": "John Doe", "email": "[email protected]"} + user_data = user.dump_user_record(user_data, record_id="test") + + assert user_data["name"] == "John Doe" + + +def test_user_exists(): + user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + user_data = {"name": "John Doe", "email": "[email protected]"} + user.dump_user_record(user_data, record_id="test") + + assert user.user_exists("[email protected]") + assert not user.user_exists("notpresent")