Refactored analytics class
This commit is contained in:
		
							parent
							
								
									06ac295e17
								
							
						
					
					
						commit
						0c21073e88
					
				|  | @ -1,36 +1,63 @@ | |||
| import duckdb | ||||
| import os | ||||
| from enum import StrEnum | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import duckdb | ||||
| from typing_extensions import Self | ||||
| 
 | ||||
| 
 | ||||
| class StorageEngines(StrEnum): | ||||
|     local = "Local" | ||||
|     gcs = "GCS" | ||||
| 
 | ||||
| 
 | ||||
| class DDB: | ||||
|     def __init__(self): | ||||
|     def __init__(self, storage_engine: StorageEngines, **kwargs): | ||||
|         """Write documentation""" | ||||
|         self.db = duckdb.connect() | ||||
|         self.db.install_extension("spatial") | ||||
|         self.db.install_extension("httpfs") | ||||
|         self.db.load_extension("spatial") | ||||
|         self.db.load_extension("httpfs") | ||||
|         self.sheets = tuple() | ||||
|         self.path = "" | ||||
|         self.loaded = False | ||||
| 
 | ||||
|     def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self: | ||||
|         self.db.sql(f""" | ||||
|             CREATE SECRET ( | ||||
|                TYPE GCS, | ||||
|                KEY_ID '{gcs_access}', | ||||
|                SECRET '{gcs_secret}') | ||||
|                """) | ||||
|         if storage_engine == StorageEngines.gcs: | ||||
|             if ( | ||||
|                 "gcs_access" in kwargs | ||||
|                 and "gcs_secret" in kwargs | ||||
|                 and "bucketname" in kwargs | ||||
|                 and "sid" in kwargs | ||||
|             ): | ||||
|                 self.db.sql(f""" | ||||
|                     CREATE SECRET ( | ||||
|                     TYPE GCS, | ||||
|                     KEY_ID '{kwargs["gcs_access"]}', | ||||
|                     SECRET '{kwargs["gcs_secret"]}') | ||||
|                     """) | ||||
|                 self.path_prefix = f"gcs://{kwargs["bucket"]}/sessions/{kwargs['sid']}" | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "With GCS storage engine you need to provide " | ||||
|                     "the gcs_access, gcs_secret and bucket keyword arguments" | ||||
|                 ) | ||||
| 
 | ||||
|         return self | ||||
|         elif storage_engine == StorageEngines.local: | ||||
|             if "path" in kwargs: | ||||
|                 self.path_prefix = kwargs["path"] | ||||
|             else: | ||||
|                 raise ValueError( | ||||
|                     "With local storage you need to provide the path keyword argument" | ||||
|                 ) | ||||
| 
 | ||||
|     def load_metadata(self, path: str = "") -> Self: | ||||
|     def load_xls(self, xls_path: Path) -> Self: | ||||
|         """For some reason, the header is not loaded""" | ||||
|         self.db.sql(f""" | ||||
|             create table metadata as ( | ||||
|             select | ||||
|                 * | ||||
|             from | ||||
|                 st_read('{path}',  | ||||
|                 st_read('{xls_path}',  | ||||
|                         layer='metadata' | ||||
|                         ) | ||||
|             )""") | ||||
|  | @ -39,62 +66,45 @@ class DDB: | |||
|             .fetchall()[0][0] | ||||
|             .split(";") | ||||
|         ) | ||||
|         self.path = path | ||||
| 
 | ||||
|         return self | ||||
| 
 | ||||
|     def dump_local(self, path) -> Self: | ||||
|         # TODO: Port to fsspec and have a single dump file | ||||
|         self.db.query(f"copy metadata to '{path}/metadata.csv'") | ||||
| 
 | ||||
