Kind of refactored everything
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-25 00:10:09 +02:00
parent 910e91a391
commit f4b9c30a17
15 changed files with 213 additions and 226 deletions

View file

@ -1,4 +1,5 @@
langchain
langgraph
langchain-community
langchain-fireworks
langchain-openai
@ -11,6 +12,7 @@ duckdb
duckdb-engine
polars
pyarrow
pydantic
pyjwt[crypto]
python-multipart
authlib

View file

@ -1,20 +1,67 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic import model_validator
from pathlib import Path
from typing import Self, Optional
from enum import StrEnum
class StorageEngines(StrEnum):
local = "local"
gcs = "GCS"
class Settings(BaseSettings):
storage_engine: StorageEngines = "local"
base_url: str = "http://localhost:8000"
llm_api_key: str = "Awesome API"
llm_base_url: str = "Awessome Endpoint"
gcs_access: str = "access"
gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"
llm_base_url: Optional[str] = None
gcs_access: Optional[str] = None
gcs_secret: Optional[str] = None
gcs_bucketname: Optional[str] = None
path: Optional[Path] = None
auth: bool = True
auth0_client_id: str = ""
auth0_client_secret: str = ""
auth0_domain: str = ""
app_secret_key: str = ""
auth0_client_id: Optional[str] = None
auth0_client_secret: Optional[str] = None
auth0_domain: Optional[str] = None
app_secret_key: Optional[str] = None
model_config = SettingsConfigDict(env_file=".env")
@model_validator(mode="after")
def check_cloud_storage(self) -> Self:
if self.storage_engine == StorageEngines.gcs:
if any(
(
self.gcs_access is None,
self.gcs_bucketname is None,
self.gcs_secret is None,
)
):
raise ValueError("Cloud storage configuration not provided")
return self
@model_validator(mode="after")
def check_auth_config(self) -> Self:
if not self.auth:
if any(
(
self.auth0_client_id is None,
self.auth0_client_secret is None,
self.auth0_domain is None,
self.app_secret_key is None,
)
):
raise ValueError("Auth is enabled but no auth config is providedc")
return self
@model_validator(mode="after")
def check_local_storage(self) -> Self:
if self.storage_engine == StorageEngines.local:
if self.path is None:
raise ValueError("Local storage requires a path")
return self
settings = Settings()

View file

@ -1,67 +1,28 @@
from enum import StrEnum
from pathlib import Path
import duckdb
from hellocomputer.config import Settings, StorageEngines
from sqlalchemy import create_engine
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(
self,
storage_engine: StorageEngines,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
settings: Settings,
):
self.db = duckdb.connect(":memory:")
# Assume extension autoloading
# self.db.sql("load httpfs")
# self.db.sql("load spatial")
self.storage_engine = storage_engine
self.sheets = tuple()
self.loaded = False
self.storage_engine = settings.storage_engine
self.engine = create_engine("duckdb:///:memory:")
self.db = self.engine.raw_connection()
if storage_engine == StorageEngines.gcs:
if all(
(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
)
):
self.db.sql(f"""
if self.storage_engine == StorageEngines.gcs:
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
KEY_ID '{settings.gcs_access}',
SECRET '{settings.gcs_secret}')
""")
self.path_prefix = f"gs://{bucket}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, and bucket keyword arguments"
)
self.path_prefix = f"gs://{settings.gcs_bucketname}"
elif storage_engine == StorageEngines.local:
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
elif settings.storage_engine == StorageEngines.local:
self.path_prefix = settings.path
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
@property
def engine(self):
return create_engine("duckdb:///:memory:")

View file

@ -3,30 +3,22 @@ from pathlib import Path
import duckdb
from typing_extensions import Self
from langchain_community.utilities.sql_database import SQLDatabase
from hellocomputer.db import StorageEngines
from hellocomputer.config import Settings, StorageEngines
from . import DDB
class SessionDB(DDB):
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
def __init__(self, settings: Settings, sid: str):
super().__init__(settings=settings)
self.sid = sid
# Override storage engine for sessions
if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gs://{bucket}/sessions/{sid}"
elif storage_engine == StorageEngines.local:
self.path_prefix = path / "sessions" / sid
if settings.storage_engine == StorageEngines.gcs:
self.path_prefix = f"gs://{settings.gcs_bucketname}/sessions/{sid}"
elif settings.storage_engine == StorageEngines.local:
self.path_prefix = settings.path / "sessions" / sid
def load_xls(self, xls_path: Path) -> Self:
"""For some reason, the header is not loaded"""
@ -175,3 +167,7 @@ class SessionDB(DDB):
"Return just the SQL statement",
]
)
@property
def llmsql(self):
return SQLDatabase(self.engine, ignore_tables=["metadata"])

View file

