diff --git a/src/hellocomputer/analytics.py b/src/hellocomputer/analytics.py index 51abf17..607b2b4 100644 --- a/src/hellocomputer/analytics.py +++ b/src/hellocomputer/analytics.py @@ -1,36 +1,63 @@ -import duckdb import os +from enum import StrEnum +from pathlib import Path + +import duckdb from typing_extensions import Self +class StorageEngines(StrEnum): + local = "Local" + gcs = "GCS" + + class DDB: - def __init__(self): + def __init__(self, storage_engine: StorageEngines, **kwargs): + """Write documentation""" self.db = duckdb.connect() self.db.install_extension("spatial") self.db.install_extension("httpfs") self.db.load_extension("spatial") self.db.load_extension("httpfs") self.sheets = tuple() - self.path = "" + self.loaded = False - def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self: - self.db.sql(f""" - CREATE SECRET ( - TYPE GCS, - KEY_ID '{gcs_access}', - SECRET '{gcs_secret}') - """) + if storage_engine == StorageEngines.gcs: + if ( + "gcs_access" in kwargs + and "gcs_secret" in kwargs + and "bucketname" in kwargs + and "sid" in kwargs + ): + self.db.sql(f""" + CREATE SECRET ( + TYPE GCS, + KEY_ID '{kwargs["gcs_access"]}', + SECRET '{kwargs["gcs_secret"]}') + """) + self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}" + else: + raise ValueError( + "With GCS storage engine you need to provide " + "the gcs_access, gcs_secret and bucket keyword arguments" + ) - return self + elif storage_engine == StorageEngines.local: + if "path" in kwargs: + self.path_prefix = kwargs["path"] + else: + raise ValueError( + "With local storage you need to provide the path keyword argument" + ) - def load_metadata(self, path: str = "") -> Self: + def load_xls(self, xls_path: Path) -> Self: """For some reason, the header is not loaded""" self.db.sql(f""" create table metadata as ( select * from - st_read('{path}', + st_read('{xls_path}', layer='metadata' ) )""") @@ -39,62 +66,45 @@ class DDB: .fetchall()[0][0] .split(";") ) - self.path = path - - return self - - def dump_local(self, path) -> Self: - # TODO: Port to fsspec and have a single dump file - self.db.query(f"copy metadata to '{path}/metadata.csv'") for sheet in self.sheets: self.db.query(f""" - copy + create table {sheet} as ( select * from st_read ( - '{self.path}', + '{xls_path}', layer = '{sheet}' ) ) - to '{path}/{sheet}.csv' """) + + self.loaded = True + return self - def dump_gcs(self, bucketname, sid) -> Self: - self.db.sql( - f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'" - ) + def dump(self) -> Self: + # TODO: Create a decorator + if not self.loaded: + raise ValueError("Data should be loaded first") + + self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") for sheet in self.sheets: - self.db.query(f""" - copy - ( - select - * - from - st_read - ( - '{self.path}', - layer = '{sheet}' - ) - ) - to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv' - """) - + self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") return self - def load_folder_local(self, path: str) -> Self: + def load_folder(self) -> Self: self.sheets = tuple( self.query( f""" select Field2 from - read_csv_auto('{path}/metadata.csv') + read_csv_auto('{self.path_prefix}/metadata.csv') where Field1 = 'Sheets' """ @@ -110,61 +120,21 @@ class DDB: select * from - read_csv_auto('{path}/{sheet}.csv') + read_csv_auto('{self.path_prefix}/{sheet}.csv') ) """) - return self - - def load_folder_gcs(self, bucketname: str, sid: str) -> Self: - self.sheets = tuple( - self.query( - f""" - select - Field2 - from - read_csv_auto( - 'gcs://{bucketname}/sessions/{sid}/metadata.csv' - ) - where - Field1 = 'Sheets' - """ - ) - .fetchall()[0][0] - .split(";") - ) - - # Load all the tables into the database - for sheet in self.sheets: - self.db.query(f""" - create table {sheet} as ( - select - * - from - read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv') - ) - """) + self.loaded = True return self - def load_description_local(self, path: str) -> Self: + def load_description(self) -> Self: return self.query( f""" select Field2 from - read_csv_auto('{path}/metadata.csv') - where - Field1 = 'Description'""" - ).fetchall()[0][0] - - def load_description_gcs(self, bucketname: str, sid: str) -> Self: - return self.query( - f""" - select - Field2 - from - read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv') + read_csv_auto('{self.path_prefix}/metadata.csv') where Field1 = 'Description'""" ).fetchall()[0][0] @@ -182,9 +152,11 @@ class DDB: f"select column_name, column_type from (describe {table})" ).fetchall() ) + + [os.linesep] ) - def db_schema(self): + @property + def schema(self): return os.linesep.join( [ "The schema of the database is the following:", @@ -194,3 +166,18 @@ class DDB: def query(self, sql, *args, **kwargs): return self.db.query(sql, *args, **kwargs) + + def query_prompt(self, user_prompt: str) -> str: + query = ( + f"The following sentence is the description of a query that " + f"needs to be executed in a database: {user_prompt}" + ) + + return os.linesep.join( + [ + query, + self.schema, + self.load_description(), + "Return just the SQL statement", + ] + ) diff --git a/src/hellocomputer/config.py b/src/hellocomputer/config.py index f6c0ef6..6e86baa 100644 --- a/src/hellocomputer/config.py +++ b/src/hellocomputer/config.py @@ -6,7 +6,7 @@ class Settings(BaseSettings): gcs_access: str = "access" gcs_secret: str = "secret" gcs_bucketname: str = "bucket" - auth: bool = False + auth: bool = True model_config = SettingsConfigDict(env_file=".env") diff --git a/src/hellocomputer/main.py b/src/hellocomputer/main.py index 2cafa65..127dfc7 100644 --- a/src/hellocomputer/main.py +++ b/src/hellocomputer/main.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import hellocomputer -from .routers import files, sessions, analysis +from .routers import analysis, files, sessions static_path = Path(hellocomputer.__file__).parent / "static" diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index 809745c..b59b1bb 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -1,4 +1,5 @@ from enum import StrEnum + from langchain_community.chat_models import ChatAnyscale from langchain_core.messages import HumanMessage, SystemMessage diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 07ded5a..e54bf58 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,41 +1,29 @@ +import os + from fastapi import APIRouter from fastapi.responses import PlainTextResponse -from ..config import settings -from ..models import Chat - -from hellocomputer.analytics import DDB +from hellocomputer.analytics import DDB, StorageEngines from hellocomputer.extraction import extract_code_block -import os +from ..config import settings +from ..models import Chat router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) async def query(sid: str = "", q: str = "") -> str: - print(q) - query = f"Write a query that {q} in the current database" - chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = ( - DDB() - .gcs_secret(settings.gcs_access, settings.gcs_secret) - .load_folder_gcs(settings.gcs_bucketname, sid) - ) + db = DDB( + StorageEngines.gcs, + gcs_access=settings.gcs_access, + gcs_secret=settings.gcs_secret, + bucket=settings.gcs_bucketname, + sid=sid, + ).load_folder() - prompt = os.linesep.join( - [ - query, - db.db_schema(), - db.load_description_gcs(settings.gcs_bucketname, sid), - "Return just the SQL statement", - ] - ) - - print(prompt) - - chat = await chat.eval("You're an expert sql developer", prompt) + chat = await chat.eval("You're an expert sql developer", db.query_prompt()) query = extract_code_block(chat.last_response_content()) result = str(db.query(query)) print(result) diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index 5000d9d..b114e4b 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -1,21 +1,22 @@ import aiofiles -import s3fs + +# import s3fs from fastapi import APIRouter, File, UploadFile from fastapi.responses import JSONResponse +from ..analytics import DDB, StorageEngines from ..config import settings -from ..analytics import DDB router = APIRouter() # Configure the S3FS with your Google Cloud Storage credentials -gcs = s3fs.S3FileSystem( - key=settings.gcs_access, - secret=settings.gcs_secret, - client_kwargs={"endpoint_url": "https://storage.googleapis.com"}, -) -bucket_name = settings.gcs_bucketname +# gcs = s3fs.S3FileSystem( +# key=settings.gcs_access, +# secret=settings.gcs_secret, +# client_kwargs={"endpoint_url": "https://storage.googleapis.com"}, +# ) +# bucket_name = settings.gcs_bucketname @router.post("/upload", tags=["files"]) @@ -26,10 +27,15 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""): await f.flush() ( - DDB() - .gcs_secret(settings.gcs_access, settings.gcs_secret) - .load_metadata(f.name) - .dump_gcs(settings.gcs_bucketname, sid) + DDB( + StorageEngines.gcs, + gcs_access=settings.gcs_access, + gcs_secret=settings.gcs_secret, + bucket=settings.gcs_bucketname, + sid=sid, + ) + .load_xls(f.name) + .dump() ) return JSONResponse( diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index 381f919..9d68dd1 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -1,22 +1,16 @@ -from uuid import uuid4 from typing import Annotated +from uuid import uuid4 + from fastapi import APIRouter, Depends from fastapi.responses import PlainTextResponse -from ..config import settings from ..security import oauth2_scheme + # Scheme for the Authorization header router = APIRouter() -if settings.auth: - - @router.get("/token") - async def get_token() -> str: - return str(uuid4()) - - @router.get("/new_session") async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str: return str(uuid4()) diff --git a/test/test_data.py b/test/test_data.py index 641c15b..5030eaa 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,40 +1,52 @@ -import hellocomputer -from hellocomputer.analytics import DDB from pathlib import Path -TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data" +import hellocomputer +from hellocomputer.analytics import DDB, StorageEngines + +TEST_STORAGE = StorageEngines.local +TEST_XLS_PATH = ( + Path(hellocomputer.__file__).parents[2] + / "test" + / "data" + / "TestExcelHelloComputer.xlsx" +) TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" -def test_dump(): - db = ( - DDB() - .load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx") - .dump_local(TEST_OUTPUT_FOLDER) - ) +def test_0_dump(): + db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) + db.load_xls(TEST_XLS_PATH).dump() assert db.sheets == ("answers",) assert (TEST_OUTPUT_FOLDER / "answers.csv").exists() def test_load(): - db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) - - assert db.sheets == ("answers",) - + db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() results = db.query("select * from answers").fetchall() assert len(results) == 6 def test_load_description(): - file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER) + db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + file_description = db.load_description() assert file_description.startswith("answers") def test_schema(): - db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) + db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() schema = [] for sheet in db.sheets: schema.append(db.table_schema(sheet)) - assert schema[0].startswith("Table name:") + print(db.schema) + + assert db.schema.startswith("The schema of the database") + + +def test_query_prompt(): + db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() + + assert db.query_prompt("Find the average score of all students").startswith( + "The following sentence" + ) diff --git a/test/test_query.py b/test/test_query.py index 4fe1b8b..59956a3 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1,14 +1,19 @@ -import pytest import os -import hellocomputer -from hellocomputer.config import settings -from hellocomputer.models import Chat -from hellocomputer.extraction import extract_code_block from pathlib import Path + +import hellocomputer +import pytest from hellocomputer.analytics import DDB +from hellocomputer.config import settings +from hellocomputer.extraction import extract_code_block +from hellocomputer.models import Chat - -TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" +TEST_XLS_PATH = ( + Path(hellocomputer.__file__).parents[2] + / "test" + / "data" + / "TestExcelHelloComputer.xlsx" +) @pytest.mark.asyncio @@ -29,13 +34,13 @@ async def test_simple_data_query(): query = "write a query that finds the average score of all students in the current database" chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) - db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) + db = DDB().load_xls(TEST_XLS_PATH) prompt = os.linesep.join( [ query, - db.db_schema(), - db.load_description_local(TEST_OUTPUT_FOLDER), + db.schema(), + db.load_description(), "Return just the SQL statement", ] )