|         for sheet in self.sheets: | ||||
|             self.db.query(f""" | ||||
|             copy  | ||||
|             create table {sheet} as   | ||||
|                 ( | ||||
|                 select | ||||
|                     * | ||||
|                 from | ||||
|                     st_read | ||||
|                         ( | ||||
|                         '{self.path}', | ||||
|                         '{xls_path}', | ||||
|                         layer = '{sheet}' | ||||
|                         ) | ||||
|                 ) | ||||
|             to '{path}/{sheet}.csv' | ||||
|                           """) | ||||
| 
 | ||||
|         self.loaded = True | ||||
| 
 | ||||
|         return self | ||||
| 
 | ||||
|     def dump_gcs(self, bucketname, sid) -> Self: | ||||
|         self.db.sql( | ||||
|             f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'" | ||||
|         ) | ||||
|     def dump(self) -> Self: | ||||
|         # TODO: Create a decorator | ||||
|         if not self.loaded: | ||||
|             raise ValueError("Data should be loaded first") | ||||
| 
 | ||||
|         self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") | ||||
| 
 | ||||
|         for sheet in self.sheets: | ||||
|             self.db.query(f""" | ||||
|             copy  | ||||
|                 ( | ||||
|                 select | ||||
|                     * | ||||
|                 from | ||||
|                     st_read | ||||
|                         ( | ||||
|                         '{self.path}', | ||||
|                         layer = '{sheet}' | ||||
|                         ) | ||||
|                 ) | ||||
|             to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv' | ||||
|                           """) | ||||
| 
 | ||||
|             self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") | ||||
|         return self | ||||
| 
 | ||||
