From f4b9c30a17415df3de38909822ba6c0a4e3485a5 Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Thu, 25 Jul 2024 00:10:09 +0200 Subject: [PATCH] Kind of refactored everything --- requirements.in | 2 + src/hellocomputer/config.py | 63 +++++++++++++++--- src/hellocomputer/db/__init__.py | 63 ++++-------------- src/hellocomputer/db/sessions.py | 28 ++++---- src/hellocomputer/db/users.py | 40 +++++------ src/hellocomputer/models.py | 54 +-------------- src/hellocomputer/routers/analysis.py | 2 +- src/hellocomputer/routers/auth.py | 2 +- src/hellocomputer/routers/files.py | 2 +- src/hellocomputer/routers/sessions.py | 2 +- src/hellocomputer/tools.py | 26 ++++++++ test/test_data.py | 29 ++++---- test/test_prompts.py | 10 +-- test/test_query.py | 96 +++++++++++++++++---------- test/test_user.py | 20 +++--- 15 files changed, 213 insertions(+), 226 deletions(-) create mode 100644 src/hellocomputer/tools.py diff --git a/requirements.in b/requirements.in index 4e3cad0..8914e60 100644 --- a/requirements.in +++ b/requirements.in @@ -1,4 +1,5 @@ langchain +langgraph langchain-community langchain-fireworks langchain-openai @@ -11,6 +12,7 @@ duckdb duckdb-engine polars pyarrow +pydantic pyjwt[crypto] python-multipart authlib diff --git a/src/hellocomputer/config.py b/src/hellocomputer/config.py index d2cf14e..0fda686 100644 --- a/src/hellocomputer/config.py +++ b/src/hellocomputer/config.py @@ -1,20 +1,67 @@ from pydantic_settings import BaseSettings, SettingsConfigDict +from pydantic import model_validator +from pathlib import Path +from typing import Self, Optional +from enum import StrEnum + + +class StorageEngines(StrEnum): + local = "local" + gcs = "GCS" class Settings(BaseSettings): + storage_engine: StorageEngines = "local" base_url: str = "http://localhost:8000" llm_api_key: str = "Awesome API" - llm_base_url: str = "Awessome Endpoint" - gcs_access: str = "access" - gcs_secret: str = "secret" - gcs_bucketname: str = "bucket" + llm_base_url: Optional[str] = None + gcs_access: Optional[str] = None + gcs_secret: Optional[str] = None + gcs_bucketname: Optional[str] = None + path: Optional[Path] = None auth: bool = True - auth0_client_id: str = "" - auth0_client_secret: str = "" - auth0_domain: str = "" - app_secret_key: str = "" + auth0_client_id: Optional[str] = None + auth0_client_secret: Optional[str] = None + auth0_domain: Optional[str] = None + app_secret_key: Optional[str] = None model_config = SettingsConfigDict(env_file=".env") + @model_validator(mode="after") + def check_cloud_storage(self) -> Self: + if self.storage_engine == StorageEngines.gcs: + if any( + ( + self.gcs_access is None, + self.gcs_bucketname is None, + self.gcs_secret is None, + ) + ): + raise ValueError("Cloud storage configuration not provided") + return self + + @model_validator(mode="after") + def check_auth_config(self) -> Self: + if not self.auth: + if any( + ( + self.auth0_client_id is None, + self.auth0_client_secret is None, + self.auth0_domain is None, + self.app_secret_key is None, + ) + ): + raise ValueError("Auth is enabled but no auth config is providedc") + + return self + + @model_validator(mode="after") + def check_local_storage(self) -> Self: + if self.storage_engine == StorageEngines.local: + if self.path is None: + raise ValueError("Local storage requires a path") + + return self + settings = Settings() diff --git a/src/hellocomputer/db/__init__.py b/src/hellocomputer/db/__init__.py index e35dd3e..5042077 100644 --- a/src/hellocomputer/db/__init__.py +++ b/src/hellocomputer/db/__init__.py @@ -1,67 +1,28 @@ -from enum import StrEnum -from pathlib import Path - -import duckdb +from hellocomputer.config import Settings, StorageEngines from sqlalchemy import create_engine -class StorageEngines(StrEnum): - local = "Local" - gcs = "GCS" - - class DDB: def __init__( self, - storage_engine: StorageEngines, - path: Path | None = None, - gcs_access: str | None = None, - gcs_secret: str | None = None, - bucket: str | None = None, - **kwargs, + settings: Settings, ): - self.db = duckdb.connect(":memory:") - # Assume extension autoloading - # self.db.sql("load httpfs") - # self.db.sql("load spatial") - self.storage_engine = storage_engine - self.sheets = tuple() - self.loaded = False + self.storage_engine = settings.storage_engine + self.engine = create_engine("duckdb:///:memory:") + self.db = self.engine.raw_connection() - if storage_engine == StorageEngines.gcs: - if all( - ( - gcs_access is not None, - gcs_secret is not None, - bucket is not None, - ) - ): - self.db.sql(f""" + if self.storage_engine == StorageEngines.gcs: + self.db.sql(f""" CREATE SECRET ( TYPE GCS, - KEY_ID '{gcs_access}', - SECRET '{gcs_secret}') + KEY_ID '{settings.gcs_access}', + SECRET '{settings.gcs_secret}') """) - self.path_prefix = f"gs://{bucket}" - else: - raise ValueError( - "With GCS storage engine you need to provide " - "the gcs_access, gcs_secret, and bucket keyword arguments" - ) + self.path_prefix = f"gs://{settings.gcs_bucketname}" - elif storage_engine == StorageEngines.local: - if path is not None: - self.path_prefix = path - else: - raise ValueError( - "With local storage you need to provide the path keyword argument" - ) - + elif settings.storage_engine == StorageEngines.local: + self.path_prefix = settings.path def query(self, sql, *args, **kwargs): return self.db.query(sql, *args, **kwargs) - - @property - def engine(self): - return create_engine("duckdb:///:memory:") diff --git a/src/hellocomputer/db/sessions.py b/src/hellocomputer/db/sessions.py index c98b472..eb6559b 100644 --- a/src/hellocomputer/db/sessions.py +++ b/src/hellocomputer/db/sessions.py @@ -3,30 +3,22 @@ from pathlib import Path import duckdb from typing_extensions import Self +from langchain_community.utilities.sql_database import SQLDatabase -from hellocomputer.db import StorageEngines +from hellocomputer.config import Settings, StorageEngines from . import DDB class SessionDB(DDB): - def __init__( - self, - storage_engine: StorageEngines, - sid: str | None = None, - path: Path | None = None, - gcs_access: str | None = None, - gcs_secret: str | None = None, - bucket: str | None = None, - **kwargs, - ): - super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) + def __init__(self, settings: Settings, sid: str): + super().__init__(settings=settings) self.sid = sid # Override storage engine for sessions - if storage_engine == StorageEngines.gcs: - self.path_prefix = f"gs://{bucket}/sessions/{sid}" - elif storage_engine == StorageEngines.local: - self.path_prefix = path / "sessions" / sid + if settings.storage_engine == StorageEngines.gcs: + self.path_prefix = f"gs://{settings.gcs_bucketname}/sessions/{sid}" + elif settings.storage_engine == StorageEngines.local: + self.path_prefix = settings.path / "sessions" / sid def load_xls(self, xls_path: Path) -> Self: """For some reason, the header is not loaded""" @@ -175,3 +167,7 @@ class SessionDB(DDB): "Return just the SQL statement", ] ) + + @property + def llmsql(self): + return SQLDatabase(self.engine, ignore_tables=["metadata"]) diff --git a/src/hellocomputer/db/users.py b/src/hellocomputer/db/users.py index 088e6ed..b4f84dc 100644 --- a/src/hellocomputer/db/users.py +++ b/src/hellocomputer/db/users.py @@ -1,35 +1,30 @@ import json import os from datetime import datetime -from pathlib import Path from typing import List from uuid import UUID, uuid4 import duckdb import polars as pl -from . import DDB, StorageEngines +from hellocomputer.db import DDB +from hellocomputer.config import Settings, StorageEngines class UserDB(DDB): def __init__( self, - storage_engine: StorageEngines, - path: Path | None = None, - gcs_access: str | None = None, - gcs_secret: str | None = None, - bucket: str | None = None, - **kwargs, + settings: Settings, ): - super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) + super().__init__(settings) - if storage_engine == StorageEngines.gcs: - self.path_prefix = f"gs://{bucket}/users" + if settings.storage_engine == StorageEngines.gcs: + self.path_prefix = f"gs://{settings.gcs_bucketname}/users" - elif storage_engine == StorageEngines.local: - self.path_prefix = path / "users" + elif settings.storage_engine == StorageEngines.local: + self.path_prefix = settings.path / "users" - self.storage_engine = storage_engine + self.storage_engine = settings.storage_engine def dump_user_record(self, user_data: dict, record_id: UUID | None = None): df = pl.from_dict(user_data) # noqa @@ -59,20 +54,15 @@ class UserDB(DDB): class OwnershipDB(DDB): def __init__( self, - storage_engine: StorageEngines, - path: Path | None = None, - gcs_access: str | None = None, - gcs_secret: str | None = None, - bucket: str | None = None, - **kwargs, + settings: Settings, ): - super().__init__(storage_engine, path, gcs_access, gcs_secret, bucket, **kwargs) + super().__init__(settings) - if storage_engine == StorageEngines.gcs: - self.path_prefix = f"gs://{bucket}/owners" + if settings.storage_engine == StorageEngines.gcs: + self.path_prefix = f"gs://{settings.gcs_bucketname}/owners" - elif storage_engine == StorageEngines.local: - self.path_prefix = path / "owners" + elif settings.storage_engine == StorageEngines.local: + self.path_prefix = settings.path / "owners" def set_ownersip(self, user_email: str, sid: str, record_id: UUID | None = None): now = datetime.now().isoformat() diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index 5840901..09c3982 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -3,7 +3,6 @@ from pathlib import Path from anyio import open_file from langchain_core.prompts import PromptTemplate -from langchain_fireworks import Fireworks import hellocomputer @@ -26,57 +25,8 @@ class Prompts: @classmethod async def general(cls): - return await cls.getter("general_prompt") + return PromptTemplate.from_template(await cls.getter("general_prompt")) @classmethod async def sql(cls): - return await cls.getter("sql_prompt") - - -class Chat: - @staticmethod - def raise_no_key(api_key): - if api_key: - return api_key - elif api_key is None: - raise ValueError( - "You need to provide a valid API in the api_key init argument" - ) - else: - raise ValueError("You need to provide a valid API key") - - def __init__( - self, - model: AvailableModels = AvailableModels.mixtral_8x7b, - api_key: str = "", - temperature: float = 0.5, - ): - self.model = model - self.api_key = self.raise_no_key(api_key) - self.messages = [] - self.responses = [] - - self.model: Fireworks = Fireworks( - model=model, temperature=temperature, api_key=self.api_key - ) - - async def eval(self, task): - prompt = PromptTemplate.from_template(await Prompts.general()) - - response = await self.model.ainvoke(prompt.format(query=task)) - self.responses.append(response) - return self - - async def sql_eval(self, question): - prompt = PromptTemplate.from_template(await Prompts.sql()) - - response = await self.model.ainvoke(prompt.format(query=question)) - self.responses.append(response) - return self - - def last_response_content(self): - last_response = self.responses[-1] - return last_response - - def last_response_metadata(self): - return self.responses[-1].response_metadata + return PromptTemplate.from_template(await cls.getter("sql_prompt")) diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 856717d..7ac6708 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,7 +1,7 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse -from hellocomputer.db import StorageEngines +from hellocomputer.config import StorageEngines from hellocomputer.db.sessions import SessionDB from hellocomputer.extraction import extract_code_block diff --git a/src/hellocomputer/routers/auth.py b/src/hellocomputer/routers/auth.py index 58bd85a..c9c986b 100644 --- a/src/hellocomputer/routers/auth.py +++ b/src/hellocomputer/routers/auth.py @@ -4,7 +4,7 @@ from fastapi.responses import HTMLResponse, RedirectResponse from starlette.requests import Request from hellocomputer.config import settings -from hellocomputer.db import StorageEngines +from hellocomputer.config import StorageEngines from hellocomputer.db.users import UserDB router = APIRouter() diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index d3bf9cb..6ae2f12 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse from starlette.requests import Request from ..config import settings -from ..db import StorageEngines +from ..config import StorageEngines from ..db.sessions import SessionDB from ..db.users import OwnershipDB diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index c5e12ca..5567f11 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -5,7 +5,7 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse from starlette.requests import Request -from hellocomputer.db import StorageEngines +from hellocomputer.config import StorageEngines from hellocomputer.db.users import OwnershipDB from ..auth import get_user_email diff --git a/src/hellocomputer/tools.py b/src/hellocomputer/tools.py new file mode 100644 index 0000000..fad16e2 --- /dev/null +++ b/src/hellocomputer/tools.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel, Field +from typing import Type +from langchain.tools import BaseTool +from hellocomputer.db.sessions import SessionDB +from hellocomputer.config import settings + + +class DuckdbQueryInput(BaseModel): + query: str = Field(description="Question to be translated to a SQL statement") + session_id: str = Field(description="Session ID necessary to fetch the data") + + +class DuckdbQueryTool(BaseTool): + name: str = "Calculator" + description: str = "Tool to evaluate mathemetical expressions" + args_schema: Type[BaseModel] = DuckdbQueryInput + + def _run(self, query: str, session_id: str) -> str: + """Run the query""" + db = SessionDB(settings, session_id) + return "Table" + + async def _arun(self, query: str, session_id: str) -> str: + """Use the tool asynchronously.""" + db = SessionDB(settings, session_id) + return "Table" diff --git a/test/test_data.py b/test/test_data.py index 1d82b1e..bda3c72 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,47 +1,44 @@ from pathlib import Path import hellocomputer -from hellocomputer.db import StorageEngines +from hellocomputer.config import StorageEngines, Settings from hellocomputer.db.sessions import SessionDB -TEST_STORAGE = StorageEngines.local +settings = Settings( + storage_engine=StorageEngines.local, + path=Path(hellocomputer.__file__).parents[2] / "test" / "output", +) + TEST_XLS_PATH = ( Path(hellocomputer.__file__).parents[2] / "test" / "data" / "TestExcelHelloComputer.xlsx" ) -TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" def test_0_dump(): - db = SessionDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test") + db = SessionDB(settings, sid="test") db.load_xls(TEST_XLS_PATH).dump() assert db.sheets == ("answers",) - assert (TEST_OUTPUT_FOLDER / "sessions" / "test" / "answers.csv").exists() + assert (settings.path / "sessions" / "test" / "answers.csv").exists() def test_load(): - db = SessionDB( - storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" - ).load_folder() + db = SessionDB(settings, sid="test").load_folder() results = db.query("select * from answers").fetchall() assert len(results) == 6 def test_load_description(): - db = SessionDB( - storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" - ).load_folder() + db = SessionDB(settings, sid="test").load_folder() file_description = db.load_description() assert file_description.startswith("answers") def test_schema(): - db = SessionDB( - storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" - ).load_folder() + db = SessionDB(settings, sid="test").load_folder() schema = [] for sheet in db.sheets: schema.append(db.table_schema(sheet)) @@ -50,9 +47,7 @@ def test_schema(): def test_query_prompt(): - db = SessionDB( - storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" - ).load_folder() + db = SessionDB(settings, sid="test").load_folder() assert db.query_prompt("Find the average score of all students").startswith( "The following sentence" diff --git a/test/test_prompts.py b/test/test_prompts.py index ced5dd2..62d275d 100644 --- a/test/test_prompts.py +++ b/test/test_prompts.py @@ -5,11 +5,5 @@ from langchain.prompts import PromptTemplate @pytest.mark.asyncio async def test_get_general_prompt(): - general: str = await Prompts.general() - assert general.startswith("You're a helpful assistant") - - -@pytest.mark.asyncio -async def test_general_templated(): - prompt = PromptTemplate.from_template(await Prompts.general()) - assert "Do as I say" in prompt.format(query="Do as I say") + general: PromptTemplate = await Prompts.general() + assert general.format(query="whatever").startswith("You're a helpful assistant") diff --git a/test/test_query.py b/test/test_query.py index 759ae69..c4c188f 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -3,65 +3,89 @@ from pathlib import Path import hellocomputer import polars as pl import pytest -from hellocomputer.config import settings -from hellocomputer.db import StorageEngines +from hellocomputer.config import Settings, StorageEngines from hellocomputer.db.sessions import SessionDB from hellocomputer.extraction import extract_code_block -from hellocomputer.models import Chat +from hellocomputer.models import AvailableModels +from langchain_core.prompts import ChatPromptTemplate +from langchain_community.agent_toolkits import SQLDatabaseToolkit +from langchain_openai import ChatOpenAI + +settings = Settings( + storage_engine=StorageEngines.local, + path=Path(hellocomputer.__file__).parents[2] / "test" / "output", +) -TEST_STORAGE = StorageEngines.local -TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_XLS_PATH = ( Path(hellocomputer.__file__).parents[2] / "test" / "data" / "TestExcelHelloComputer.xlsx" ) + SID = "test" @pytest.mark.asyncio @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") async def test_chat_simple(): - chat = Chat(api_key=settings.llm_api_key, temperature=0) - chat = await chat.eval("Say literlly 'Hello'") - assert "Hello" in chat.last_response_content() + llm = ChatOpenAI( + base_url=settings.llm_base_url, + api_key=settings.llm_api_key, + model=AvailableModels.mixtral_8x7b, + temperature=0.5, + ) + prompt = ChatPromptTemplate.from_template( + """Say literally {word}, a single word. Don't be verbose, + I'll be disappointed if you say more than a single word""" + ) + chain = prompt | llm + response = await chain.ainvoke({"word": "Hello"}) + + assert response.content.lower().startswith("hello") @pytest.mark.asyncio @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") -async def test_simple_data_query(): - query = "write a query that finds the average score of all students in the current database" +async def test_query_context(): + db = SessionDB(settings, sid=SID).load_xls(TEST_XLS_PATH).llmsql - chat = Chat( + llm = ChatOpenAI( + base_url=settings.llm_base_url, api_key=settings.llm_api_key, + model=AvailableModels.mixtral_8x7b, temperature=0.5, ) - db = SessionDB( - storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID - ).load_xls(TEST_XLS_PATH) - chat = await chat.sql_eval(db.query_prompt(query)) - query = extract_code_block(chat.last_response_content()) - assert query.startswith("SELECT") + toolkit = SQLDatabaseToolkit(db=db, llm=llm) + context = toolkit.get_context() + assert "table_info" in context + assert "table_names" in context -@pytest.mark.asyncio -@pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") -async def test_data_query(): - q = "Find the average score of all the sudents" - - llm = Chat( - api_key=settings.llm_api_key, - temperature=0.5, - ) - db = SessionDB( - storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" - ).load_folder() - - chat = await llm.sql_eval(db.query_prompt(q)) - query = extract_code_block(chat.last_response_content()) - result: pl.DataFrame = db.query(query).pl() - - assert result.shape[0] == 1 - assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5 +# +# chat = await chat.sql_eval(db.query_prompt(query)) +# query = extract_code_block(chat.last_response_content()) +# assert query.startswith("SELECT") +# +# +# @pytest.mark.asyncio +# @pytest.mark.skipif(settings.llm_api_key == "Awesome API", reason="API Key not set") +# async def test_data_query(): +# q = "Find the average score of all the sudents" +# +# llm = Chat( +# api_key=settings.llm_api_key, +# temperature=0.5, +# ) +# db = SessionDB( +# storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" +# ).load_folder() +# +# chat = await llm.sql_eval(db.query_prompt(q)) +# query = extract_code_block(chat.last_response_content()) +# result: pl.DataFrame = db.query(query).pl() +# +# assert result.shape[0] == 1 +# assert result.select([pl.col("avg(Score)")]).to_series()[0] == 0.5 +# diff --git a/test/test_user.py b/test/test_user.py index eb38df0..959826d 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -1,15 +1,17 @@ from pathlib import Path import hellocomputer -from hellocomputer.db import StorageEngines +from hellocomputer.config import StorageEngines, Settings from hellocomputer.db.users import OwnershipDB, UserDB -TEST_STORAGE = StorageEngines.local -TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" +settings = Settings( + storage_engine=StorageEngines.local, + path=Path(hellocomputer.__file__).parents[2] / "test" / "output", +) def test_create_user(): - user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + user = UserDB(settings) user_data = {"name": "John Doe", "email": "[email protected]"} user_data = user.dump_user_record(user_data, record_id="test") @@ -17,7 +19,7 @@ def test_create_user(): def test_user_exists(): - user = UserDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + user = UserDB(settings) user_data = {"name": "John Doe", "email": "[email protected]"} user.dump_user_record(user_data, record_id="test") @@ -27,7 +29,7 @@ def test_user_exists(): def test_assign_owner(): assert ( - OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).set_ownersip( + OwnershipDB(settings).set_ownersip( "something.something@something", "testsession", "test" ) == "testsession" @@ -35,6 +37,6 @@ def test_assign_owner(): def test_get_sessions(): - assert OwnershipDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).sessions( - "something.something@something" - ) == ["testsession"] + assert OwnershipDB(settings).sessions("something.something@something") == [ + "testsession" + ]