This commit is contained in:
parent
56dc012e23
commit
6d6ec72336
|
@ -6,6 +6,7 @@ pydantic-settings
|
||||||
s3fs
|
s3fs
|
||||||
aiofiles
|
aiofiles
|
||||||
duckdb
|
duckdb
|
||||||
|
polars
|
||||||
pyjwt[crypto]
|
pyjwt[crypto]
|
||||||
python-multipart
|
python-multipart
|
||||||
authlib
|
authlib
|
||||||
|
|
223
requirements.txt
223
requirements.txt
|
@ -1,108 +1,225 @@
|
||||||
|
# This file was autogenerated by uv via the following command:
|
||||||
|
# uv pip compile requirements.in
|
||||||
aiobotocore==2.13.0
|
aiobotocore==2.13.0
|
||||||
|
# via s3fs
|
||||||
aiofiles==23.2.1
|
aiofiles==23.2.1
|
||||||
aiohttp==3.9.5
|
aiohttp==3.9.5
|
||||||
|
# via
|
||||||
|
# aiobotocore
|
||||||
|
# langchain
|
||||||
|
# langchain-community
|
||||||
|
# s3fs
|
||||||
aioitertools==0.11.0
|
aioitertools==0.11.0
|
||||||
|
# via aiobotocore
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
annotated-types==0.6.0
|
# via aiohttp
|
||||||
anyio==4.3.0
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.4.0
|
||||||
|
# via
|
||||||
|
# httpx
|
||||||
|
# openai
|
||||||
|
# starlette
|
||||||
|
# watchfiles
|
||||||
attrs==23.2.0
|
attrs==23.2.0
|
||||||
|
# via aiohttp
|
||||||
authlib==1.3.1
|
authlib==1.3.1
|
||||||
babel==2.15.0
|
|
||||||
bootstrap4==0.1.0
|
|
||||||
botocore==1.34.106
|
botocore==1.34.106
|
||||||
certifi==2024.2.2
|
# via aiobotocore
|
||||||
|
certifi==2024.6.2
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
cffi==1.16.0
|
cffi==1.16.0
|
||||||
|
# via cryptography
|
||||||
charset-normalizer==3.3.2
|
charset-normalizer==3.3.2
|
||||||
|
# via requests
|
||||||
click==8.1.7
|
click==8.1.7
|
||||||
colorama==0.4.6
|
# via
|
||||||
cryptography==42.0.7
|
# typer
|
||||||
dataclasses-json==0.6.6
|
# uvicorn
|
||||||
|
cryptography==42.0.8
|
||||||
|
# via
|
||||||
|
# authlib
|
||||||
|
# pyjwt
|
||||||
|
dataclasses-json==0.6.7
|
||||||
|
# via langchain-community
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
|
# via openai
|
||||||
dnspython==2.6.1
|
dnspython==2.6.1
|
||||||
docutils==0.21.2
|
# via email-validator
|
||||||
duckdb==0.10.2
|
duckdb==1.0.0
|
||||||
email-validator==2.1.1
|
email-validator==2.1.1
|
||||||
|
# via fastapi
|
||||||
fastapi==0.111.0
|
fastapi==0.111.0
|
||||||
fastapi-cli==0.0.3
|
fastapi-cli==0.0.4
|
||||||
flit==3.9.0
|
# via fastapi
|
||||||
flit-core==3.9.0
|
|
||||||
frozenlist==1.4.1
|
frozenlist==1.4.1
|
||||||
fsspec==2024.5.0
|
# via
|
||||||
ghp-import==2.1.0
|
# aiohttp
|
||||||
|
# aiosignal
|
||||||
|
fsspec==2024.6.0
|
||||||
|
# via s3fs
|
||||||
h11==0.14.0
|
h11==0.14.0
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# uvicorn
|
||||||
httpcore==1.0.5
|
httpcore==1.0.5
|
||||||
|
# via httpx
|
||||||
httptools==0.6.1
|
httptools==0.6.1
|
||||||
|
# via uvicorn
|
||||||
httpx==0.27.0
|
httpx==0.27.0
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# openai
|
||||||
idna==3.7
|
idna==3.7
|
||||||
iniconfig==2.0.0
|
# via
|
||||||
|
# anyio
|
||||||
|
# email-validator
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
# yarl
|
||||||
itsdangerous==2.2.0
|
itsdangerous==2.2.0
|
||||||
jinja2==3.1.4
|
jinja2==3.1.4
|
||||||
|
# via fastapi
|
||||||
jmespath==1.0.1
|
jmespath==1.0.1
|
||||||
|
# via botocore
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
jsonpointer==2.4
|
# via langchain-core
|
||||||
langchain==0.2.0
|
jsonpointer==3.0.0
|
||||||
langchain-community==0.2.0
|
# via jsonpatch
|
||||||
langchain-core==0.2.0
|
langchain==0.2.3
|
||||||
langchain-text-splitters==0.2.0
|
# via langchain-community
|
||||||
langsmith==0.1.59
|
langchain-community==0.2.4
|
||||||
markdown==3.6
|
langchain-core==0.2.5
|
||||||
|
# via
|
||||||
|
# langchain
|
||||||
|
# langchain-community
|
||||||
|
# langchain-text-splitters
|
||||||
|
langchain-text-splitters==0.2.1
|
||||||
|
# via langchain
|
||||||
|
langsmith==0.1.76
|
||||||
|
# via
|
||||||
|
# langchain
|
||||||
|
# langchain-community
|
||||||
|
# langchain-core
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
|
# via rich
|
||||||
markupsafe==2.1.5
|
markupsafe==2.1.5
|
||||||
marshmallow==3.21.2
|
# via jinja2
|
||||||
|
marshmallow==3.21.3
|
||||||
|
# via dataclasses-json
|
||||||
mdurl==0.1.2
|
mdurl==0.1.2
|
||||||
mergedeep==1.3.4
|
# via markdown-it-py
|
||||||
mkdocs==1.6.0
|
|
||||||
mkdocs-get-deps==0.2.0
|
|
||||||
mkdocs-material==9.5.23
|
|
||||||
mkdocs-material-extensions==1.3.1
|
|
||||||
multidict==6.0.5
|
multidict==6.0.5
|
||||||
|
# via
|
||||||
|
# aiohttp
|
||||||
|
# yarl
|
||||||
mypy-extensions==1.0.0
|
mypy-extensions==1.0.0
|
||||||
|
# via typing-inspect
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
openai==1.30.1
|
# via
|
||||||
orjson==3.10.3
|
# langchain
|
||||||
|
# langchain-community
|
||||||
|
openai==1.33.0
|
||||||
|
orjson==3.10.4
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# langsmith
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
paginate==0.5.6
|
# via
|
||||||
pathspec==0.12.1
|
# langchain-core
|
||||||
platformdirs==4.2.2
|
# marshmallow
|
||||||
pluggy==1.5.0
|
polars==0.20.31
|
||||||
polars==0.20.26
|
|
||||||
pyarrow==16.1.0
|
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
pydantic==2.7.1
|
# via cffi
|
||||||
pydantic-core==2.18.2
|
pydantic==2.7.3
|
||||||
pydantic-settings==2.2.1
|
# via
|
||||||
|
# fastapi
|
||||||
|
# langchain
|
||||||
|
# langchain-core
|
||||||
|
# langsmith
|
||||||
|
# openai
|
||||||
|
# pydantic-settings
|
||||||
|
pydantic-core==2.18.4
|
||||||
|
# via pydantic
|
||||||
|
pydantic-settings==2.3.2
|
||||||
pygments==2.18.0
|
pygments==2.18.0
|
||||||
|
# via rich
|
||||||
pyjwt==2.8.0
|
pyjwt==2.8.0
|
||||||
pymdown-extensions==10.8.1
|
|
||||||
pytest==8.2.1
|
|
||||||
pytest-asyncio==0.23.7
|
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
|
# via botocore
|
||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
|
# via
|
||||||
|
# pydantic-settings
|
||||||
|
# uvicorn
|
||||||
python-multipart==0.0.9
|
python-multipart==0.0.9
|
||||||
|
# via fastapi
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
pyyaml-env-tag==0.1
|
# via
|
||||||
regex==2024.5.15
|
# langchain
|
||||||
requests==2.31.0
|
# langchain-community
|
||||||
|
# langchain-core
|
||||||
|
# uvicorn
|
||||||
|
requests==2.32.3
|
||||||
|
# via
|
||||||
|
# langchain
|
||||||
|
# langchain-community
|
||||||
|
# langsmith
|
||||||
rich==13.7.1
|
rich==13.7.1
|
||||||
ruff==0.4.4
|
# via typer
|
||||||
s3fs==2024.5.0
|
s3fs==2024.6.0
|
||||||
shellingham==1.5.4
|
shellingham==1.5.4
|
||||||
|
# via typer
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
# via python-dateutil
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
# openai
|
||||||
sqlalchemy==2.0.30
|
sqlalchemy==2.0.30
|
||||||
|
# via
|
||||||
|
# langchain
|
||||||
|
# langchain-community
|
||||||
starlette==0.37.2
|
starlette==0.37.2
|
||||||
|
# via fastapi
|
||||||
tenacity==8.3.0
|
tenacity==8.3.0
|
||||||
tomli-w==1.0.0
|
# via
|
||||||
|
# langchain
|
||||||
|
# langchain-community
|
||||||
|
# langchain-core
|
||||||
tqdm==4.66.4
|
tqdm==4.66.4
|
||||||
|
# via openai
|
||||||
typer==0.12.3
|
typer==0.12.3
|
||||||
typing-extensions==4.11.0
|
# via fastapi-cli
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via
|
||||||
|
# fastapi
|
||||||
|
# openai
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# sqlalchemy
|
||||||
|
# typer
|
||||||
|
# typing-inspect
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
|
# via dataclasses-json
|
||||||
ujson==5.10.0
|
ujson==5.10.0
|
||||||
|
# via fastapi
|
||||||
urllib3==2.2.1
|
urllib3==2.2.1
|
||||||
uvicorn==0.29.0
|
# via
|
||||||
|
# botocore
|
||||||
|
# requests
|
||||||
|
uvicorn==0.30.1
|
||||||
|
# via fastapi
|
||||||
uvloop==0.19.0
|
uvloop==0.19.0
|
||||||
watchdog==4.0.0
|
# via uvicorn
|
||||||
watchfiles==0.21.0
|
watchfiles==0.22.0
|
||||||
|
# via uvicorn
|
||||||
websockets==12.0
|
websockets==12.0
|
||||||
|
# via uvicorn
|
||||||
wrapt==1.16.0
|
wrapt==1.16.0
|
||||||
|
# via aiobotocore
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
|
# via aiohttp
|
||||||
|
|
|
@ -1,11 +1,33 @@
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import duckdb
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
|
|
||||||
from .db import DDB
|
from .db import DDB
|
||||||
|
|
||||||
|
|
||||||
class AnalyticsDB(DDB):
|
class AnalyticsDB(DDB):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_engine: StorageEngines,
|
||||||
|
sid: str | None = None,
|
||||||
|
path: Path | None = None,
|
||||||
|
gcs_access: str | None = None,
|
||||||
|
gcs_secret: str | None = None,
|
||||||
|
bucket: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||||
|
self.sid = sid
|
||||||
|
# Override storage engine for sessions
|
||||||
|
if storage_engine == StorageEngines.gcs:
|
||||||
|
self.path_prefix = "gcs://{bucket}/sessions/{sid}"
|
||||||
|
elif storage_engine == StorageEngines.local:
|
||||||
|
self.path_prefix = path / "sessions" / sid
|
||||||
|
|
||||||
def load_xls(self, xls_path: Path) -> Self:
|
def load_xls(self, xls_path: Path) -> Self:
|
||||||
"""For some reason, the header is not loaded"""
|
"""For some reason, the header is not loaded"""
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
|
@ -47,7 +69,12 @@ class AnalyticsDB(DDB):
|
||||||
if not self.loaded:
|
if not self.loaded:
|
||||||
raise ValueError("Data should be loaded first")
|
raise ValueError("Data should be loaded first")
|
||||||
|
|
||||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
try:
|
||||||
|
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||||
|
except duckdb.duckdb.IOException:
|
||||||
|
# Create the folder
|
||||||
|
os.makedirs(self.path_prefix)
|
||||||
|
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||||
|
|
||||||
for sheet in self.sheets:
|
for sheet in self.sheets:
|
||||||
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import duckdb
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import duckdb
|
||||||
|
|
||||||
|
|
||||||
class StorageEngines(StrEnum):
|
class StorageEngines(StrEnum):
|
||||||
local = "Local"
|
local = "Local"
|
||||||
|
@ -12,7 +13,6 @@ class DDB:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage_engine: StorageEngines,
|
storage_engine: StorageEngines,
|
||||||
sid: str | None = None,
|
|
||||||
path: Path | None = None,
|
path: Path | None = None,
|
||||||
gcs_access: str | None = None,
|
gcs_access: str | None = None,
|
||||||
gcs_secret: str | None = None,
|
gcs_secret: str | None = None,
|
||||||
|
@ -33,7 +33,6 @@ class DDB:
|
||||||
gcs_access is not None,
|
gcs_access is not None,
|
||||||
gcs_secret is not None,
|
gcs_secret is not None,
|
||||||
bucket is not None,
|
bucket is not None,
|
||||||
sid is not None,
|
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
|
@ -42,11 +41,11 @@ class DDB:
|
||||||
KEY_ID '{gcs_access}',
|
KEY_ID '{gcs_access}',
|
||||||
SECRET '{gcs_secret}')
|
SECRET '{gcs_secret}')
|
||||||
""")
|
""")
|
||||||
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
self.path_prefix = f"gcs://{bucket}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"With GCS storage engine you need to provide "
|
"With GCS storage engine you need to provide "
|
||||||
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
|
"the gcs_access, gcs_secret, and bucket keyword arguments"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif storage_engine == StorageEngines.local:
|
elif storage_engine == StorageEngines.local:
|
||||||
|
|
|
@ -10,7 +10,7 @@ from starlette.requests import Request
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
from .routers import analysis, auth, files, sessions, health
|
from .routers import analysis, auth, files, health, sessions
|
||||||
|
|
||||||
static_path = Path(hellocomputer.__file__).parent / "static"
|
static_path = Path(hellocomputer.__file__).parent / "static"
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
from hellocomputer.db import StorageEngines
|
|
||||||
from hellocomputer.analytics import AnalyticsDB
|
from hellocomputer.analytics import AnalyticsDB
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
|
|
|
@ -2,8 +2,9 @@ from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
from hellocomputer.config import settings
|
||||||
from ..config import settings
|
from hellocomputer.users import UserDB
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
@ -33,7 +34,15 @@ async def callback(request: Request):
|
||||||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||||
user = token.get("userinfo")
|
user = token.get("userinfo")
|
||||||
if user:
|
if user:
|
||||||
request.session["user"] = dict(user)
|
user_info = dict(user)
|
||||||
|
request.session["user"] = user_info
|
||||||
|
user_db = UserDB(
|
||||||
|
StorageEngines.gcs,
|
||||||
|
gcs_access=settings.gcs_access,
|
||||||
|
gcs_secret=settings.gcs_secret,
|
||||||
|
bucket=settings.gcs_bucketname,
|
||||||
|
)
|
||||||
|
user_db.dump_user_record(user_info)
|
||||||
|
|
||||||
return RedirectResponse(url="/app")
|
return RedirectResponse(url="/app")
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,9 @@ import aiofiles
|
||||||
from fastapi import APIRouter, File, UploadFile
|
from fastapi import APIRouter, File, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from ..db import StorageEngines
|
|
||||||
from ..analytics import AnalyticsDB
|
from ..analytics import AnalyticsDB
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
|
from ..db import StorageEngines
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
|
@ -1,2 +1,49 @@
|
||||||
class UserManagement:
|
import json
|
||||||
pass
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import duckdb
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from .db import DDB, StorageEngines
|
||||||
|
|
||||||
|
|
||||||
|
class UserDB(DDB):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_engine: StorageEngines,
|
||||||
|
path: Path | None = None,
|
||||||
|
gcs_access: str | None = None,
|
||||||
|
gcs_secret: str | None = None,
|
||||||
|
bucket: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||||
|
|
||||||
|
if storage_engine == StorageEngines.gcs:
|
||||||
|
self.path_prefix = "gcs://{bucket}/users"
|
||||||
|
|
||||||
|
elif storage_engine == StorageEngines.local:
|
||||||
|
self.path_prefix = path / "users"
|
||||||
|
|
||||||
|
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
|
||||||
|
df = pl.from_dict(user_data) # noqa
|
||||||
|
record_id = uuid4() if record_id is None else record_id
|
||||||
|
query = f"COPY df TO '{self.path_prefix}/{record_id}.ndjson' (FORMAT JSON)"
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.db.sql(query)
|
||||||
|
except duckdb.duckdb.IOException:
|
||||||
|
os.makedirs(self.path_prefix)
|
||||||
|
self.db.sql(query)
|
||||||
|
|
||||||
|
return user_data
|
||||||
|
|
||||||
|
def user_exists(self, email: str) -> bool:
|
||||||
|
query = f"SELECT * FROM '{self.path_prefix}/*.ndjson' WHERE email = '{email}'"
|
||||||
|
return self.db.sql(query).pl().shape[0] > 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def email(record: str) -> str:
|
||||||
|
return json.loads(record)["email"]
|
||||||
|
|
4
test/output/.gitignore
vendored
4
test/output/.gitignore
vendored
|
@ -1 +1,3 @@
|
||||||
*.csv
|
*.csv
|
||||||
|
*.json
|
||||||
|
*.ndjson
|
|
@ -1,8 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.db import StorageEngines
|
|
||||||
from hellocomputer.analytics import AnalyticsDB
|
from hellocomputer.analytics import AnalyticsDB
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
TEST_STORAGE = StorageEngines.local
|
||||||
TEST_XLS_PATH = (
|
TEST_XLS_PATH = (
|
||||||
|
@ -15,7 +15,7 @@ TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
|
||||||
|
|
||||||
def test_0_dump():
|
def test_0_dump():
|
||||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
||||||
db.load_xls(TEST_XLS_PATH).dump()
|
db.load_xls(TEST_XLS_PATH).dump()
|
||||||
|
|
||||||
assert db.sheets == ("answers",)
|
assert db.sheets == ("answers",)
|
||||||
|
@ -23,19 +23,25 @@ def test_0_dump():
|
||||||
|
|
||||||
|
|
||||||
def test_load():
|
def test_load():
|
||||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(
|
||||||
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
|
).load_folder()
|
||||||
results = db.query("select * from answers").fetchall()
|
results = db.query("select * from answers").fetchall()
|
||||||
assert len(results) == 6
|
assert len(results) == 6
|
||||||
|
|
||||||
|
|
||||||
def test_load_description():
|
def test_load_description():
|
||||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(
|
||||||
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
|
).load_folder()
|
||||||
file_description = db.load_description()
|
file_description = db.load_description()
|
||||||
assert file_description.startswith("answers")
|
assert file_description.startswith("answers")
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(
|
||||||
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
|
).load_folder()
|
||||||
schema = []
|
schema = []
|
||||||
for sheet in db.sheets:
|
for sheet in db.sheets:
|
||||||
schema.append(db.table_schema(sheet))
|
schema.append(db.table_schema(sheet))
|
||||||
|
@ -44,7 +50,9 @@ def test_schema():
|
||||||
|
|
||||||
|
|
||||||
def test_query_prompt():
|
def test_query_prompt():
|
||||||
db = AnalyticsDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
db = AnalyticsDB(
|
||||||
|
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
|
).load_folder()
|
||||||
|
|
||||||
assert db.query_prompt("Find the average score of all students").startswith(
|
assert db.query_prompt("Find the average score of all students").startswith(
|
||||||
"The following sentence"
|
"The following sentence"
|
||||||
|
|
|
@ -2,9 +2,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.db import StorageEngines
|
|
||||||
from hellocomputer.analytics import AnalyticsDB
|
from hellocomputer.analytics import AnalyticsDB
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
from hellocomputer.models import Chat
|
from hellocomputer.models import Chat
|
||||||
|
|
||||||
|
|
25
test/test_user.py
Normal file
25
test/test_user.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import hellocomputer
|
||||||
|
from hellocomputer.db import StorageEngines
|
||||||
|
from hellocomputer.users import UserDB
|
||||||
|
|
||||||
|
TEST_STORAGE = StorageEngines.local
|
||||||
|
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user():
|
||||||
|
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||||
|
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||||
|
user_data = user.dump_user_record(user_data, record_id="test")
|
||||||
|
|
||||||
|
assert user_data["name"] == "John Doe"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_exists():
|
||||||
|
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||||
|
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||||
|
user.dump_user_record(user_data, record_id="test")
|
||||||
|
|
||||||
|
assert user.user_exists("[email protected]")
|
||||||
|
assert not user.user_exists("notpresent")
|
Loading…
Reference in a new issue