|     def load_folder_local(self, path: str) -> Self: | ||||
|     def load_folder(self) -> Self: | ||||
|         self.sheets = tuple( | ||||
|             self.query( | ||||
|                 f""" | ||||
|                 select | ||||
|                     Field2 | ||||
|                 from | ||||
|                     read_csv_auto('{path}/metadata.csv') | ||||
|                     read_csv_auto('{self.path_prefix}/metadata.csv') | ||||
|                 where | ||||
|                     Field1 = 'Sheets' | ||||
|                 """ | ||||
|  | @ -110,61 +120,21 @@ class DDB: | |||
|             select | ||||
|                 * | ||||
|             from | ||||
|                 read_csv_auto('{path}/{sheet}.csv') | ||||
|                 read_csv_auto('{self.path_prefix}/{sheet}.csv') | ||||
|             ) | ||||
|             """) | ||||
| 
 | ||||
|         return self | ||||
| 
 | ||||
|     def load_folder_gcs(self, bucketname: str, sid: str) -> Self: | ||||
|         self.sheets = tuple( | ||||
|             self.query( | ||||
|                 f""" | ||||
|                 select | ||||
|                     Field2 | ||||
|                 from | ||||
|                     read_csv_auto( | ||||
|                         'gcs://{bucketname}/sessions/{sid}/metadata.csv' | ||||
|                         ) | ||||
|                 where | ||||
|                     Field1 = 'Sheets' | ||||
|                     """ | ||||
|             ) | ||||
|             .fetchall()[0][0] | ||||
|             .split(";") | ||||
|         ) | ||||
| 
 | ||||
|         # Load all the tables into the database | ||||
|         for sheet in self.sheets: | ||||
|             self.db.query(f""" | ||||
|             create table {sheet} as ( | ||||
|             select | ||||
|                 * | ||||
|             from | ||||
|                 read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv') | ||||
|             ) | ||||
|             """) | ||||
|         self.loaded = True | ||||
| 
 | ||||
|         return self | ||||
| 
 | ||||
|     def load_description_local(self, path: str) -> Self: | ||||
|     def load_description(self) -> Self: | ||||
|         return self.query( | ||||
|             f""" | ||||
|             select | ||||
|                 Field2 | ||||
|             from | ||||
|                 read_csv_auto('{path}/metadata.csv') | ||||
|             where | ||||
|                 Field1 = 'Description'""" | ||||
|         ).fetchall()[0][0] | ||||
| 
 | ||||
|     def load_description_gcs(self, bucketname: str, sid: str) -> Self: | ||||
|         return self.query( | ||||
|             f""" | ||||
|             select | ||||
|                 Field2 | ||||
|             from | ||||
|                 read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv') | ||||
|                 read_csv_auto('{self.path_prefix}/metadata.csv') | ||||
|             where | ||||
|                 Field1 = 'Description'""" | ||||
|         ).fetchall()[0][0] | ||||
|  | @ -182,9 +152,11 @@ class DDB: | |||
|                     f"select column_name, column_type from (describe {table})" | ||||
|                 ).fetchall() | ||||
|             ) | ||||
|             + [os.linesep] | ||||
|         ) | ||||
| 
 | ||||
|     def db_schema(self): | ||||
|     @property | ||||
|     def schema(self): | ||||
|         return os.linesep.join( | ||||
|             [ | ||||
|                 "The schema of the database is the following:", | ||||
|  | @ -194,3 +166,18 @@ class DDB: | |||
| 
 | ||||
|     def query(self, sql, *args, **kwargs): | ||||
|         return self.db.query(sql, *args, **kwargs) | ||||
| 
 | ||||
|     def query_prompt(self, user_prompt: str) -> str: | ||||
|         query = ( | ||||
|             f"The following sentence is the description of a query that " | ||||
|             f"needs to be executed in a database: {user_prompt}" | ||||
|         ) | ||||
| 
 | ||||
|         return os.linesep.join( | ||||
|             [ | ||||
|                 query, | ||||
|                 self.schema, | ||||
|                 self.load_description(), | ||||
|                 "Return just the SQL statement", | ||||
|             ] | ||||
|         ) | ||||
|  |  | |||
|  | @ -6,7 +6,7 @@ class Settings(BaseSettings): | |||
|     gcs_access: str = "access" | ||||
|     gcs_secret: str = "secret" | ||||
|     gcs_bucketname: str = "bucket" | ||||
|     auth: bool = False | ||||
|     auth: bool = True | ||||
| 
 | ||||
|     model_config = SettingsConfigDict(env_file=".env") | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,7 +6,7 @@ from pydantic import BaseModel | |||
| 
 | ||||
| import hellocomputer | ||||
| 
 | ||||
| from .routers import files, sessions, analysis | ||||
| from .routers import analysis, files, sessions | ||||
| 
 | ||||
| static_path = Path(hellocomputer.__file__).parent / "static" | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| from enum import StrEnum | ||||
| 
 | ||||
| from langchain_community.chat_models import ChatAnyscale | ||||
| from langchain_core.messages import HumanMessage, SystemMessage | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,41 +1,29 @@ | |||
| import os | ||||
| 
 | ||||
| from fastapi import APIRouter | ||||
| from fastapi.responses import PlainTextResponse | ||||
| 
 | ||||
| from ..config import settings | ||||
| from ..models import Chat | ||||
| 
 | ||||
| from hellocomputer.analytics import DDB | ||||
| from hellocomputer.analytics import DDB, StorageEngines | ||||
| from hellocomputer.extraction import extract_code_block | ||||
| 
 | ||||
| import os | ||||
| from ..config import settings | ||||
| from ..models import Chat | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) | ||||
| async def query(sid: str = "", q: str = "") -> str: | ||||
|     print(q) | ||||
|     query = f"Write a query that {q} in the current database" | ||||
| 
 | ||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||
|     db = ( | ||||
|         DDB() | ||||
|         .gcs_secret(settings.gcs_access, settings.gcs_secret) | ||||
|         .load_folder_gcs(settings.gcs_bucketname, sid) | ||||
|     ) | ||||
|     db = DDB( | ||||
|         StorageEngines.gcs, | ||||
|         gcs_access=settings.gcs_access, | ||||
|         gcs_secret=settings.gcs_secret, | ||||
|         bucket=settings.gcs_bucketname, | ||||
|         sid=sid, | ||||
|     ).load_folder() | ||||
| 
 | ||||
|     prompt = os.linesep.join( | ||||
|         [ | ||||
|             query, | ||||
|             db.db_schema(), | ||||
|             db.load_description_gcs(settings.gcs_bucketname, sid), | ||||
|             "Return just the SQL statement", | ||||
|         ] | ||||
|     ) | ||||
| 
 | ||||
|     print(prompt) | ||||
| 
 | ||||
|     chat = await chat.eval("You're an expert sql developer", prompt) | ||||
|     chat = await chat.eval("You're an expert sql developer", db.query_prompt()) | ||||
|     query = extract_code_block(chat.last_response_content()) | ||||
|     result = str(db.query(query)) | ||||
|     print(result) | ||||
|  |  | |||
|  | @ -1,21 +1,22 @@ | |||
| import aiofiles | ||||
| import s3fs | ||||
| 
 | ||||
| # import s3fs | ||||
| from fastapi import APIRouter, File, UploadFile | ||||
| from fastapi.responses import JSONResponse | ||||
| 
 | ||||
| from ..analytics import DDB, StorageEngines | ||||
| from ..config import settings | ||||
| from ..analytics import DDB | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| 
 | ||||
| # Configure the S3FS with your Google Cloud Storage credentials | ||||
| gcs = s3fs.S3FileSystem( | ||||
|     key=settings.gcs_access, | ||||
|     secret=settings.gcs_secret, | ||||
|     client_kwargs={"endpoint_url": "https://storage.googleapis.com"}, | ||||
| ) | ||||
| bucket_name = settings.gcs_bucketname | ||||
| # gcs = s3fs.S3FileSystem( | ||||
| #     key=settings.gcs_access, | ||||
| #     secret=settings.gcs_secret, | ||||
| #     client_kwargs={"endpoint_url": "https://storage.googleapis.com"}, | ||||
| # ) | ||||
| # bucket_name = settings.gcs_bucketname | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/upload", tags=["files"]) | ||||
|  | @ -26,10 +27,15 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""): | |||
|         await f.flush() | ||||
| 
 | ||||
|         ( | ||||
|             DDB() | ||||
|             .gcs_secret(settings.gcs_access, settings.gcs_secret) | ||||
|             .load_metadata(f.name) | ||||
|             .dump_gcs(settings.gcs_bucketname, sid) | ||||
|             DDB( | ||||
|                 StorageEngines.gcs, | ||||
|                 gcs_access=settings.gcs_access, | ||||
|                 gcs_secret=settings.gcs_secret, | ||||
|                 bucket=settings.gcs_bucketname, | ||||
|                 sid=sid, | ||||
|             ) | ||||
|             .load_xls(f.name) | ||||
|             .dump() | ||||
|         ) | ||||
| 
 | ||||
|         return JSONResponse( | ||||
|  |  | |||
|  | @ -1,22 +1,16 @@ | |||
| from uuid import uuid4 | ||||
| from typing import Annotated | ||||
| from uuid import uuid4 | ||||
| 
 | ||||
| from fastapi import APIRouter, Depends | ||||
| from fastapi.responses import PlainTextResponse | ||||
| 
 | ||||
| from ..config import settings | ||||
| from ..security import oauth2_scheme | ||||
| 
 | ||||
| # Scheme for the Authorization header | ||||
| 
 | ||||
| router = APIRouter() | ||||
| 
 | ||||
| 
 | ||||
| if settings.auth: | ||||
| 
 | ||||
|     @router.get("/token") | ||||
|     async def get_token() -> str: | ||||
|         return str(uuid4()) | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/new_session") | ||||
| async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str: | ||||
|     return str(uuid4()) | ||||
|  |  | |||
|  | @ -1,40 +1,52 @@ | |||
| import hellocomputer | ||||
| from hellocomputer.analytics import DDB | ||||
| from pathlib import Path | ||||
| 
 | ||||
| TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data" | ||||
| import hellocomputer | ||||
| from hellocomputer.analytics import DDB, StorageEngines | ||||
| 
 | ||||
| TEST_STORAGE = StorageEngines.local | ||||
| TEST_XLS_PATH = ( | ||||
|     Path(hellocomputer.__file__).parents[2] | ||||
|     / "test" | ||||
|     / "data" | ||||
|     / "TestExcelHelloComputer.xlsx" | ||||
| ) | ||||
| TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" | ||||
| 
 | ||||
| 
 | ||||
| def test_dump(): | ||||
|     db = ( | ||||
|         DDB() | ||||
|         .load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx") | ||||
|         .dump_local(TEST_OUTPUT_FOLDER) | ||||
|     ) | ||||
| def test_0_dump(): | ||||
|     db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) | ||||
|     db.load_xls(TEST_XLS_PATH).dump() | ||||
| 
 | ||||
|     assert db.sheets == ("answers",) | ||||
|     assert (TEST_OUTPUT_FOLDER / "answers.csv").exists() | ||||
| 
 | ||||
| 
 | ||||
| def test_load(): | ||||
|     db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) | ||||
| 
 | ||||
|     assert db.sheets == ("answers",) | ||||
| 
 | ||||
|     db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() | ||||
|     results = db.query("select * from answers").fetchall() | ||||
|     assert len(results) == 6 | ||||
| 
 | ||||
| 
 | ||||
| def test_load_description(): | ||||
|     file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER) | ||||
|     db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() | ||||
|     file_description = db.load_description() | ||||
|     assert file_description.startswith("answers") | ||||
| 
 | ||||
| 
 | ||||
| def test_schema(): | ||||
|     db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) | ||||
|     db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() | ||||
|     schema = [] | ||||
|     for sheet in db.sheets: | ||||
|         schema.append(db.table_schema(sheet)) | ||||
| 
 | ||||
|     assert schema[0].startswith("Table name:") | ||||
|     print(db.schema) | ||||
| 
 | ||||
|     assert db.schema.startswith("The schema of the database") | ||||
| 
 | ||||
| 
 | ||||
| def test_query_prompt(): | ||||
|     db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() | ||||
| 
 | ||||
|     assert db.query_prompt("Find the average score of all students").startswith( | ||||
|         "The following sentence" | ||||
|     ) | ||||
|  |  | |||
|  | @ -1,14 +1,19 @@ | |||
| import pytest | ||||
| import os | ||||
| import hellocomputer | ||||
| from hellocomputer.config import settings | ||||
| from hellocomputer.models import Chat | ||||
| from hellocomputer.extraction import extract_code_block | ||||
| from pathlib import Path | ||||
| 
 | ||||
| import hellocomputer | ||||
| import pytest | ||||
| from hellocomputer.analytics import DDB | ||||
| from hellocomputer.config import settings | ||||
| from hellocomputer.extraction import extract_code_block | ||||
| from hellocomputer.models import Chat | ||||
| 
 | ||||
| 
 | ||||
| TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" | ||||
| TEST_XLS_PATH = ( | ||||
|     Path(hellocomputer.__file__).parents[2] | ||||
|     / "test" | ||||
|     / "data" | ||||
|     / "TestExcelHelloComputer.xlsx" | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
|  | @ -29,13 +34,13 @@ async def test_simple_data_query(): | |||
|     query = "write a query that finds the average score of all students in the current database" | ||||
| 
 | ||||
|     chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) | ||||
|     db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) | ||||
|     db = DDB().load_xls(TEST_XLS_PATH) | ||||
| 
 | ||||
|     prompt = os.linesep.join( | ||||
|         [ | ||||
|             query, | ||||
|             db.db_schema(), | ||||
|             db.load_description_local(TEST_OUTPUT_FOLDER), | ||||
|             db.schema(), | ||||
|             db.load_description(), | ||||
|             "Return just the SQL statement", | ||||
|         ] | ||||
|     ) | ||||
|  |  | |||
		Loading…
	
		Reference in a new issue