Compare commits

..

No commits in common. "d7ba280a2f7733171aa4f54e3e83e08ee62dea69" and "06ac295e17fba4ea6c55b75284d3fdf567d4aaff" have entirely different histories.

9 changed files with 171 additions and 173 deletions

View file

@ -1,71 +1,36 @@
import os
from enum import StrEnum
from pathlib import Path
import duckdb
import os
from typing_extensions import Self
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB:
def __init__(
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
def __init__(self):
self.db = duckdb.connect()
self.db.install_extension("spatial")
self.db.install_extension("httpfs")
self.db.load_extension("spatial")
self.db.load_extension("httpfs")
self.sheets = tuple()
self.loaded = False
self.path = ""
if storage_engine == StorageEngines.gcs:
if all(
gcs_access is not None,
gcs_secret is not None,
bucket is not None,
sid is not None,
):
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self:
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
elif storage_engine == StorageEngines.local:
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
return self
def load_xls(self, xls_path: Path) -> Self:
def load_metadata(self, path: str = "") -> Self:
"""For some reason, the header is not loaded"""
self.db.sql(f"""
create table metadata as (
select
*
from
st_read('{xls_path}',
st_read('{path}',
layer='metadata'
)
)""")
@ -74,55 +39,62 @@ class DDB:
.fetchall()[0][0]
.split(";")
)
self.path = path
return self
def dump_local(self, path) -> Self:
# TODO: Port to fsspec and have a single dump file
self.db.query(f"copy metadata to '{path}/metadata.csv'")
for sheet in self.sheets:
self.db.query(f"""
create table {sheet} as
copy
(
select
*
from
st_read
(
'{xls_path}',
'{self.path}',
layer = '{sheet}'
)
)
to '{path}/{sheet}.csv'
""")
self.loaded = True
return self
def dump(self) -> Self:
# TODO: Create a decorator
if not self.loaded:
raise ValueError("Data should be loaded first")
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
def dump_gcs(self, bucketname, sid) -> Self:
self.db.sql(
f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'"
)
for sheet in self.sheets:
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
return self
def load_folder(self) -> Self:
self.query(
f"""
create table metadata as (
self.db.query(f"""
copy
(
select
*
from
read_csv_auto('{self.path_prefix}/metadata.csv')
st_read
(
'{self.path}',
layer = '{sheet}'
)
)
"""
)
to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv'
""")
return self
def load_folder_local(self, path: str) -> Self:
self.sheets = tuple(
self.query(
"""
f"""
select
Field2
from
metadata
read_csv_auto('{path}/metadata.csv')
where
Field1 = 'Sheets'
"""
@ -138,21 +110,61 @@ class DDB:
select
*
from
read_csv_auto('{self.path_prefix}/{sheet}.csv')
read_csv_auto('{path}/{sheet}.csv')
)
""")
self.loaded = True
return self
def load_folder_gcs(self, bucketname: str, sid: str) -> Self:
self.sheets = tuple(
self.query(
f"""
select
Field2
from
read_csv_auto(
'gcs://{bucketname}/sessions/{sid}/metadata.csv'
)
where
Field1 = 'Sheets'
"""
)
.fetchall()[0][0]
.split(";")
)
# Load all the tables into the database
for sheet in self.sheets:
self.db.query(f"""
create table {sheet} as (
select
*
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv')
)
""")
return self
def load_description(self) -> Self:
def load_description_local(self, path: str) -> Self:
return self.query(
"""
f"""
select
Field2
from
metadata
read_csv_auto('{path}/metadata.csv')
where
Field1 = 'Description'"""
).fetchall()[0][0]
def load_description_gcs(self, bucketname: str, sid: str) -> Self:
return self.query(
f"""
select
Field2
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv')
where
Field1 = 'Description'"""
).fetchall()[0][0]
@ -170,11 +182,9 @@ class DDB:
f"select column_name, column_type from (describe {table})"
).fetchall()
)
+ [os.linesep]
)
@property
def schema(self):
def db_schema(self):
return os.linesep.join(
[
"The schema of the database is the following:",
@ -184,18 +194,3 @@ class DDB:
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
def query_prompt(self, user_prompt: str) -> str:
query = (
f"The following sentence is the description of a query that "
f"needs to be executed in a database: {user_prompt}"
)
return os.linesep.join(
[
query,
self.schema,
self.load_description(),
"Return just the SQL statement",
]
)

View file

@ -6,7 +6,7 @@ class Settings(BaseSettings):
gcs_access: str = "access"
gcs_secret: str = "secret"
gcs_bucketname: str = "bucket"
auth: bool = True
auth: bool = False
model_config = SettingsConfigDict(env_file=".env")

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
import hellocomputer
from .routers import analysis, files, sessions
from .routers import files, sessions, analysis
static_path = Path(hellocomputer.__file__).parent / "static"

View file

@ -1,5 +1,4 @@
from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage

View file

@ -1,29 +1,41 @@
import os
from fastapi import APIRouter
from fastapi.responses import PlainTextResponse
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.extraction import extract_code_block
from ..config import settings
from ..models import Chat
from hellocomputer.analytics import DDB
from hellocomputer.extraction import extract_code_block
import os
router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
).load_folder()
print(q)
query = f"Write a query that {q} in the current database"
chat = await chat.eval("You're an expert sql developer", db.query_prompt())
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = (
DDB()
.gcs_secret(settings.gcs_access, settings.gcs_secret)
.load_folder_gcs(settings.gcs_bucketname, sid)
)
prompt = os.linesep.join(
[
query,
db.db_schema(),
db.load_description_gcs(settings.gcs_bucketname, sid),
"Return just the SQL statement",
]
)
print(prompt)
chat = await chat.eval("You're an expert sql developer", prompt)
query = extract_code_block(chat.last_response_content())
result = str(db.query(query))
print(result)

View file

@ -1,22 +1,21 @@
import aiofiles
# import s3fs
import s3fs
from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse
from ..analytics import DDB, StorageEngines
from ..config import settings
from ..analytics import DDB
router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials
# gcs = s3fs.S3FileSystem(
# key=settings.gcs_access,
# secret=settings.gcs_secret,
# client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
# )
# bucket_name = settings.gcs_bucketname
gcs = s3fs.S3FileSystem(
key=settings.gcs_access,
secret=settings.gcs_secret,
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
)
bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"])
@ -27,15 +26,10 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush()
(
DDB(
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
)
.load_xls(f.name)
.dump()
DDB()
.gcs_secret(settings.gcs_access, settings.gcs_secret)
.load_metadata(f.name)
.dump_gcs(settings.gcs_bucketname, sid)
)
return JSONResponse(

View file

@ -1,16 +1,22 @@
from typing import Annotated
from uuid import uuid4
from typing import Annotated
from fastapi import APIRouter, Depends
from fastapi.responses import PlainTextResponse
from ..config import settings
from ..security import oauth2_scheme
# Scheme for the Authorization header
router = APIRouter()
if settings.auth:
@router.get("/token")
async def get_token() -> str:
return str(uuid4())
@router.get("/new_session")
async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
return str(uuid4())

View file

@ -1,50 +1,40 @@
import hellocomputer
from hellocomputer.analytics import DDB
from pathlib import Path
import hellocomputer
from hellocomputer.analytics import DDB, StorageEngines
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data"
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
db.load_xls(TEST_XLS_PATH).dump()
def test_dump():
db = (
DDB()
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
.dump_local(TEST_OUTPUT_FOLDER)
)
assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()
def test_load():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
assert db.sheets == ("answers",)
results = db.query("select * from answers").fetchall()
assert len(results) == 6
def test_load_description():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
file_description = db.load_description()
file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER)
assert file_description.startswith("answers")
def test_schema():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
schema = []
for sheet in db.sheets:
schema.append(db.table_schema(sheet))
assert db.schema.startswith("The schema of the database")
def test_query_prompt():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"
)
assert schema[0].startswith("Table name:")

View file

@ -1,19 +1,14 @@
import os
from pathlib import Path
import hellocomputer
import pytest
from hellocomputer.analytics import DDB, StorageEngines
import os
import hellocomputer
from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat
from hellocomputer.extraction import extract_code_block
from pathlib import Path
from hellocomputer.analytics import DDB
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
@pytest.mark.asyncio
@ -34,10 +29,17 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
TEST_XLS_PATH
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
prompt = os.linesep.join(
[
query,
db.db_schema(),
db.load_description_local(TEST_OUTPUT_FOLDER),
"Return just the SQL statement",
]
)
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
chat = await chat.eval("You're an expert sql developer", prompt)
query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT")