Persist users now
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-06-11 17:20:32 +02:00
parent 56dc012e23
commit 6d6ec72336
13 changed files with 310 additions and 75 deletions

View file

@ -6,6 +6,7 @@ pydantic-settings
s3fs
aiofiles
duckdb
polars
pyjwt[crypto]
python-multipart
authlib

View file

@ -1,108 +1,225 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in
aiobotocore==2.13.0
# via s3fs
aiofiles==23.2.1
aiohttp==3.9.5
# via
# aiobotocore
# langchain
# langchain-community
# s3fs
aioitertools==0.11.0
# via aiobotocore
aiosignal==1.3.1
annotated-types==0.6.0
anyio==4.3.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
anyio==4.4.0
# via
# httpx
# openai
# starlette
# watchfiles
attrs==23.2.0
# via aiohttp
authlib==1.3.1
babel==2.15.0
bootstrap4==0.1.0
botocore==1.34.106
certifi==2024.2.2
# via aiobotocore
certifi==2024.6.2
# via
# httpcore
# httpx
# requests
cffi==1.16.0
# via cryptography
charset-normalizer==3.3.2
# via requests
click==8.1.7
colorama==0.4.6
cryptography==42.0.7
dataclasses-json==0.6.6
# via
# typer
# uvicorn
cryptography==42.0.8
# via
# authlib
# pyjwt
dataclasses-json==0.6.7
# via langchain-community
distro==1.9.0
# via openai
dnspython==2.6.1
docutils==0.21.2
duckdb==0.10.2
# via email-validator
duckdb==1.0.0
email-validator==2.1.1
# via fastapi
fastapi==0.111.0
fastapi-cli==0.0.3
flit==3.9.0
flit-core==3.9.0
fastapi-cli==0.0.4
# via fastapi
frozenlist==1.4.1
fsspec==2024.5.0
ghp-import==2.1.0
# via
# aiohttp
# aiosignal
fsspec==2024.6.0
# via s3fs
h11==0.14.0
# via
# httpcore
# uvicorn
httpcore==1.0.5
# via httpx
httptools==0.6.1
# via uvicorn
httpx==0.27.0
# via
# fastapi
# openai
idna==3.7
iniconfig==2.0.0
# via
# anyio
# email-validator
# httpx
# requests
# yarl
itsdangerous==2.2.0
jinja2==3.1.4
# via fastapi
jmespath==1.0.1
# via botocore
jsonpatch==1.33
jsonpointer==2.4
langchain==0.2.0
langchain-community==0.2.0
langchain-core==0.2.0
langchain-text-splitters==0.2.0
langsmith==0.1.59
markdown==3.6
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.3
# via langchain-community
langchain-community==0.2.4
langchain-core==0.2.5
# via
# langchain
# langchain-community
# langchain-text-splitters
langchain-text-splitters==0.2.1
# via langchain
langsmith==0.1.76
# via
# langchain
# langchain-community
# langchain-core
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
marshmallow==3.21.2
# via jinja2
marshmallow==3.21.3
# via dataclasses-json
mdurl==0.1.2
mergedeep==1.3.4
mkdocs==1.6.0
mkdocs-get-deps==0.2.0
mkdocs-material==9.5.23
mkdocs-material-extensions==1.3.1
# via markdown-it-py
multidict==6.0.5
# via
# aiohttp
# yarl
mypy-extensions==1.0.0
# via typing-inspect
numpy==1.26.4
openai==1.30.1
orjson==3.10.3
# via
# langchain
# langchain-community
openai==1.33.0
orjson==3.10.4
# via
# fastapi
# langsmith
packaging==23.2
paginate==0.5.6
pathspec==0.12.1
platformdirs==4.2.2
pluggy==1.5.0
polars==0.20.26
pyarrow==16.1.0
# via
# langchain-core
# marshmallow
polars==0.20.31
pycparser==2.22
pydantic==2.7.1
pydantic-core==2.18.2
pydantic-settings==2.2.1
# via cffi
pydantic==2.7.3
# via
# fastapi
# langchain
# langchain-core
# langsmith
# openai
# pydantic-settings
pydantic-core==2.18.4
# via pydantic
pydantic-settings==2.3.2
pygments==2.18.0
# via rich
pyjwt==2.8.0
pymdown-extensions==10.8.1
pytest==8.2.1
pytest-asyncio==0.23.7
python-dateutil==2.9.0.post0
# via botocore
python-dotenv==1.0.1
# via
# pydantic-settings
# uvicorn
python-multipart==0.0.9
# via fastapi
pyyaml==6.0.1
pyyaml-env-tag==0.1
regex==2024.5.15
requests==2.31.0
# via
# langchain
# langchain-community
# langchain-core
# uvicorn
requests==2.32.3
# via
# langchain
# langchain-community
# langsmith
rich==13.7.1
ruff==0.4.4
s3fs==2024.5.0
# via typer
s3fs==2024.6.0
shellingham==1.5.4
# via typer
six==1.16.0
# via python-dateutil
sniffio==1.3.1
# via
# anyio
# httpx
# openai
sqlalchemy==2.0.30
# via
# langchain
# langchain-community
starlette==0.37.2
# via fastapi
tenacity==8.3.0
tomli-w==1.0.0
# via
# langchain
# langchain-community
# langchain-core
tqdm==4.66.4
# via openai
typer==0.12.3
typing-extensions==4.11.0
# via fastapi-cli
typing-extensions==4.12.2
# via
# fastapi
# openai
# pydantic
# pydantic-core
# sqlalchemy
# typer
# typing-inspect
typing-inspect==0.9.0
# via dataclasses-json
ujson==5.10.0
# via fastapi
urllib3==2.2.1
uvicorn==0.29.0
# via
# botocore
# requests
uvicorn==0.30.1
# via fastapi
uvloop==0.19.0
watchdog==4.0.0
watchfiles==0.21.0
# via uvicorn
watchfiles==0.22.0
# via uvicorn
websockets==12.0
# via uvicorn
wrapt==1.16.0
# via aiobotocore
yarl==1.9.4
# via aiohttp

