This commit is contained in:
parent
910e91a391
commit
f4b9c30a17
|
@ -1,4 +1,5 @@
|
||||||
langchain
|
langchain
|
||||||
|
langgraph
|
||||||
langchain-community
|
langchain-community
|
||||||
langchain-fireworks
|
langchain-fireworks
|
||||||
langchain-openai
|
langchain-openai
|
||||||
|
@ -11,6 +12,7 @@ duckdb
|
||||||
duckdb-engine
|
duckdb-engine
|
||||||
polars
|
polars
|
||||||
pyarrow
|
pyarrow
|
||||||
|
pydantic
|
||||||
pyjwt[crypto]
|
pyjwt[crypto]
|
||||||
python-multipart
|
python-multipart
|
||||||
authlib
|
authlib
|
||||||
|
|
|
@ -1,20 +1,67 @@
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from pydantic import model_validator
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Self, Optional
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class StorageEngines(StrEnum):
|
||||||
|
local = "local"
|
||||||
|
gcs = "GCS"
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
storage_engine: StorageEngines = "local"
|
||||||
base_url: str = "http://localhost:8000"
|
base_url: str = "http://localhost:8000"
|
||||||
llm_api_key: str = "Awesome API"
|
llm_api_key: str = "Awesome API"
|
||||||
llm_base_url: str = "Awessome Endpoint"
|
llm_base_url: Optional[str] = None
|
||||||
gcs_access: str = "access"
|
gcs_access: Optional[str] = None
|
||||||
gcs_secret: str = "secret"
|
gcs_secret: Optional[str] = None
|
||||||
gcs_bucketname: str = "bucket"
|
gcs_bucketname: Optional[str] = None
|
||||||
|
path: Optional[Path] = None
|
||||||
auth: bool = True
|
auth: bool = True
|
||||||
auth0_client_id: str = ""
|
auth0_client_id: Optional[str] = None
|
||||||
auth0_client_secret: str = ""
|
auth0_client_secret: Optional[str] = None
|
||||||
auth0_domain: str = ""
|
auth0_domain: Optional[str] = None
|
||||||
app_secret_key: str = ""
|
app_secret_key: Optional[str] = None
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_cloud_storage(self) -> Self:
|
||||||
|
if self.storage_engine == StorageEngines.gcs:
|
||||||
|
if any(
|
||||||
|
(
|
||||||
|
self.gcs_access is None,
|
||||||
|
self.gcs_bucketname is None,
|
||||||
|
self.gcs_secret is None,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError("Cloud storage configuration not provided")
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_auth_config(self) -> Self:
|
||||||
|
if not self.auth:
|
||||||
|
if any(
|
||||||
|
(
|
||||||
|
self.auth0_client_id is None,
|
||||||
|
self.auth0_client_secret is None,
|
||||||
|
self.auth0_domain is None,
|
||||||
|
self.app_secret_key is None,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError("Auth is enabled but no auth config is providedc")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_local_storage(self) -> Self:
|
||||||
|
if self.storage_engine == StorageEngines.local:
|
||||||
|
if self.path is None:
|
||||||
|
raise ValueError("Local storage requires a path")
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
@ -1,67 +1,28 @@
|
||||||
from enum import StrEnum
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import duckdb
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
|
||||||
class StorageEngines(StrEnum):
|
|
||||||
local = "Local"
|
|
||||||
gcs = "GCS"
|
|
||||||
|
|
||||||
|
|
||||||
class DDB:
|
class DDB:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage_engine: StorageEngines,
|
settings: Settings,
|
||||||
path: Path | None = None,
|
|
||||||
gcs_access: str | None = None,
|
|
||||||
gcs_secret: str | None = None,
|
|
||||||
bucket: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
self.db = duckdb.connect(":memory:")
|
self.storage_engine = settings.storage_engine
|
||||||
# Assume extension autoloading
|
self.engine = create_engine("duckdb:///:memory:")
|
||||||
# self.db.sql("load httpfs")
|
self.db = self.engine.raw_connection()
|
||||||
# self.db.sql("load spatial")
|
|
||||||
self.storage_engine = storage_engine
|
|
||||||
self.sheets = tuple()
|
|
||||||
self.loaded = False
|
|
||||||
|
|
||||||
if storage_engine == StorageEngines.gcs:
|
if self.storage_engine == StorageEngines.gcs:
|
||||||
if all(
|
|
||||||
(
|
|
||||||
gcs_access is not None,
|
|
||||||
gcs_secret is not None,
|
|
||||||
bucket is not None,
|
|
||||||
)
|
|
||||||
):
|
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
CREATE SECRET (
|
CREATE SECRET (
|
||||||
TYPE GCS,
|
TYPE GCS,
|
||||||
KEY_ID '{gcs_access}',
|
KEY_ID '{settings.gcs_access}',
|
||||||
SECRET '{gcs_secret}')
|
SECRET '{settings.gcs_secret}')
|
||||||
""")
|
""")
|
||||||
|
|
||||||
self.path_prefix = f"gs://{bucket}"
|
self.path_prefix = f"gs://{settings.gcs_bucketname}"
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"With GCS storage engine you need to provide "
|
|
||||||
"the gcs_access, gcs_secret, and bucket keyword arguments"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif storage_engine == StorageEngines.local:
|
|
||||||
if path is not None:
|
|
||||||
self.path_prefix = path
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"With local storage you need to provide the path keyword argument"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
elif settings.storage_engine == StorageEngines.local:
|
||||||
|
self.path_prefix = settings.path
|
||||||
|
|
||||||
def query(self, sql, *args, **kwargs):
|
def query(self, sql, *args, **kwargs):
|
||||||
return self.db.query(sql, *args, **kwargs)
|
return self.db.query(sql, *args, **kwargs)
|
||||||
|
|
||||||
@property
|
|
||||||
def engine(self):
|
|
||||||
return create_engine("duckdb:///:memory:")
|
|
||||||
|
|
|
@ -3,30 +3,22 @@ from pathlib import Path
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
from langchain_community.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
|
|
||||||
from . import DDB
|
from . import DDB
|
||||||
|
|
||||||
|
|
||||||
class SessionDB(DDB):
|
class SessionDB(DDB):
|
||||||
def __init__(
|
def __init__(self, settings: Settings, sid: str):
|
||||||
self,
|
super().__init__(settings=settings)
|
||||||
storage_engine: StorageEngines,
|
|
||||||
sid: str | None = None,
|
|
||||||
path: Path | None = None,
|
|
||||||
gcs_access: str | None = None,
|
|
||||||
gcs_secret: str | None = None,
|
|
||||||
bucket: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
# Override storage engine for sessions
|
# Override storage engine for sessions
|
||||||
if storage_engine == StorageEngines.gcs:
|
if settings.storage_engine == StorageEngines.gcs:
|
||||||
self.path_prefix = f"gs://{bucket}/sessions/{sid}"
|
self.path_prefix = f"gs://{settings.gcs_bucketname}/sessions/{sid}"
|
||||||
elif storage_engine == StorageEngines.local:
|
elif settings.storage_engine == StorageEngines.local:
|
||||||
self.path_prefix = path / "sessions" / sid
|
self.path_prefix = settings.path / "sessions" / sid
|
||||||
|
|
||||||
def load_xls(self, xls_path: Path) -> Self:
|
def load_xls(self, xls_path: Path) -> Self:
|
||||||
"""For some reason, the header is not loaded"""
|
"""For some reason, the header is not loaded"""
|
||||||
|
@ -175,3 +167,7 @@ class SessionDB(DDB):
|
||||||
"Return just the SQL statement",
|
"Return just the SQL statement",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llmsql(self):
|
||||||
|
return SQLDatabase(self.engine, ignore_tables=["metadata"])
|
||||||
|
|
|
@ -1,35 +1,30 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import duckdb
|
import duckdb
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
from . import DDB, StorageEngines
|
from hellocomputer.db import DDB
|
||||||
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
|
|
||||||
|
|
||||||
class UserDB(DDB):
|
class UserDB(DDB):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage_engine: StorageEngines,
|
settings: Settings,
|
||||||
path: Path | None = None,
|
|
||||||
gcs_access: str | None = None,
|
|
||||||
gcs_secret: str | None = None,
|
|
||||||
bucket: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
super().__init__(settings)
|
||||||
|
|
||||||
if storage_engine == StorageEngines.gcs:
|
if settings.storage_engine == StorageEngines.gcs:
|
||||||
self.path_prefix = f"gs://{bucket}/users"
|
self.path_prefix = f"gs://{settings.gcs_bucketname}/users"
|
||||||
|
|
||||||
elif storage_engine == StorageEngines.local:
|
elif settings.storage_engine == StorageEngines.local:
|
||||||
self.path_prefix = path / "users"
|
self.path_prefix = settings.path / "users"
|
||||||
|
|
||||||
self.storage_engine = storage_engine
|
self.storage_engine = settings.storage_engine
|
||||||
|
|
||||||
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
|
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
|
||||||
df = pl.from_dict(user_data) # noqa
|
df = pl.from_dict(user_data) # noqa
|
||||||
|
@ -59,20 +54,15 @@ class UserDB(DDB):
|
||||||
class OwnershipDB(DDB):
|
class OwnershipDB(DDB):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
storage_engine: StorageEngines,
|
settings: Settings,
|
||||||
path: Path | None = None,
|
|
||||||
gcs_access: str | None = None,
|
|
||||||
gcs_secret: str | None = None,
|
|
||||||
bucket: str | None = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
|
super().__init__(settings)
|
||||||
|
|
||||||
if storage_engine == StorageEngines.gcs:
|
if settings.storage_engine == StorageEngines.gcs:
|
||||||
self.path_prefix = f"gs://{bucket}/owners"
|
self.path_prefix = f"gs://{settings.gcs_bucketname}/owners"
|
||||||
|
|
||||||
elif storage_engine == StorageEngines.local:
|
elif settings.storage_engine == StorageEngines.local:
|
||||||
self.path_prefix = path / "owners"
|
self.path_prefix = settings.path / "owners"
|
||||||
|
|
||||||
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
|
||||||
now = datetime.now().isoformat()
|
now = datetime.now().isoformat()
|
||||||
|
|
|
@ -3,7 +3,6 @@ from pathlib import Path
|
||||||
|
|
||||||
from anyio import open_file
|
from anyio import open_file
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_fireworks import Fireworks
|
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
|
|
||||||
|
@ -26,57 +25,8 @@ class Prompts:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def general(cls):
|
async def general(cls):
|
||||||
return await cls.getter("general_prompt")
|
return PromptTemplate.from_template(await cls.getter("general_prompt"))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def sql(cls):
|
async def sql(cls):
|
||||||
return await cls.getter("sql_prompt")
|
return PromptTemplate.from_template(await cls.getter("sql_prompt"))
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
|
||||||
@staticmethod
|
|
||||||
def raise_no_key(api_key):
|
|
||||||
if api_key:
|
|
||||||
return api_key
|
|
||||||
elif api_key is None:
|
|
||||||
raise ValueError(
|
|
||||||
"You need to provide a valid API in the api_key init argument"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("You need to provide a valid API key")
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: AvailableModels = AvailableModels.mixtral_8x7b,
|
|
||||||
api_key: str = "",
|
|
||||||
temperature: float = 0.5,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.api_key = self.raise_no_key(api_key)
|
|
||||||
self.messages = []
|
|
||||||
self.responses = []
|
|
||||||
|
|
||||||
self.model: Fireworks = Fireworks(
|
|
||||||
model=model, temperature=temperature, api_key=self.api_key
|
|
||||||
)
|
|
||||||
|
|
||||||
async def eval(self, task):
|
|
||||||
prompt = PromptTemplate.from_template(await Prompts.general())
|
|
||||||
|
|
||||||
response = await self.model.ainvoke(prompt.format(query=task))
|
|
||||||
self.responses.append(response)
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def sql_eval(self, question):
|
|
||||||
prompt = PromptTemplate.from_template(await Prompts.sql())
|
|
||||||
|
|
||||||
response = await self.model.ainvoke(prompt.format(query=question))
|
|
||||||
self.responses.append(response)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def last_response_content(self):
|
|
||||||
last_response = self.responses[-1]
|
|
||||||
return last_response
|
|
||||||
|
|
||||||
def last_response_metadata(self):
|
|
||||||
return self.responses[-1].response_metadata
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.config import StorageEngines
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.config import StorageEngines
|
||||||
from hellocomputer.db.users import UserDB
|
from hellocomputer.db.users import UserDB
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
|
@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
from ..db import StorageEngines
|
from ..config import StorageEngines
|
||||||
from ..db.sessions import SessionDB
|
from ..db.sessions import SessionDB
|
||||||
from ..db.users import OwnershipDB
|
from ..db.users import OwnershipDB
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.config import StorageEngines
|
||||||
from hellocomputer.db.users import OwnershipDB
|
from hellocomputer.db.users import OwnershipDB
|
||||||
|
|
||||||
from ..auth import get_user_email
|
from ..auth import get_user_email
|
||||||
|
|
26
src/hellocomputer/tools.py
Normal file
26
src/hellocomputer/tools.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Type
|
||||||
|
from langchain.tools import BaseTool
|
||||||
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbQueryInput(BaseModel):
|
||||||
|
query: str = Field(description="Question to be translated to a SQL statement")
|
||||||
|
session_id: str = Field(description="Session ID necessary to fetch the data")
|
||||||
|
|
||||||
|
|
||||||
|
class DuckdbQueryTool(BaseTool):
|
||||||
|
name: str = "Calculator"
|
||||||
|
description: str = "Tool to evaluate mathemetical expressions"
|
||||||
|
args_schema: Type[BaseModel] = DuckdbQueryInput
|
||||||
|
|
||||||
|
def _run(self, query: str, session_id: str) -> str:
|
||||||
|
"""Run the query"""
|
||||||
|
db = SessionDB(settings, session_id)
|
||||||
|
return "Table"
|
||||||
|
|
||||||
|
async def _arun(self, query: str, session_id: str) -> str:
|
||||||
|
"""Use the tool asynchronously."""
|
||||||
|
db = SessionDB(settings, session_id)
|
||||||
|
return "Table"
|
|
@ -1,47 +1,44 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.config import StorageEngines, Settings
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
settings = Settings(
|
||||||
|
storage_engine=StorageEngines.local,
|
||||||
|
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
||||||
|
)
|
||||||
|
|
||||||
TEST_XLS_PATH = (
|
TEST_XLS_PATH = (
|
||||||
Path(hellocomputer.__file__).parents[2]
|
Path(hellocomputer.__file__).parents[2]
|
||||||
/ "test"
|
/ "test"
|
||||||
/ "data"
|
/ "data"
|
||||||
/ "TestExcelHelloComputer.xlsx"
|
/ "TestExcelHelloComputer.xlsx"
|
||||||
)
|
)
|
||||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
|
||||||
|
|
||||||
|
|
||||||
def test_0_dump():
|
def test_0_dump():
|
||||||
db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
|
db = SessionDB(settings, sid="test")
|
||||||
db.load_xls(TEST_XLS_PATH).dump()
|
db.load_xls(TEST_XLS_PATH).dump()
|
||||||
|
|
||||||
assert db.sheets == ("answers",)
|
assert db.sheets == ("answers",)
|
||||||
assert (TEST_OUTPUT_FOLDER / "sessions" / "test" / "answers.csv").exists()
|
assert (settings.path / "sessions" / "test" / "answers.csv").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_load():
|
def test_load():
|
||||||
db = SessionDB(
|
db = SessionDB(settings, sid="test").load_folder()
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
|
||||||
).load_folder()
|
|
||||||
results = db.query("select * from answers").fetchall()
|
results = db.query("select * from answers").fetchall()
|
||||||
assert len(results) == 6
|
assert len(results) == 6
|
||||||
|
|
||||||
|
|
||||||
def test_load_description():
|
def test_load_description():
|
||||||
db = SessionDB(
|
db = SessionDB(settings, sid="test").load_folder()
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
|
||||||
).load_folder()
|
|
||||||
file_description = db.load_description()
|
file_description = db.load_description()
|
||||||
assert file_description.startswith("answers")
|
assert file_description.startswith("answers")
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
db = SessionDB(
|
db = SessionDB(settings, sid="test").load_folder()
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
|
||||||
).load_folder()
|
|
||||||
schema = []
|
schema = []
|
||||||
for sheet in db.sheets:
|
for sheet in db.sheets:
|
||||||
schema.append(db.table_schema(sheet))
|
schema.append(db.table_schema(sheet))
|
||||||
|
@ -50,9 +47,7 @@ def test_schema():
|
||||||
|
|
||||||
|
|
||||||
def test_query_prompt():
|
def test_query_prompt():
|
||||||
db = SessionDB(
|
db = SessionDB(settings, sid="test").load_folder()
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
|
||||||
).load_folder()
|
|
||||||
|
|
||||||
assert db.query_prompt("Find the average score of all students").startswith(
|
assert db.query_prompt("Find the average score of all students").startswith(
|
||||||
"The following sentence"
|
"The following sentence"
|
||||||
|
|
|
@ -5,11 +5,5 @@ from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_general_prompt():
|
async def test_get_general_prompt():
|
||||||
general: str = await Prompts.general()
|
general: PromptTemplate = await Prompts.general()
|
||||||
assert general.startswith("You're a helpful assistant")
|
assert general.format(query="whatever").startswith("You're a helpful assistant")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_general_templated():
|
|
||||||
prompt = PromptTemplate.from_template(await Prompts.general())
|
|
||||||
assert "Do as I say" in prompt.format(query="Do as I say")
|
|
||||||
|
|
|
@ -3,65 +3,89 @@ from pathlib import Path
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
import polars as pl
|
import polars as pl
|
||||||
import pytest
|
import pytest
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import Settings, StorageEngines
|
||||||
from hellocomputer.db import StorageEngines
|
|
||||||
from hellocomputer.db.sessions import SessionDB
|
from hellocomputer.db.sessions import SessionDB
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
from hellocomputer.models import Chat
|
from hellocomputer.models import AvailableModels
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
settings = Settings(
|
||||||
|
storage_engine=StorageEngines.local,
|
||||||
|
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
||||||
|
)
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
|
||||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
|
||||||
TEST_XLS_PATH = (
|
TEST_XLS_PATH = (
|
||||||
Path(hellocomputer.__file__).parents[2]
|
Path(hellocomputer.__file__).parents[2]
|
||||||
/ "test"
|
/ "test"
|
||||||
/ "data"
|
/ "data"
|
||||||
/ "TestExcelHelloComputer.xlsx"
|
/ "TestExcelHelloComputer.xlsx"
|
||||||
)
|
)
|
||||||
|
|
||||||
SID = "test"
|
SID = "test"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
async def test_chat_simple():
|
async def test_chat_simple():
|
||||||
chat = Chat(api_key=settings.llm_api_key, temperature=0)
|
llm = ChatOpenAI(
|
||||||
chat = await chat.eval("Say literlly 'Hello'")
|
base_url=settings.llm_base_url,
|
||||||
assert "Hello" in chat.last_response_content()
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.mixtral_8x7b,
|
||||||
|
temperature=0.5,
|
||||||
|
)
|
||||||
|
prompt = ChatPromptTemplate.from_template(
|
||||||
|
"""Say literally {word}, a single word. Don't be verbose,
|
||||||
|
I'll be disappointed if you say more than a single word"""
|
||||||
|
)
|
||||||
|
chain = prompt | llm
|
||||||
|
response = await chain.ainvoke({"word": "Hello"})
|
||||||
|
|
||||||
|
assert response.content.lower().startswith("hello")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
async def test_simple_data_query():
|
async def test_query_context():
|
||||||
query = "write a query that finds the average score of all students in the current database"
|
db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql
|
||||||
|
|
||||||
chat = Chat(
|
llm = ChatOpenAI(
|
||||||
|
base_url=settings.llm_base_url,
|
||||||
api_key=settings.llm_api_key,
|
api_key=settings.llm_api_key,
|
||||||
|
model=AvailableModels.mixtral_8x7b,
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
)
|
)
|
||||||
db = SessionDB(
|
|
||||||
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
|
|
||||||
).load_xls(TEST_XLS_PATH)
|
|
||||||
|
|
||||||
chat = await chat.sql_eval(db.query_prompt(query))
|
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
||||||
query = extract_code_block(chat.last_response_content())
|
context = toolkit.get_context()
|
||||||
assert query.startswith("SELECT")
|
assert "table_info" in context
|
||||||
|
assert "table_names" in context
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
#
|
||||||
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
# chat = await chat.sql_eval(db.query_prompt(query))
|
||||||
async def test_data_query():
|
# query = extract_code_block(chat.last_response_content())
|
||||||
q = "Find the average score of all the sudents"
|
# assert query.startswith("SELECT")
|
||||||
|
#
|
||||||
llm = Chat(
|
#
|
||||||
api_key=settings.llm_api_key,
|
# @pytest.mark.asyncio
|
||||||
temperature=0.5,
|
# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
|
||||||
)
|
# async def test_data_query():
|
||||||
db = SessionDB(
|
# q = "Find the average score of all the sudents"
|
||||||
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
#
|
||||||
).load_folder()
|
# llm = Chat(
|
||||||
|
# api_key=settings.llm_api_key,
|
||||||
chat = await llm.sql_eval(db.query_prompt(q))
|
# temperature=0.5,
|
||||||
query = extract_code_block(chat.last_response_content())
|
# )
|
||||||
result: pl.DataFrame = db.query(query).pl()
|
# db = SessionDB(
|
||||||
|
# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
|
||||||
assert result.shape[0] == 1
|
# ).load_folder()
|
||||||
assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
#
|
||||||
|
# chat = await llm.sql_eval(db.query_prompt(q))
|
||||||
|
# query = extract_code_block(chat.last_response_content())
|
||||||
|
# result: pl.DataFrame = db.query(query).pl()
|
||||||
|
#
|
||||||
|
# assert result.shape[0] == 1
|
||||||
|
# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
|
||||||
|
#
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.config import StorageEngines, Settings
|
||||||
from hellocomputer.db.users import OwnershipDB, UserDB
|
from hellocomputer.db.users import OwnershipDB, UserDB
|
||||||
|
|
||||||
TEST_STORAGE = StorageEngines.local
|
settings = Settings(
|
||||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
storage_engine=StorageEngines.local,
|
||||||
|
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_create_user():
|
def test_create_user():
|
||||||
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
user = UserDB(settings)
|
||||||
user_data = {"name": "John Doe", "email": "[email protected]"}
|
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||||
user_data = user.dump_user_record(user_data, record_id="test")
|
user_data = user.dump_user_record(user_data, record_id="test")
|
||||||
|
|
||||||
|
@ -17,7 +19,7 @@ def test_create_user():
|
||||||
|
|
||||||
|
|
||||||
def test_user_exists():
|
def test_user_exists():
|
||||||
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
user = UserDB(settings)
|
||||||
user_data = {"name": "John Doe", "email": "[email protected]"}
|
user_data = {"name": "John Doe", "email": "[email protected]"}
|
||||||
user.dump_user_record(user_data, record_id="test")
|
user.dump_user_record(user_data, record_id="test")
|
||||||
|
|
||||||
|
@ -27,7 +29,7 @@ def test_user_exists():
|
||||||
|
|
||||||
def test_assign_owner():
|
def test_assign_owner():
|
||||||
assert (
|
assert (
|
||||||
OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip(
|
OwnershipDB(settings).set_ownersip(
|
||||||
"something.something@something", "testsession", "test"
|
"something.something@something", "testsession", "test"
|
||||||
)
|
)
|
||||||
== "testsession"
|
== "testsession"
|
||||||
|
@ -35,6 +37,6 @@ def test_assign_owner():
|
||||||
|
|
||||||
|
|
||||||
def test_get_sessions():
|
def test_get_sessions():
|
||||||
assert OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).sessions(
|
assert OwnershipDB(settings).sessions("something.something@something") == [
|
||||||
"something.something@something"
|
"testsession"
|
||||||
) == ["testsession"]
|
]
|
||||||
|
|
Loading…
Reference in a new issue