This commit is contained in:
parent
56dc012e23
commit
6d6ec72336
|
@ -6,6 +6,7 @@ pydantic-settings
|
|||
s3fs
|
||||
aiofiles
|
||||
duckdb
|
||||
polars
|
||||
pyjwt[crypto]
|
||||
python-multipart
|
||||
authlib
|
||||
|
|
223
requirements.txt
223
requirements.txt
|
@ -1,108 +1,225 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in
|
||||
aiobotocore==2.13.0
|
||||
# via s3fs
|
||||
aiofiles==23.2.1
|
||||
aiohttp==3.9.5
|
||||
# via
|
||||
# aiobotocore
|
||||
# langchain
|
||||
# langchain-community
|
||||
# s3fs
|
||||
aioitertools==0.11.0
|
||||
# via aiobotocore
|
||||
aiosignal==1.3.1
|
||||
annotated-types==0.6.0
|
||||
anyio==4.3.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.4.0
|
||||
# via
|
||||
# httpx
|
||||
# openai
|
||||
# starlette
|
||||
# watchfiles
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
authlib==1.3.1
|
||||
babel==2.15.0
|
||||
bootstrap4==0.1.0
|
||||
botocore==1.34.106
|
||||
certifi==2024.2.2
|
||||
# via aiobotocore
|
||||
certifi==2024.6.2
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==1.16.0
|
||||
# via cryptography
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
colorama==0.4.6
|
||||
cryptography==42.0.7
|
||||
dataclasses-json==0.6.6
|
||||
# via
|
||||
# typer
|
||||
# uvicorn
|
||||
cryptography==42.0.8
|
||||
# via
|
||||
# authlib
|
||||
# pyjwt
|
||||
dataclasses-json==0.6.7
|
||||
# via langchain-community
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
dnspython==2.6.1
|
||||
docutils==0.21.2
|
||||
duckdb==0.10.2
|
||||
# via email-validator
|
||||
duckdb==1.0.0
|
||||
email-validator==2.1.1
|
||||
# via fastapi
|
||||
fastapi==0.111.0
|
||||
fastapi-cli==0.0.3
|
||||
flit==3.9.0
|
||||
flit-core==3.9.0
|
||||
fastapi-cli==0.0.4
|
||||
# via fastapi
|
||||
frozenlist==1.4.1
|
||||
fsspec==2024.5.0
|
||||
ghp-import==2.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.6.0
|
||||
# via s3fs
|
||||
h11==0.14.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.5
|
||||
# via httpx
|
||||
httptools==0.6.1
|
||||
# via uvicorn
|
||||
httpx==0.27.0
|
||||
# via
|
||||
# fastapi
|
||||
# openai
|
||||
idna==3.7
|
||||
iniconfig==2.0.0
|
||||
# via
|
||||
# anyio
|
||||
# email-validator
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
itsdangerous==2.2.0
|
||||
jinja2==3.1.4
|
||||
# via fastapi
|
||||
jmespath==1.0.1
|
||||
# via botocore
|
||||
jsonpatch==1.33
|
||||
jsonpointer==2.4
|
||||
langchain==0.2.0
|
||||
langchain-community==0.2.0
|
||||
langchain-core==0.2.0
|
||||
langchain-text-splitters==0.2.0
|
||||
langsmith==0.1.59
|
||||
markdown==3.6
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==0.2.3
|
||||
# via langchain-community
|
||||
langchain-community==0.2.4
|
||||
langchain-core==0.2.5
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-text-splitters
|
||||
langchain-text-splitters==0.2.1
|
||||
# via langchain
|
||||
langsmith==0.1.76
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
marshmallow==3.21.2
|
||||
# via jinja2
|
||||
marshmallow==3.21.3
|
||||
# via dataclasses-json
|
||||
mdurl==0.1.2
|
||||
mergedeep==1.3.4
|
||||
mkdocs==1.6.0
|
||||
mkdocs-get-deps==0.2.0
|
||||
mkdocs-material==9.5.23
|
||||
mkdocs-material-extensions==1.3.1
|
||||
# via markdown-it-py
|
||||
multidict==6.0.5
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
mypy-extensions==1.0.0
|
||||
# via typing-inspect
|
||||
numpy==1.26.4
|
||||
openai==1.30.1
|
||||
orjson==3.10.3
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
openai==1.33.0
|
||||
orjson==3.10.4
|
||||
# via
|
||||
# fastapi
|
||||
# langsmith
|
||||
packaging==23.2
|
||||
paginate==0.5.6
|
||||
pathspec==0.12.1
|
||||
platformdirs==4.2.2
|
||||
pluggy==1.5.0
|
||||
polars==0.20.26
|
||||
pyarrow==16.1.0
|
||||
# via
|
||||
# langchain-core
|
||||
# marshmallow
|
||||
polars==0.20.31
|
||||
pycparser==2.22
|
||||
pydantic==2.7.1
|
||||
pydantic-core==2.18.2
|
||||
pydantic-settings==2.2.1
|
||||
# via cffi
|
||||
pydantic==2.7.3
|
||||
# via
|
||||
# fastapi
|
||||
# langchain
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# openai
|
||||
# pydantic-settings
|
||||
pydantic-core==2.18.4
|
||||
# via pydantic
|
||||
pydantic-settings==2.3.2
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyjwt==2.8.0
|
||||
pymdown-extensions==10.8.1
|
||||
pytest==8.2.1
|
||||
pytest-asyncio==0.23.7
|
||||
python-dateutil==2.9.0.post0
|
||||
# via botocore
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# pydantic-settings
|
||||
# uvicorn
|
||||
python-multipart==0.0.9
|
||||
# via fastapi
|
||||
pyyaml==6.0.1
|
||||
pyyaml-env-tag==0.1
|
||||
regex==2024.5.15
|
||||
requests==2.31.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
# uvicorn
|
||||
requests==2.32.3
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langsmith
|
||||
rich==13.7.1
|
||||
ruff==0.4.4
|
||||
s3fs==2024.5.0
|
||||
# via typer
|
||||
s3fs==2024.6.0
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
six==1.16.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# openai
|
||||
sqlalchemy==2.0.30
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
tenacity==8.3.0
|
||||
tomli-w==1.0.0
|
||||
# via
|
||||
# langchain
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
tqdm==4.66.4
|
||||
# via openai
|
||||
typer==0.12.3
|
||||
typing-extensions==4.11.0
|
||||
# via fastapi-cli
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# fastapi
|
||||
# openai
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# typer
|
||||
# typing-inspect
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
ujson==5.10.0
|
||||
# via fastapi
|
||||
urllib3==2.2.1
|
||||
uvicorn==0.29.0
|
||||
# via
|
||||
# botocore
|
||||
# requests
|
||||
uvicorn==0.30.1
|
||||
# via fastapi
|
||||
uvloop==0.19.0
|
||||
watchdog==4.0.0
|
||||
watchfiles==0.21.0
|
||||
# via uvicorn
|
||||
watchfiles==0.22.0
|
||||
# via uvicorn
|
||||
websockets==12.0
|
||||
# via uvicorn
|
||||
wrapt==1.16.0
|
||||
# via aiobotocore
|
||||
yarl==1.9.4
|
||||
# via aiohttp
|
||||
|
|
|
@ -1,11 +1,33 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
from typing_extensions import Self
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
|
||||
from .db import DDB
|
||||
|
||||
|
||||
class AnalyticsDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
sid: str | None = None,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
self.sid = sid
|
||||
# Override storage engine for sessions
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = "gcs://{bucket}/sessions/{sid}"
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "sessions" / sid
|
||||
|
||||
def load_xls(self, xls_path: Path) -> Self:
|
||||
"""For some reason, the header is not loaded"""
|
||||
self.db.sql(f"""
|
||||
|
@ -47,7 +69,12 @@ class AnalyticsDB(DDB):
|
|||
if not self.loaded:
|
||||
raise ValueError("Data should be loaded first")
|
||||
|
||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||
try:
|
||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||
except duckdb.duckdb.IOException:
|
||||
# Create the folder
|
||||
os.makedirs(self.path_prefix)
|
||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||
|
||||
for sheet in self.sheets:
|
||||
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from enum import StrEnum
|
||||
import duckdb
|
||||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
local = "Local"
|
||||
|
@ -12,7 +13,6 @@ class DDB:
|
|||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
sid: str | None = None,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
|
@ -33,7 +33,6 @@ class DDB:
|
|||
gcs_access is not None,
|
||||
gcs_secret is not None,
|
||||
bucket is not None,
|
||||
sid is not None,
|
||||
)
|
||||
):
|
||||
self.db.sql(f"""
|
||||
|
@ -42,11 +41,11 @@ class DDB:
|
|||
KEY_ID '{gcs_access}',
|
||||
SECRET '{gcs_secret}')
|
||||
""")
|
||||
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
||||
self.path_prefix = f"gcs://{bucket}"
|
||||
else:
|
||||
raise ValueError(
|
||||
"With GCS storage engine you need to provide "
|
||||
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
|
||||
"the gcs_access, gcs_secret, and bucket keyword arguments"
|
||||
)
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
|
|
|
@ -10,7 +10,7 @@ from starlette.requests import Request
|
|||
import hellocomputer
|
||||
|
||||
from .config import settings
|
||||
from .routers import analysis, auth, files, sessions, health
|
||||
from .routers import analysis, auth, files, health, sessions
|
||||
|
||||
static_path = Path(hellocomputer.__file__).parent / "static"
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.analytics import AnalyticsDB
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
|
||||
from ..config import settings
|
||||
|
|
|
@ -2,8 +2,9 @@ from authlib.integrations.starlette_client import OAuth, OAuthError
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from ..config import settings
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.users import UserDB
|
||||
from hellocomputer.db import StorageEngines
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
@ -33,7 +34,15 @@ async def callback(request: Request):
|
|||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||
user = token.get("userinfo")
|
||||
if user:
|
||||
request.session["user"] = dict(user)
|
||||
user_info = dict(user)
|
||||
request.session["user"] = user_info
|
||||
user_db = UserDB(
|
||||
StorageEngines.gcs,
|
||||
gcs_access=settings.gcs_access,
|
||||
gcs_secret=settings.gcs_secret,
|
||||
bucket=settings.gcs_bucketname,
|
||||
)
|
||||
user_db.dump_user_record(user_info)
|
||||
|
||||
return RedirectResponse(url="/app")
|
||||
|
||||
|
|
|
@ -4,9 +4,9 @@ import aiofiles
|
|||
from fastapi import APIRouter, File, UploadFile
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..db import StorageEngines
|
||||
from ..analytics import AnalyticsDB
|
||||
from ..config import settings
|
||||
from ..db import StorageEngines
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
@ -1,2 +1,49 @@
|
|||
class UserManagement:
|
||||
pass
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import duckdb
|
||||
import polars as pl
|
||||
|
||||
from .db import DDB, StorageEngines
|
||||
|
||||
|
||||
class UserDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = "gcs://{bucket}/users"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "users"
|
||||
|
||||
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
|
||||
df = pl.from_dict(user_data) # noqa
|
||||
record_id = uuid4() if record_id is None else record_id
|
||||
query = f"COPY df TO '{self.path_prefix}/{record_id}.ndjson' (FORMAT JSON)"
|
||||
|
||||
try:
|
||||
self.db.sql(query)
|
||||
except duckdb.duckdb.IOException:
|
||||
os.makedirs(self.path_prefix)
|
||||
self.db.sql(query)
|
||||
|
||||
return user_data
|
||||
|
||||
def user_exists(self, email: str) -> bool:
|
||||
query = f"SELECT * FROM '{self.path_prefix}/*.ndjson' WHERE email = '{email}'"
|
||||
return self.db.sql(query).pl().shape[0] > 0
|
||||
|
||||
@staticmethod
|
||||
def email(record: str) -> str:
|
||||
return json.loads(record)["email"]
|
||||
|
|
2
test/output/.gitignore
vendored
2
test/output/.gitignore
vendored
|
@ -1 +1,3 @@
|
|||
*.csv
|
||||
*.json
|
||||
*.ndjson
|
|
@ -1,8 +1,8 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.analytics import AnalyticsDB
|
||||
from hellocomputer.db import StorageEngines
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_XLS_PATH = (
|
||||
|
@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
|||
|
||||
|
||||
def test_0_dump():
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
||||
db.load_xls(TEST_XLS_PATH).dump()
|
||||
|
||||
assert db.sheets == ("answers",)
|
||||
|
@ -23,19 +23,25 @@ def test_0_dump():
|
|||
|
||||
|
||||
def test_load():
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||
db = AnalyticsDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
results = db.query("select * from answers").fetchall()
|
||||
assert len(results) == 6
|
||||
|
||||
|
||||
def test_load_description():
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||
db = AnalyticsDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
file_description = db.load_description()
|
||||
assert file_description.startswith("answers")
|
||||
|
||||
|
||||
def test_schema():
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||
db = AnalyticsDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
schema = []
|
||||
for sheet in db.sheets:
|
||||
schema.append(db.table_schema(sheet))
|
||||
|
@ -44,7 +50,9 @@ def test_schema():
|
|||
|
||||
|
||||
def test_query_prompt():
|
||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||
db = AnalyticsDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
|
||||
assert db.query_prompt("Find the average score of all students").startswith(
|
||||
"The following sentence"
|
||||
|
|
|
@ -2,9 +2,9 @@ from pathlib import Path
|
|||
|
||||
import hellocomputer
|
||||
import pytest
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.analytics import AnalyticsDB
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.models import Chat
|
||||
|
||||
|
|
25
test/test_user.py
Normal file
25
test/test_user.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.users import UserDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
|
||||
|
||||
def test_create_user():
|
||||
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||
user_data = user.dump_user_record(user_data, record_id="test")
|
||||
|
||||
assert user_data["name"] == "John Doe"
|
||||
|
||||
|
||||
def test_user_exists():
|
||||
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||
user.dump_user_record(user_data, record_id="test")
|
||||
|
||||
assert user.user_exists("[email protected]")
|
||||
assert not user.user_exists("notpresent")
|
Loading…
Reference in a new issue