This commit is contained in:
parent
910e91a391
commit
f4b9c30a17
|
@ -1,4 +1,5 @@
|
|||
langchain
|
||||
langgraph
|
||||
langchain-community
|
||||
langchain-fireworks
|
||||
langchain-openai
|
||||
|
@ -11,6 +12,7 @@ duckdb
|
|||
duckdb-engine
|
||||
polars
|
||||
pyarrow
|
||||
pydantic
|
||||
pyjwt[crypto]
|
||||
python-multipart
|
||||
authlib
|
||||
|
|
|
@ -1,20 +1,67 @@
|
|||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from pydantic import model_validator
|
||||
from pathlib import Path
|
||||
from typing import Self, Optional
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
local = "local"
|
||||
gcs = "GCS"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
storage_engine: StorageEngines = "local"
|
||||
base_url: str = "http://localhost:8000"
|
||||
llm_api_key: str = "Awesome API"
|
||||
llm_base_url: str = "Awessome Endpoint"
|
||||
gcs_access: str = "access"
|
||||
gcs_secret: str = "secret"
|
||||
gcs_bucketname: str = "bucket"
|
||||
llm_base_url: Optional[str] = None
|
||||
gcs_access: Optional[str] = None
|
||||
gcs_secret: Optional[str] = None
|
||||
gcs_bucketname: Optional[str] = None
|
||||
path: Optional[Path] = None
|
||||
auth: bool = True
|
||||
auth0_client_id: str = ""
|
||||
auth0_client_secret: str = ""
|
||||
auth0_domain: str = ""
|
||||
app_secret_key: str = ""
|
||||
auth0_client_id: Optional[str] = None
|
||||
auth0_client_secret: Optional[str] = None
|
||||
auth0_domain: Optional[str] = None
|
||||
app_secret_key: Optional[str] = None
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_cloud_storage(self) -> Self:
|
||||
if self.storage_engine == StorageEngines.gcs:
|
||||
if any(
|
||||
(
|
||||
self.gcs_access is None,
|
||||
self.gcs_bucketname is None,
|
||||
self.gcs_secret is None,
|
||||
)
|
||||
):
|
||||
raise ValueError("Cloud storage configuration not provided")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_auth_config(self) -> Self:
|
||||
if not self.auth:
|
||||
if any(
|
||||
(
|
||||
self.auth0_client_id is None,
|
||||
self.auth0_client_secret is None,
|
||||
self.auth0_domain is None,
|
||||
self.app_secret_key is None,
|
||||
)
|
||||
):
|
||||
raise ValueError("Auth is enabled but no auth config is providedc")
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_local_storage(self) -> Self:
|
||||
if self.storage_engine == StorageEngines.local:
|
||||
if self.path is None:
|
||||
raise ValueError("Local storage requires a path")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
@ -1,67 +1,28 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
import duckdb
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
local = "Local"
|
||||
gcs = "GCS"
|
||||
|
||||
|
||||
class DDB:
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
settings: Settings,
|
||||
):
|
||||
self.db = duckdb.connect(":memory:")
|
||||
# Assume extension autoloading
|
||||
# self.db.sql("load httpfs")
|
||||
# self.db.sql("load spatial")
|
||||
self.storage_engine = storage_engine
|
||||
self.sheets = tuple()
|
||||
self.loaded = False
|
||||
self.storage_engine = settings.storage_engine
|
||||
self.engine = create_engine("duckdb:///:memory:")
|
||||
self.db = self.engine.raw_connection()
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
if all(
|
||||
(
|
||||
gcs_access is not None,
|
||||
gcs_secret is not None,
|
||||
bucket is not None,
|
||||
)
|
||||
):
|
||||
self.db.sql(f"""
|
||||
if self.storage_engine == StorageEngines.gcs:
|
||||
self.db.sql(f"""
|
||||
CREATE SECRET (
|
||||
TYPE GCS,
|
||||
KEY_ID '{gcs_access}',
|
||||
SECRET '{gcs_secret}')
|
||||
KEY_ID '{settings.gcs_access}',
|
||||
SECRET '{settings.gcs_secret}')
|
||||
""")
|
||||
|
||||
self.path_prefix = f"gs://{bucket}"
|
||||
else:
|
||||
raise ValueError(
|
||||
"With GCS storage engine you need to provide "
|
||||
"the gcs_access, gcs_secret, and bucket keyword arguments"
|
||||
)
|
||||
self.path_prefix = f"gs://{settings.gcs_bucketname}"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
if path is not None:
|
||||
self.path_prefix = path
|
||||
else:
|
||||
raise ValueError(
|
||||
"With local storage you need to provide the path keyword argument"
|
||||
)
|
||||
|
||||
elif settings.storage_engine == StorageEngines.local:
|
||||
self.path_prefix = settings.path
|
||||
|
||||
def query(self, sql, *args, **kwargs):
|
||||
return self.db.query(sql, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
return create_engine("duckdb:///:memory:")
|
||||
|
|
|
@ -3,30 +3,22 @@ from pathlib import Path
|
|||
|
||||
import duckdb
|
||||
from typing_extensions import Self
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
|
||||
from . import DDB
|
||||
|
||||
|
||||
class SessionDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
sid: str | None = None,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
def __init__(self, settings: Settings, sid: str):
|
||||
super().__init__(settings=settings)
|
||||
self.sid = sid
|
||||
# Override storage engine for sessions
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gs://{bucket}/sessions/{sid}"
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "sessions" / sid
|
||||
if settings.storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gs://{settings.gcs_bucketname}/sessions/{sid}"
|
||||
elif settings.storage_engine == StorageEngines.local:
|
||||
self.path_prefix = settings.path / "sessions" / sid
|
||||
|
||||
def load_xls(self, xls_path: Path) -> Self:
|
||||
"""For some reason, the header is not loaded"""
|
||||
|
@ -175,3 +167,7 @@ class SessionDB(DDB):
|
|||
"Return just the SQL statement",
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def llmsql(self):
|
||||
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
||||
|
|
|
@ -1,35 +1,30 @@
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import duckdb
|
||||
import polars as pl
|
||||
|
||||
from . import DDB, StorageEngines
|
||||
from hellocomputer.db import DDB
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
|
||||
|
||||
class UserDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
settings: Settings,
|
||||
):
|
||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
super().__init__(settings)
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gs://{bucket}/users"
|
||||
if settings.storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gs://{settings.gcs_bucketname}/users"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "users"
|
||||
elif settings.storage_engine == StorageEngines.local:
|
||||
self.path_prefix = settings.path / "users"
|
||||
|
||||
self.storage_engine = storage_engine
|
||||
self.storage_engine = settings.storage_engine
|
||||
|
||||
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
|
||||
df = pl.from_dict(user_data) # noqa
|
||||
|
@ -59,20 +54,15 @@ class UserDB(DDB):
|
|||
class OwnershipDB(DDB):
|
||||
def __init__(
|
||||
self,
|
||||
storage_engine: StorageEngines,
|
||||
path: Path | None = None,
|
||||
gcs_access: str | None = None,
|
||||
gcs_secret: str | None = None,
|
||||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
settings: Settings,
|
||||
):
|
||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
||||
super().__init__(settings)
|
||||
|
||||
if storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gs://{bucket}/owners"
|
||||
if settings.storage_engine == StorageEngines.gcs:
|
||||
self.path_prefix = f"gs://{settings.gcs_bucketname}/owners"
|
||||
|
||||
elif storage_engine == StorageEngines.local:
|
||||
self.path_prefix = path / "owners"
|
||||
elif settings.storage_engine == StorageEngines.local:
|
||||
self.path_prefix = settings.path / "owners"
|
||||
|
||||
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
||||
now = datetime.now().isoformat()
|
||||
|
|
|
@ -3,7 +3,6 @@ from pathlib import Path
|
|||
|
||||
from anyio import open_file
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
import hellocomputer
|
||||
|
||||
|
@ -26,57 +25,8 @@ class Prompts:
|
|||
|
||||
@classmethod
|
||||
async def general(cls):
|
||||
return await cls.getter("general_prompt")
|
||||
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||
|
||||
@classmethod
|
||||
async def sql(cls):
|
||||
return await cls.getter("sql_prompt")
|
||||
|
||||
|
||||
class Chat:
|
||||
@staticmethod
|
||||
def raise_no_key(api_key):
|
||||
if api_key:
|
||||
return api_key
|
||||
elif api_key is None:
|
||||
raise ValueError(
|
||||
"You need to provide a valid API in the api_key init argument"
|
||||
)
|
||||
else:
|
||||
raise ValueError("You need to provide a valid API key")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: AvailableModels = AvailableModels.mixtral_8x7b,
|
||||
api_key: str = "",
|
||||
temperature: float = 0.5,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = self.raise_no_key(api_key)
|
||||
self.messages = []
|
||||
self.responses = []
|
||||
|
||||
self.model: Fireworks = Fireworks(
|
||||
model=model, temperature=temperature, api_key=self.api_key
|
||||
)
|
||||
|
||||
async def eval(self, task):
|
||||
prompt = PromptTemplate.from_template(await Prompts.general())
|
||||
|
||||
response = await self.model.ainvoke(prompt.format(query=task))
|
||||
self.responses.append(response)
|
||||
return self
|
||||
|
||||
async def sql_eval(self, question):
|
||||
prompt = PromptTemplate.from_template(await Prompts.sql())
|
||||
|
||||
response = await self.model.ainvoke(prompt.format(query=question))
|
||||
self.responses.append(response)
|
||||
return self
|
||||
|
||||
def last_response_content(self):
|
||||
last_response = self.responses[-1]
|
||||
return last_response
|
||||
|
||||
def last_response_metadata(self):
|
||||
return self.responses[-1].response_metadata
|
||||
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from fastapi import APIRouter
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
|
|||
from starlette.requests import Request
|
||||
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import StorageEngines
|
||||
from hellocomputer.db.users import UserDB
|
||||
|
||||
router = APIRouter()
|
||||
|
|
|
@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse
|
|||
from starlette.requests import Request
|
||||
|
||||
from ..config import settings
|
||||
from ..db import StorageEngines
|
||||
from ..config import StorageEngines
|
||||
from ..db.sessions import SessionDB
|
||||
from ..db.users import OwnershipDB
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from fastapi import APIRouter
|
|||
from fastapi.responses import PlainTextResponse
|
||||
from starlette.requests import Request
|
||||
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import StorageEngines
|
||||
from hellocomputer.db.users import OwnershipDB
|
||||
|
||||
from ..auth import get_user_email
|
||||
|
|
26
src/hellocomputer/tools.py
Normal file
26
src/hellocomputer/tools.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import Type
|
||||
from langchain.tools import BaseTool
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.config import settings
|
||||
|
||||
|
||||
class DuckdbQueryInput(BaseModel):
|
||||
query: str = Field(description="Question to be translated to a SQL statement")
|
||||
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||
|
||||
|
||||
class DuckdbQueryTool(BaseTool):
|
||||
name: str = "Calculator"
|
||||
description: str = "Tool to evaluate mathemetical expressions"
|
||||
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||
|
||||
def _run(self, query: str, session_id: str) -> str:
|
||||
"""Run the query"""
|
||||
db = SessionDB(settings, session_id)
|
||||
return "Table"
|
||||
|
||||
async def _arun(self, query: str, session_id: str) -> str:
|
||||
"""Use the tool asynchronously."""
|
||||
db = SessionDB(settings, session_id)
|
||||
return "Table"
|
|
@ -1,47 +1,44 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import StorageEngines, Settings
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
settings = Settings(
|
||||
storage_engine=StorageEngines.local,
|
||||
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
||||
)
|
||||
|
||||
TEST_XLS_PATH = (
|
||||
Path(hellocomputer.__file__).parents[2]
|
||||
/ "test"
|
||||
/ "data"
|
||||
/ "TestExcelHelloComputer.xlsx"
|
||||
)
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
|
||||
|
||||
def test_0_dump():
|
||||
db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
||||
db = SessionDB(settings, sid="test")
|
||||
db.load_xls(TEST_XLS_PATH).dump()
|
||||
|
||||
assert db.sheets == ("answers",)
|
||||
assert (TEST_OUTPUT_FOLDER / "sessions" / "test" / "answers.csv").exists()
|
||||
assert (settings.path / "sessions" / "test" / "answers.csv").exists()
|
||||
|
||||
|
||||
def test_load():
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
db = SessionDB(settings, sid="test").load_folder()
|
||||
results = db.query("select * from answers").fetchall()
|
||||
assert len(results) == 6
|
||||
|
||||
|
||||
def test_load_description():
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
db = SessionDB(settings, sid="test").load_folder()
|
||||
file_description = db.load_description()
|
||||
assert file_description.startswith("answers")
|
||||
|
||||
|
||||
def test_schema():
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
db = SessionDB(settings, sid="test").load_folder()
|
||||
schema = []
|
||||
for sheet in db.sheets:
|
||||
schema.append(db.table_schema(sheet))
|
||||
|
@ -50,9 +47,7 @@ def test_schema():
|
|||
|
||||
|
||||
def test_query_prompt():
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
db = SessionDB(settings, sid="test").load_folder()
|
||||
|
||||
assert db.query_prompt("Find the average score of all students").startswith(
|
||||
"The following sentence"
|
||||
|
|
|
@ -5,11 +5,5 @@ from langchain.prompts import PromptTemplate
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_general_prompt():
|
||||
general: str = await Prompts.general()
|
||||
assert general.startswith("You're a helpful assistant")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_general_templated():
|
||||
prompt = PromptTemplate.from_template(await Prompts.general())
|
||||
assert "Do as I say" in prompt.format(query="Do as I say")
|
||||
general: PromptTemplate = await Prompts.general()
|
||||
assert general.format(query="whatever").startswith("You're a helpful assistant")
|
||||
|
|
|
@ -3,65 +3,89 @@ from pathlib import Path
|
|||
import hellocomputer
|
||||
import polars as pl
|
||||
import pytest
|
||||
from hellocomputer.config import settings
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import Settings, StorageEngines
|
||||
from hellocomputer.db.sessions import SessionDB
|
||||
from hellocomputer.extraction import extract_code_block
|
||||
from hellocomputer.models import Chat
|
||||
from hellocomputer.models import AvailableModels
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
settings = Settings(
|
||||
storage_engine=StorageEngines.local,
|
||||
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
||||
)
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
TEST_XLS_PATH = (
|
||||
Path(hellocomputer.__file__).parents[2]
|
||||
/ "test"
|
||||
/ "data"
|
||||
/ "TestExcelHelloComputer.xlsx"
|
||||
)
|
||||
|
||||
SID = "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_chat_simple():
|
||||
chat = Chat(api_key=settings.llm_api_key, temperature=0)
|
||||
chat = await chat.eval("Say literlly 'Hello'")
|
||||
assert "Hello" in chat.last_response_content()
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.mixtral_8x7b,
|
||||
temperature=0.5,
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"""Say literally {word}, a single word. Don't be verbose,
|
||||
I'll be disappointed if you say more than a single word"""
|
||||
)
|
||||
chain = prompt | llm
|
||||
response = await chain.ainvoke({"word": "Hello"})
|
||||
|
||||
assert response.content.lower().startswith("hello")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_simple_data_query():
|
||||
query = "write a query that finds the average score of all students in the current database"
|
||||
async def test_query_context():
|
||||
db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql
|
||||
|
||||
chat = Chat(
|
||||
llm = ChatOpenAI(
|
||||
base_url=settings.llm_base_url,
|
||||
api_key=settings.llm_api_key,
|
||||
model=AvailableModels.mixtral_8x7b,
|
||||
temperature=0.5,
|
||||
)
|
||||
db = SessionDB(
|
||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
||||
).load_xls(TEST_XLS_PATH)
|
||||
|
||||
chat = await chat.sql_eval(db.query_prompt(query))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
assert query.startswith("SELECT")
|
||||
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
||||
context = toolkit.get_context()
|
||||
assert "table_info" in context
|
||||
assert "table_names" in context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
async def test_data_query():
|
||||
q = "Find the average score of all the sudents"
|
||||
|
||||
llm = Chat(
|
||||
api_key=settings.llm_api_key,
|
||||
temperature=0.5,
|
||||
)
|
||||
db = SessionDB(
|
||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
).load_folder()
|
||||
|
||||
chat = await llm.sql_eval(db.query_prompt(q))
|
||||
query = extract_code_block(chat.last_response_content())
|
||||
result: pl.DataFrame = db.query(query).pl()
|
||||
|
||||
assert result.shape[0] == 1
|
||||
assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
||||
#
|
||||
# chat = await chat.sql_eval(db.query_prompt(query))
|
||||
# query = extract_code_block(chat.last_response_content())
|
||||
# assert query.startswith("SELECT")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||
# async def test_data_query():
|
||||
# q = "Find the average score of all the sudents"
|
||||
#
|
||||
# llm = Chat(
|
||||
# api_key=settings.llm_api_key,
|
||||
# temperature=0.5,
|
||||
# )
|
||||
# db = SessionDB(
|
||||
# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||
# ).load_folder()
|
||||
#
|
||||
# chat = await llm.sql_eval(db.query_prompt(q))
|
||||
# query = extract_code_block(chat.last_response_content())
|
||||
# result: pl.DataFrame = db.query(query).pl()
|
||||
#
|
||||
# assert result.shape[0] == 1
|
||||
# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
||||
#
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
from pathlib import Path
|
||||
|
||||
import hellocomputer
|
||||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.config import StorageEngines, Settings
|
||||
from hellocomputer.db.users import OwnershipDB, UserDB
|
||||
|
||||
TEST_STORAGE = StorageEngines.local
|
||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||
settings = Settings(
|
||||
storage_engine=StorageEngines.local,
|
||||
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
||||
)
|
||||
|
||||
|
||||
def test_create_user():
|
||||
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||
user = UserDB(settings)
|
||||
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||
user_data = user.dump_user_record(user_data, record_id="test")
|
||||
|
||||
|
@ -17,7 +19,7 @@ def test_create_user():
|
|||
|
||||
|
||||
def test_user_exists():
|
||||
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||
user = UserDB(settings)
|
||||
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||
user.dump_user_record(user_data, record_id="test")
|
||||
|
||||
|
@ -27,7 +29,7 @@ def test_user_exists():
|
|||
|
||||
def test_assign_owner():
|
||||
assert (
|
||||
OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip(
|
||||
OwnershipDB(settings).set_ownersip(
|
||||
"something.something@something", "testsession", "test"
|
||||
)
|
||||
== "testsession"
|
||||
|
@ -35,6 +37,6 @@ def test_assign_owner():
|
|||
|
||||
|
||||
def test_get_sessions():
|
||||
assert OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).sessions(
|
||||
"something.something@something"
|
||||
) == ["testsession"]
|
||||
assert OwnershipDB(settings).sessions("something.something@something") == [
|
||||
"testsession"
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue