Staging
This commit is contained in:
parent
48d2a9fc57
commit
f210b811af
|
@ -1,4 +1,5 @@
|
||||||
import duckdb
|
import duckdb
|
||||||
|
import os
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +65,7 @@ class DDB:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def dump_gcs(self, bucketname, sid) -> Self:
|
def dump_gcs(self, bucketname, sid) -> Self:
|
||||||
self.db.sql(f"copy metadata to 'gcs://{bucketname}/{sid}/data.csv'")
|
self.db.sql(f"copy metadata to 'gcs://{bucketname}/{sid}/metadata.csv'")
|
||||||
|
|
||||||
for sheet in self.sheets:
|
for sheet in self.sheets:
|
||||||
self.db.query(f"""
|
self.db.query(f"""
|
||||||
|
@ -106,8 +107,60 @@ class DDB:
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def load_folder_gcs(self, path: str) -> Self:
|
def load_folder_gcs(self, bucketname: str, sid: str) -> Self:
|
||||||
|
self.sheets = tuple(
|
||||||
|
self.query(
|
||||||
|
f"select Field2 from read_csv_auto('gcs://{bucketname}/{sid}/metadata.csv') where Field1 = 'Sheets'"
|
||||||
|
)
|
||||||
|
.fetchall()[0][0]
|
||||||
|
.split(";")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all the tables into the database
|
||||||
|
for sheet in self.sheets:
|
||||||
|
self.db.query(f"""
|
||||||
|
create table {sheet} as (
|
||||||
|
select
|
||||||
|
*
|
||||||
|
from
|
||||||
|
read_csv_auto('gcs://{bucketname}/{sid}/{sheet}.csv')
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def load_description_local(self, path: str) -> Self:
|
||||||
|
return self.query(
|
||||||
|
f"select Field2 from read_csv_auto('{path}/metadata.csv') where Field1 = 'Description'"
|
||||||
|
).fetchall()[0][0]
|
||||||
|
|
||||||
|
def load_description_gcs(self, bucketname: str, sid: str) -> Self:
|
||||||
|
return self.query(
|
||||||
|
f"select Field2 from read_csv_auto('gcs://{bucketname}/{sid}/metadata.csv') where Field1 = 'Description'"
|
||||||
|
).fetchall()[0][0]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_schema_row(row):
|
||||||
|
return f"Column name: {row[0]}, Column type: {row[1]}"
|
||||||
|
|
||||||
|
def table_schema(self, table: str):
|
||||||
|
return os.linesep.join(
|
||||||
|
[f"Table name: {table}"]
|
||||||
|
+ list(
|
||||||
|
self.process_schema_row(r)
|
||||||
|
for r in self.query(
|
||||||
|
f"select column_name, column_type from (describe {table})"
|
||||||
|
).fetchall()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def db_schema(self):
|
||||||
|
return os.linesep.join(
|
||||||
|
[
|
||||||
|
"The schema of the database is the following:",
|
||||||
|
]
|
||||||
|
+ [self.table_schema(sheet) for sheet in self.sheets]
|
||||||
|
)
|
||||||
|
|
||||||
def query(self, sql, *args, **kwargs):
|
def query(self, sql, *args, **kwargs):
|
||||||
return self.db.query(sql, *args, **kwargs)
|
return self.db.query(sql, *args, **kwargs)
|
||||||
|
|
10
src/hellocomputer/extraction.py
Normal file
10
src/hellocomputer/extraction.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def extract_code_block(response):
|
||||||
|
# python regex to extract markdown clode block contained in the response string
|
||||||
|
pattern = r"```(.*?)```"
|
||||||
|
matches = re.findall(pattern, response, re.DOTALL)
|
||||||
|
if len(matches) > 1:
|
||||||
|
raise ValueError("More than one code block")
|
||||||
|
return matches[0].removeprefix("sql").removeprefix("\n")
|
|
@ -5,6 +5,9 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
class AvailableModels(StrEnum):
|
class AvailableModels(StrEnum):
|
||||||
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
|
llama3_8b = "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||||
|
llama3_70b = "meta-llama/Meta-Llama-3-70B-Instruct"
|
||||||
|
# Function calling model
|
||||||
|
mixtral_8x7b = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
|
|
|
@ -4,12 +4,35 @@ from fastapi.responses import PlainTextResponse
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
from ..models import Chat
|
from ..models import Chat
|
||||||
|
|
||||||
|
from hellocomputer.analytics import DDB
|
||||||
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "") -> str:
|
async def query(sid: str = "", q: str = "") -> str:
|
||||||
model = Chat(api_key=settings.anyscale_api_key).eval(
|
query = f"Write a query that {q} in the current database"
|
||||||
system="You're an expert analyst", human="Do some analysis"
|
|
||||||
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
|
db = (
|
||||||
|
DDB()
|
||||||
|
.gcs_secret(settings.gcs_secret, settings.gcs_secret)
|
||||||
|
.load_folder_gcs(settings.gcs_bucketname, sid)
|
||||||
)
|
)
|
||||||
return model.last_response_content()
|
|
||||||
|
prompt = os.linesep.join(
|
||||||
|
[
|
||||||
|
query,
|
||||||
|
db.db_schema(),
|
||||||
|
db.load_description_gcs(settings.gcs_bucketname, sid),
|
||||||
|
"Return just the SQL statement",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
chat = await chat.eval("You're an expert sql developer", prompt)
|
||||||
|
query = extract_code_block(chat.last_response_content())
|
||||||
|
|
||||||
|
return str(db.query(query))
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
|
||||||
await f.flush()
|
await f.flush()
|
||||||
|
|
||||||
gcs.makedir(f"{settings.gcs_bucketname}/{sid}")
|
gcs.makedir(f"{settings.gcs_bucketname}/{sid}")
|
||||||
|
print("successfully created directory")
|
||||||
(
|
(
|
||||||
DDB()
|
DDB()
|
||||||
.gcs_secret(settings.gcs_secret, settings.gcs_secret)
|
.gcs_secret(settings.gcs_secret, settings.gcs_secret)
|
||||||
|
|
|
@ -24,3 +24,17 @@ def test_load():
|
||||||
|
|
||||||
results = db.query("select * from answers").fetchall()
|
results = db.query("select * from answers").fetchall()
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_description():
|
||||||
|
file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER)
|
||||||
|
assert file_description.startswith("answers")
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema():
|
||||||
|
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
|
||||||
|
schema = []
|
||||||
|
for sheet in db.sheets:
|
||||||
|
schema.append(db.table_schema(sheet))
|
||||||
|
|
||||||
|
assert schema[0].startswith("Table name:")
|
||||||
|
|
|
@ -1,6 +1,14 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
|
import hellocomputer
|
||||||
from hellocomputer.config import settings
|
from hellocomputer.config import settings
|
||||||
from hellocomputer.models import Chat
|
from hellocomputer.models import Chat
|
||||||
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
from pathlib import Path
|
||||||
|
from hellocomputer.analytics import DDB
|
||||||
|
|
||||||
|
|
||||||
|
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -11,3 +19,27 @@ async def test_chat_simple():
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0)
|
||||||
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
|
chat = await chat.eval("Your're a helpful assistant", "Say literlly 'Hello'")
|
||||||
assert chat.last_response_content() == "Hello!"
|
assert chat.last_response_content() == "Hello!"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
settings.anyscale_api_key == "Awesome API", reason="API Key not set"
|
||||||
|
)
|
||||||
|
async def test_simple_data_query():
|
||||||
|
query = "write a query that finds the average score of all students in the current database"
|
||||||
|
|
||||||
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
|
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
|
||||||
|
|
||||||
|
prompt = os.linesep.join(
|
||||||
|
[
|
||||||
|
query,
|
||||||
|
db.db_schema(),
|
||||||
|
db.load_description_local(TEST_OUTPUT_FOLDER),
|
||||||
|
"Return just the SQL statement",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
chat = await chat.eval("You're an expert sql developer", prompt)
|
||||||
|
query = extract_code_block(chat.last_response_content())
|
||||||
|
assert query.startswith("SELECT")
|
||||||
|
|
Loading…
Reference in a new issue