View file

@ -1,11 +1,33 @@
import os
from pathlib import Path
import duckdb
from typing_extensions import Self
from hellocomputer.db import StorageEngines
from .db import DDB
class AnalyticsDB(DDB):
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
self.sid = sid
# Override storage engine for sessions
if storage_engine == StorageEngines.gcs:
self.path_prefix = "gcs://{bucket}/sessions/{sid}"
elif storage_engine == StorageEngines.local:
self.path_prefix = path / "sessions" / sid
def load_xls(self, xls_path: Path) -> Self:
"""For some reason, the header is not loaded"""
self.db.sql(f"""
@ -47,7 +69,12 @@ class AnalyticsDB(DDB):
if not self.loaded:
raise ValueError("Data should be loaded first")
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
try:
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
except duckdb.duckdb.IOException:
# Create the folder
os.makedirs(self.path_prefix)
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
for sheet in self.sheets:
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")

View file

@ -1,7 +1,8 @@
from enum import StrEnum
import duckdb
from pathlib import Path
import duckdb
class StorageEngines(StrEnum):
local = "Local"
@ -12,7 +13,6 @@ class DDB:
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
@ -33,7 +33,6 @@ class DDB:
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
)
):
self.db.sql(f"""
@ -42,11 +41,11 @@ class DDB:
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
self.path_prefix = f"gcs://{bucket}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
"the gcs_access, gcs_secret, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local:

View file

@ -10,7 +10,7 @@ from starlette.requests import Request
import hellocomputer
from .config import settings
from .routers import analysis, auth, files, sessions, health
from .routers import analysis, auth, files, health, sessions
static_path = Path(hellocomputer.__file__).parent / "static"

View file

@ -1,8 +1,8 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block
from ..config import settings

View file

@ -2,8 +2,9 @@ from authlib.integrations.starlette_client import OAuth, OAuthError
from fastapi import APIRouter
from fastapi.responses import HTMLResponse, RedirectResponse
from starlette.requests import Request
from ..config import settings
from hellocomputer.config import settings
from hellocomputer.users import UserDB
from hellocomputer.db import StorageEngines
router = APIRouter()
@ -33,7 +34,15 @@ async def callback(request: Request):
return HTMLResponse(f"<h1>{error.error}</h1>")
user = token.get("userinfo")
if user:
request.session["user"] = dict(user)
user_info = dict(user)
request.session["user"] = user_info
user_db = UserDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
)
user_db.dump_user_record(user_info)
return RedirectResponse(url="/app")