@ -1,35 +1,30 @@
import json
import os
from datetime import datetime
from pathlib import Path
from typing import List
from uuid import UUID, uuid4
import duckdb
import polars as pl
from . import DDB, StorageEngines
from hellocomputer.db import DDB
from hellocomputer.config import Settings, StorageEngines
class UserDB(DDB):
def __init__(
self,
storage_engine: StorageEngines,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
settings: Settings,
):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
super().__init__(settings)
if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gs://{bucket}/users"
if settings.storage_engine == StorageEngines.gcs:
self.path_prefix = f"gs://{settings.gcs_bucketname}/users"
elif storage_engine == StorageEngines.local:
self.path_prefix = path / "users"
elif settings.storage_engine == StorageEngines.local:
self.path_prefix = settings.path / "users"
self.storage_engine = storage_engine
self.storage_engine = settings.storage_engine
def dump_user_record(self, user_data: dict, record_id: UUID | None = None):
df = pl.from_dict(user_data) # noqa
@ -59,20 +54,15 @@ class UserDB(DDB):
class OwnershipDB(DDB):
def __init__(
self,
storage_engine: StorageEngines,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
settings: Settings,
):
super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs)
super().__init__(settings)
if storage_engine == StorageEngines.gcs:
self.path_prefix = f"gs://{bucket}/owners"
if settings.storage_engine == StorageEngines.gcs:
self.path_prefix = f"gs://{settings.gcs_bucketname}/owners"
elif storage_engine == StorageEngines.local:
self.path_prefix = path / "owners"
elif settings.storage_engine == StorageEngines.local:
self.path_prefix = settings.path / "owners"
def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None):
now = datetime.now().isoformat()

View file

@ -3,7 +3,6 @@ from pathlib import Path
from anyio import open_file
from langchain_core.prompts import PromptTemplate
from langchain_fireworks import Fireworks
import hellocomputer
@ -26,57 +25,8 @@ class Prompts:
@classmethod
async def general(cls):
return await cls.getter("general_prompt")
return PromptTemplate.from_template(await cls.getter("general_prompt"))
@classmethod
async def sql(cls):
return await cls.getter("sql_prompt")
class Chat:
@staticmethod
def raise_no_key(api_key):
if api_key:
return api_key
elif api_key is None:
raise ValueError(
"You need to provide a valid API in the api_key init argument"
)
else:
raise ValueError("You need to provide a valid API key")
def __init__(
self,
model: AvailableModels = AvailableModels.mixtral_8x7b,
api_key: str = "",
temperature: float = 0.5,
):
self.model = model
self.api_key = self.raise_no_key(api_key)
self.messages = []
self.responses = []
self.model: Fireworks = Fireworks(
model=model, temperature=temperature, api_key=self.api_key
)
async def eval(self, task):
prompt = PromptTemplate.from_template(await Prompts.general())
response = await self.model.ainvoke(prompt.format(query=task))
self.responses.append(response)
return self
async def sql_eval(self, question):
prompt = PromptTemplate.from_template(await Prompts.sql())
response = await self.model.ainvoke(prompt.format(query=question))
self.responses.append(response)
return self
def last_response_content(self):
last_response = self.responses[-1]
return last_response
def last_response_metadata(self):
return self.responses[-1].response_metadata
return PromptTemplate.from_template(await cls.getter("sql_prompt"))

View file

@ -1,7 +1,7 @@
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from hellocomputer.db import StorageEngines
from hellocomputer.config import StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block

View file

@ -4,7 +4,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from starlette.requests import Request
from hellocomputer.config import settings
from hellocomputer.db import StorageEngines
from hellocomputer.config import StorageEngines
from hellocomputer.db.users import UserDB
router = APIRouter()

View file

@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse
from starlette.requests import Request
from ..config import settings
from ..db import StorageEngines
from ..config import StorageEngines
from ..db.sessions import SessionDB
from ..db.users import OwnershipDB

View file

@ -5,7 +5,7 @@ from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from starlette.requests import Request
from hellocomputer.db import StorageEngines
from hellocomputer.config import StorageEngines
from hellocomputer.db.users import OwnershipDB
from ..auth import get_user_email

View file

@ -0,0 +1,26 @@
from pydantic import BaseModel, Field
from typing import Type
from langchain.tools import BaseTool
from hellocomputer.db.sessions import SessionDB
from hellocomputer.config import settings
class DuckdbQueryInput(BaseModel):
query: str = Field(description="Question to be translated to a SQL statement")
session_id: str = Field(description="Session ID necessary to fetch the data")
class DuckdbQueryTool(BaseTool):
name: str = "Calculator"
description: str = "Tool to evaluate mathemetical expressions"
args_schema: Type[BaseModel] = DuckdbQueryInput
def _run(self, query: str, session_id: str) -> str:
"""Run the query"""
db = SessionDB(settings, session_id)
return "Table"
async def _arun(self, query: str, session_id: str) -> str:
"""Use the tool asynchronously."""
db = SessionDB(settings, session_id)
return "Table"

View file

@ -1,47 +1,44 @@
from pathlib import Path
import hellocomputer
from hellocomputer.db import StorageEngines
from hellocomputer.config import StorageEngines, Settings
from hellocomputer.db.sessions import SessionDB
TEST_STORAGE = StorageEngines.local
settings = Settings(
storage_engine=StorageEngines.local,
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
)
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump():
db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test")
db = SessionDB(settings, sid="test")
db.load_xls(TEST_XLS_PATH).dump()
assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "sessions" / "test" / "answers.csv").exists()
assert (settings.path / "sessions" / "test" / "answers.csv").exists()
def test_load():
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
db = SessionDB(settings, sid="test").load_folder()
results = db.query("select * from answers").fetchall()
assert len(results) == 6
def test_load_description():
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
db = SessionDB(settings, sid="test").load_folder()
file_description = db.load_description()
assert file_description.startswith("answers")
def test_schema():
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
db = SessionDB(settings, sid="test").load_folder()
schema = []
for sheet in db.sheets:
schema.append(db.table_schema(sheet))
@ -50,9 +47,7 @@ def test_schema():
def test_query_prompt():
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
db = SessionDB(settings, sid="test").load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"

View file

@ -5,11 +5,5 @@ from langchain.prompts import PromptTemplate
@pytest.mark.asyncio
async def test_get_general_prompt():
general: str = await Prompts.general()
assert general.startswith("You're a helpful assistant")
@pytest.mark.asyncio
async def test_general_templated():
prompt = PromptTemplate.from_template(await Prompts.general())
assert "Do as I say" in prompt.format(query="Do as I say")
general: PromptTemplate = await Prompts.general()
assert general.format(query="whatever").startswith("You're a helpful assistant")

View file

@ -3,65 +3,89 @@ from pathlib import Path
import hellocomputer
import polars as pl
import pytest
from hellocomputer.config import settings
from hellocomputer.db import StorageEngines
from hellocomputer.config import Settings, StorageEngines
from hellocomputer.db.sessions import SessionDB
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
from hellocomputer.models import AvailableModels
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
settings = Settings(
storage_engine=StorageEngines.local,
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
)
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
SID = "test"
@pytest.mark.asyncio
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_chat_simple():
chat = Chat(api_key=settings.llm_api_key, temperature=0)
chat = await chat.eval("Say literlly 'Hello'")
assert "Hello" in chat.last_response_content()
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.mixtral_8x7b,
temperature=0.5,
)
prompt = ChatPromptTemplate.from_template(
"""Say literally {word}, a single word. Don't be verbose,
I'll be disappointed if you say more than a single word"""
)
chain = prompt | llm
response = await chain.ainvoke({"word": "Hello"})
assert response.content.lower().startswith("hello")
@pytest.mark.asyncio
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
async def test_query_context():
db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql
chat = Chat(
llm = ChatOpenAI(
base_url=settings.llm_base_url,
api_key=settings.llm_api_key,
model=AvailableModels.mixtral_8x7b,
temperature=0.5,
)
db = SessionDB(
storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID
).load_xls(TEST_XLS_PATH)
chat = await chat.sql_eval(db.query_prompt(query))
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
context = toolkit.get_context()
assert "table_info" in context
assert "table_names" in context
@pytest.mark.asyncio
@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
async def test_data_query():
q = "Find the average score of all the sudents"
llm = Chat(
api_key=settings.llm_api_key,
temperature=0.5,
)
db = SessionDB(
storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
).load_folder()
chat = await llm.sql_eval(db.query_prompt(q))
query = extract_code_block(chat.last_response_content())
result: pl.DataFrame = db.query(query).pl()
assert result.shape[0] == 1
assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
#
# chat = await chat.sql_eval(db.query_prompt(query))
# query = extract_code_block(chat.last_response_content())
# assert query.startswith("SELECT")
#
#
# @pytest.mark.asyncio
# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set")
# async def test_data_query():
# q = "Find the average score of all the sudents"
#
# llm = Chat(
# api_key=settings.llm_api_key,
# temperature=0.5,
# )
# db = SessionDB(
# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test"
# ).load_folder()
#
# chat = await llm.sql_eval(db.query_prompt(q))
# query = extract_code_block(chat.last_response_content())
# result: pl.DataFrame = db.query(query).pl()
#
# assert result.shape[0] == 1
# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5
#

View file

@ -1,15 +1,17 @@
from pathlib import Path
import hellocomputer
from hellocomputer.db import StorageEngines
from hellocomputer.config import StorageEngines, Settings
from hellocomputer.db.users import OwnershipDB, UserDB
TEST_STORAGE = StorageEngines.local
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
settings = Settings(
storage_engine=StorageEngines.local,
path=Path(hellocomputer.__file__).parents[2] / "test" / "output",
)
def test_create_user():
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
user = UserDB(settings)
user_data = {"name": "John Doe", "email": "[email protected]"}
user_data = user.dump_user_record(user_data, record_id="test")
@ -17,7 +19,7 @@ def test_create_user():
def test_user_exists():
user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
user = UserDB(settings)
user_data = {"name": "John Doe", "email": "[email protected]"}
user.dump_user_record(user_data, record_id="test")
@ -27,7 +29,7 @@ def test_user_exists():
def test_assign_owner():
assert (
OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip(
OwnershipDB(settings).set_ownersip(
"something.something@something", "testsession", "test"
)
== "testsession"
@ -35,6 +37,6 @@ def test_assign_owner():
def test_get_sessions():
assert OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).sessions(
"something.something@something"
) == ["testsession"]
assert OwnershipDB(settings).sessions("something.something@something") == [
"testsession"
]