diff --git a/src/hellocomputer/analytics.py b/src/hellocomputer/analytics.py index a841e40..55a0cb3 100644 --- a/src/hellocomputer/analytics.py +++ b/src/hellocomputer/analytics.py @@ -1,4 +1,5 @@ import duckdb +import os from typing_extensions import Self @@ -64,7 +65,7 @@ class DDB: return self def dump_gcs(self, bucketname, sid) -> Self: - self.db.sql(f"copy metadata to 'gcs://{bucketname}/{sid}/data.csv'") + self.db.sql(f"copy metadata to 'gcs://{bucketname}/{sid}/metadata.csv'") for sheet in self.sheets: self.db.query(f""" @@ -106,8 +107,60 @@ class DDB: return self - def load_folder_gcs(self, path: str) -> Self: + def load_folder_gcs(self, bucketname: str, sid: str) -> Self: + self.sheets = tuple( + self.query( + f"select Field2 from read_csv_auto('gcs://{bucketname}/{sid}/metadata.csv') where Field1 = 'Sheets'" + ) + .fetchall()[0][0] + .split(";") + ) + + # Load all the tables into the database + for sheet in self.sheets: + self.db.query(f""" + create table {sheet} as ( + select + * + from + read_csv_auto('gcs://{bucketname}/{sid}/{sheet}.csv') + ) + """) + return self + def load_description_local(self, path: str) -> Self: + return self.query( + f"select Field2 from read_csv_auto('{path}/metadata.csv') where Field1 = 'Description'" + ).fetchall()[0][0] + + def load_description_gcs(self, bucketname: str, sid: str) -> Self: + return self.query( + f"select Field2 from read_csv_auto('gcs://{bucketname}/{sid}/metadata.csv') where Field1 = 'Description'" + ).fetchall()[0][0] + + @staticmethod + def process_schema_row(row): + return f"Column name: {row[0]}, Column type: {row[1]}" + + def table_schema(self, table: str): + return os.linesep.join( + [f"Table name: {table}"] + + list( + self.process_schema_row(r) + for r in self.query( + f"select column_name, column_type from (describe {table})" + ).fetchall() + ) + ) + + def db_schema(self): + return os.linesep.join( + [ + "The schema of the database is the following:", + ] + + [self.table_schema(sheet) for sheet in self.sheets] + ) + def query(self, sql, *args, **kwargs): return self.db.query(sql, *args, **kwargs) diff --git a/src/hellocomputer/extraction.py b/src/hellocomputer/extraction.py new file mode 100644 index 0000000..d860272 --- /dev/null +++ b/src/hellocomputer/extraction.py @@ -0,0 +1,10 @@ +import re + + +def extract_code_block(response): + # python regex to extract markdown clode block contained in the response string + pattern = r"```(.*?)```" + matches = re.findall(pattern, response, re.DOTALL) + if len(matches) > 1: + raise ValueError("More than one code block") + return matches[0].removeprefix("sql").removeprefix("\n") diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index b98e41d..809745c 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -5,6 +5,9 @@ from langchain_core.messages import HumanMessage, SystemMessage class AvailableModels(StrEnum): llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct" + llama3_70b = "meta-llama/Meta-Llama-3-70B-Instruct" + # Function calling model + mixtral_8x7b = "mistralai/Mixtral-8x7B-Instruct-v0.1" class Chat: diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 739ffae..b403d6f 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -4,12 +4,35 @@ from fastapi.responses import PlainTextResponse from ..config import settings from ..models import Chat +from hellocomputer.analytics import DDB +from hellocomputer.extraction import extract_code_block + +import os + router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) -async def query(sid: str = "") -> str: - model = Chat(api_key=settings.anyscale_api_key).eval( - system="You're an expert analyst", human="Do some analysis" +async def query(sid: str = "", q: str = "") -> str: + query = f"Write a query that {q} in the current database" + + chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) + db = ( + DDB() + .gcs_secret(settings.gcs_secret, settings.gcs_secret) + .load_folder_gcs(settings.gcs_bucketname, sid) ) - return model.last_response_content() + + prompt = os.linesep.join( + [ + query, + db.db_schema(), + db.load_description_gcs(settings.gcs_bucketname, sid), + "Return just the SQL statement", + ] + ) + + chat = await chat.eval("You're an expert sql developer", prompt) + query = extract_code_block(chat.last_response_content()) + + return str(db.query(query)) diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index 23b4d55..5787c5d 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -26,7 +26,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""): await f.flush() gcs.makedir(f"{settings.gcs_bucketname}/{sid}") - + print("successfully created directory") ( DDB() .gcs_secret(settings.gcs_secret, settings.gcs_secret) diff --git a/test/test_data.py b/test/test_data.py index b3a351e..6f750f1 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -24,3 +24,17 @@ def test_load(): results = db.query("select * from answers").fetchall() assert len(results) == 2 + + +def test_load_description(): + file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER) + assert file_description.startswith("answers") + + +def test_schema(): + db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) + schema = [] + for sheet in db.sheets: + schema.append(db.table_schema(sheet)) + + assert schema[0].startswith("Table name:") diff --git a/test/test_query.py b/test/test_query.py index 0107db0..4fe1b8b 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -1,6 +1,14 @@ import pytest +import os +import hellocomputer from hellocomputer.config import settings from hellocomputer.models import Chat +from hellocomputer.extraction import extract_code_block +from pathlib import Path +from hellocomputer.analytics import DDB + + +TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" @pytest.mark.asyncio @@ -11,3 +19,27 @@ async def test_chat_simple(): chat = Chat(api_key=settings.anyscale_api_key, temperature=0) chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'") assert chat.last_response_content() == "Hello!" + + +@pytest.mark.asyncio +@pytest.mark.skipif( + settings.anyscale_api_key == "Awesome API", reason="API Key not set" +) +async def test_simple_data_query(): + query = "write a query that finds the average score of all students in the current database" + + chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) + db = DDB().load_folder_local(TEST_OUTPUT_FOLDER) + + prompt = os.linesep.join( + [ + query, + db.db_schema(), + db.load_description_local(TEST_OUTPUT_FOLDER), + "Return just the SQL statement", + ] + ) + + chat = await chat.eval("You're an expert sql developer", prompt) + query = extract_code_block(chat.last_response_content()) + assert query.startswith("SELECT")