View file

@ -4,9 +4,9 @@ import aiofiles
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..db import StorageEngines
from ..analytics import AnalyticsDB
from ..config import settings
from ..db import StorageEngines
router = APIRouter()

View file

@ -1,2 +1,49 @@
class UserManagement:
pass
import json
import os
from pathlib import Path
from uuid import UUID, uuid4
import duckdb
import polars as pl
from .db import DDB, StorageEngines
class UserDB(DDB):
def __init__(
self,
storage_engine: StorageEngines,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
if storage_engine == StorageEngines.gcs:
self.path_prefix = "gcs://{bucket}/users"
elif storage_engine == StorageEngines.local:
self.path_prefix = path / "users"
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
df = pl.from_dict(user_data) # noqa
record_id = uuid4() if record_id is None else record_id
query = f"COPY df TO '{self.path_prefix}/{record_id}.ndjson' (FORMAT JSON)"
try:
self.db.sql(query)
except duckdb.duckdb.IOException:
os.makedirs(self.path_prefix)
self.db.sql(query)
return user_data
def user_exists(self, email: str) -> bool:
query = f"SELECT * FROM '{self.path_prefix}/*.ndjson' WHERE email = '{email}'"
return self.db.sql(query).pl().shape[0] > 0
@staticmethod
def email(record: str) -> str:
return json.loads(record)["email"]

View file

@ -1 +1,3 @@
*.csv
*.json
*.ndjson

View file

@ -1,8 +1,8 @@
from pathlib import Path
import hellocomputer
from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.db import StorageEngines
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (
@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump():
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
db.load_xls(TEST_XLS_PATH).dump()
assert db.sheets == ("answers",)
@ -23,19 +23,25 @@ def test_0_dump():
def test_load():
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
results = db.query("select * from answers").fetchall()
assert len(results) == 6
def test_load_description():
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
file_description = db.load_description()
assert file_description.startswith("answers")
def test_schema():
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
schema = []
for sheet in db.sheets:
schema.append(db.table_schema(sheet))
@ -44,7 +50,9 @@ def test_schema():
def test_query_prompt():
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = AnalyticsDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"

View file

@ -2,9 +2,9 @@ from pathlib import Path
import hellocomputer
import pytest
from hellocomputer.db import StorageEngines
from hellocomputer.analytics import AnalyticsDB
from hellocomputer.config import settings
from hellocomputer.db import StorageEngines
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat

25
test/test_user.py Normal file
View file

@ -0,0 +1,25 @@
from pathlib import Path
import hellocomputer
from hellocomputer.db import StorageEngines
from hellocomputer.users import UserDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_create_user():
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
user_data = {"name": "John Doe", "email": "[email protected]"}
user_data = user.dump_user_record(user_data, record_id="test")
assert user_data["name"] == "John Doe"
def test_user_exists():
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
user_data = {"name": "John Doe", "email": "[email protected]"}
user.dump_user_record(user_data, record_id="test")
assert user.user_exists("[email protected]")
assert not user.user_exists("notpresent")