Compare commits

..

No commits in common. "d7ba280a2f7733171aa4f54e3e83e08ee62dea69" and "06ac295e17fba4ea6c55b75284d3fdf567d4aaff" have entirely different histories.

9 changed files with 171 additions and 173 deletions

View file

@ -1,71 +1,36 @@
import os
from enum import StrEnum
from pathlib import Path
import duckdb import duckdb
import os
from typing_extensions import Self from typing_extensions import Self
class StorageEngines(StrEnum):
local = "Local"
gcs = "GCS"
class DDB: class DDB:
def __init__( def __init__(self):
self,
storage_engine: StorageEngines,
sid: str | None = None,
path: Path | None = None,
gcs_access: str | None = None,
gcs_secret: str | None = None,
bucket: str | None = None,
**kwargs,
):
self.db = duckdb.connect() self.db = duckdb.connect()
self.db.install_extension("spatial") self.db.install_extension("spatial")
self.db.install_extension("httpfs") self.db.install_extension("httpfs")
self.db.load_extension("spatial") self.db.load_extension("spatial")
self.db.load_extension("httpfs") self.db.load_extension("httpfs")
self.sheets = tuple() self.sheets = tuple()
self.loaded = False self.path = ""
if storage_engine == StorageEngines.gcs: def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self:
if all( self.db.sql(f"""
gcs_access is not None, CREATE SECRET (
gcs_secret is not None, TYPE GCS,
bucket is not None, KEY_ID '{gcs_access}',
sid is not None, SECRET '{gcs_secret}')
): """)
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
""")
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
else:
raise ValueError(
"With GCS storage engine you need to provide "
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
)
elif storage_engine == StorageEngines.local: return self
if path is not None:
self.path_prefix = path
else:
raise ValueError(
"With local storage you need to provide the path keyword argument"
)
def load_xls(self, xls_path: Path) -> Self: def load_metadata(self, path: str = "") -> Self:
"""For some reason, the header is not loaded""" """For some reason, the header is not loaded"""
self.db.sql(f""" self.db.sql(f"""
create table metadata as ( create table metadata as (
select select
* *
from from
st_read('{xls_path}', st_read('{path}',
layer='metadata' layer='metadata'
) )
)""") )""")
@ -74,55 +39,62 @@ class DDB:
.fetchall()[0][0] .fetchall()[0][0]
.split(";") .split(";")
) )
self.path = path
return self
def dump_local(self, path) -> Self:
# TODO: Port to fsspec and have a single dump file
self.db.query(f"copy metadata to '{path}/metadata.csv'")
for sheet in self.sheets: for sheet in self.sheets:
self.db.query(f""" self.db.query(f"""
create table {sheet} as copy
( (
select select
* *
from from
st_read st_read
( (
'{xls_path}', '{self.path}',
layer = '{sheet}' layer = '{sheet}'
) )
) )
to '{path}/{sheet}.csv'
""") """)
self.loaded = True
return self return self
def dump(self) -> Self: def dump_gcs(self, bucketname, sid) -> Self:
# TODO: Create a decorator self.db.sql(
if not self.loaded: f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'"
raise ValueError("Data should be loaded first") )
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
for sheet in self.sheets: for sheet in self.sheets:
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") self.db.query(f"""
return self copy
(
def load_folder(self) -> Self:
self.query(
f"""
create table metadata as (
select select
* *
from from
read_csv_auto('{self.path_prefix}/metadata.csv') st_read
(
'{self.path}',
layer = '{sheet}'
)
) )
""" to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv'
) """)
return self
def load_folder_local(self, path: str) -> Self:
self.sheets = tuple( self.sheets = tuple(
self.query( self.query(
""" f"""
select select
Field2 Field2
from from
metadata read_csv_auto('{path}/metadata.csv')
where where
Field1 = 'Sheets' Field1 = 'Sheets'
""" """
@ -138,21 +110,61 @@ class DDB:
select select
* *
from from
read_csv_auto('{self.path_prefix}/{sheet}.csv') read_csv_auto('{path}/{sheet}.csv')
) )
""") """)
self.loaded = True return self
def load_folder_gcs(self, bucketname: str, sid: str) -> Self:
self.sheets = tuple(
self.query(
f"""
select
Field2
from
read_csv_auto(
'gcs://{bucketname}/sessions/{sid}/metadata.csv'
)
where
Field1 = 'Sheets'
"""
)
.fetchall()[0][0]
.split(";")
)
# Load all the tables into the database
for sheet in self.sheets:
self.db.query(f"""
create table {sheet} as (
select
*
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv')
)
""")
return self return self
def load_description(self) -> Self: def load_description_local(self, path: str) -> Self:
return self.query( return self.query(
""" f"""
select select
Field2 Field2
from from
metadata read_csv_auto('{path}/metadata.csv')
where
Field1 = 'Description'"""
).fetchall()[0][0]
def load_description_gcs(self, bucketname: str, sid: str) -> Self:
return self.query(
f"""
select
Field2
from
read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv')
where where
Field1 = 'Description'""" Field1 = 'Description'"""
).fetchall()[0][0] ).fetchall()[0][0]
@ -170,11 +182,9 @@ class DDB:
f"select column_name, column_type from (describe {table})" f"select column_name, column_type from (describe {table})"
).fetchall() ).fetchall()
) )
+ [os.linesep]
) )
@property def db_schema(self):
def schema(self):
return os.linesep.join( return os.linesep.join(
[ [
"The schema of the database is the following:", "The schema of the database is the following:",
@ -184,18 +194,3 @@ class DDB:
def query(self, sql, *args, **kwargs): def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs) return self.db.query(sql, *args, **kwargs)
def query_prompt(self, user_prompt: str) -> str:
query = (
f"The following sentence is the description of a query that "
f"needs to be executed in a database: {user_prompt}"
)
return os.linesep.join(
[
query,
self.schema,
self.load_description(),
"Return just the SQL statement",
]
)

View file

@ -6,7 +6,7 @@ class Settings(BaseSettings):
gcs_access: str = "access" gcs_access: str = "access"
gcs_secret: str = "secret" gcs_secret: str = "secret"
gcs_bucketname: str = "bucket" gcs_bucketname: str = "bucket"
auth: bool = True auth: bool = False
model_config = SettingsConfigDict(env_file=".env") model_config = SettingsConfigDict(env_file=".env")

View file

@ -6,7 +6,7 @@ from pydantic import BaseModel
import hellocomputer import hellocomputer
from .routers import analysis, files, sessions from .routers import files, sessions, analysis
static_path = Path(hellocomputer.__file__).parent / "static" static_path = Path(hellocomputer.__file__).parent / "static"

View file

@ -1,5 +1,4 @@
from enum import StrEnum from enum import StrEnum
from langchain_community.chat_models import ChatAnyscale from langchain_community.chat_models import ChatAnyscale
from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.messages import HumanMessage, SystemMessage

View file

@ -1,29 +1,41 @@
import os
from fastapi import APIRouter from fastapi import APIRouter
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from hellocomputer.analytics import DDB, StorageEngines
from hellocomputer.extraction import extract_code_block
from ..config import settings from ..config import settings
from ..models import Chat from ..models import Chat
from hellocomputer.analytics import DDB
from hellocomputer.extraction import extract_code_block
import os
router = APIRouter() router = APIRouter()
@router.get("/query", response_class=PlainTextResponse, tags=["queries"]) @router.get("/query", response_class=PlainTextResponse, tags=["queries"])
async def query(sid: str = "", q: str = "") -> str: async def query(sid: str = "", q: str = "") -> str:
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) print(q)
db = DDB( query = f"Write a query that {q} in the current database"
StorageEngines.gcs,
gcs_access=settings.gcs_access,
gcs_secret=settings.gcs_secret,
bucket=settings.gcs_bucketname,
sid=sid,
).load_folder()
chat = await chat.eval("You're an expert sql developer", db.query_prompt()) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = (
DDB()
.gcs_secret(settings.gcs_access, settings.gcs_secret)
.load_folder_gcs(settings.gcs_bucketname, sid)
)
prompt = os.linesep.join(
[
query,
db.db_schema(),
db.load_description_gcs(settings.gcs_bucketname, sid),
"Return just the SQL statement",
]
)
print(prompt)
chat = await chat.eval("You're an expert sql developer", prompt)
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
result = str(db.query(query)) result = str(db.query(query))
print(result) print(result)

View file

@ -1,22 +1,21 @@
import aiofiles import aiofiles
import s3fs
# import s3fs
from fastapi import APIRouter, File, UploadFile from fastapi import APIRouter, File, UploadFile
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from ..analytics import DDB, StorageEngines
from ..config import settings from ..config import settings
from ..analytics import DDB
router = APIRouter() router = APIRouter()
# Configure the S3FS with your Google Cloud Storage credentials # Configure the S3FS with your Google Cloud Storage credentials
# gcs = s3fs.S3FileSystem( gcs = s3fs.S3FileSystem(
# key=settings.gcs_access, key=settings.gcs_access,
# secret=settings.gcs_secret, secret=settings.gcs_secret,
# client_kwargs={"endpoint_url": "https://storage.googleapis.com"}, client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
# ) )
# bucket_name = settings.gcs_bucketname bucket_name = settings.gcs_bucketname
@router.post("/upload", tags=["files"]) @router.post("/upload", tags=["files"])
@ -27,15 +26,10 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
await f.flush() await f.flush()
( (
DDB( DDB()
StorageEngines.gcs, .gcs_secret(settings.gcs_access, settings.gcs_secret)
gcs_access=settings.gcs_access, .load_metadata(f.name)
gcs_secret=settings.gcs_secret, .dump_gcs(settings.gcs_bucketname, sid)
bucket=settings.gcs_bucketname,
sid=sid,
)
.load_xls(f.name)
.dump()
) )
return JSONResponse( return JSONResponse(

View file

@ -1,16 +1,22 @@
from typing import Annotated
from uuid import uuid4 from uuid import uuid4
from typing import Annotated
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from ..config import settings
from ..security import oauth2_scheme from ..security import oauth2_scheme
# Scheme for the Authorization header # Scheme for the Authorization header
router = APIRouter() router = APIRouter()
if settings.auth:
@router.get("/token")
async def get_token() -> str:
return str(uuid4())
@router.get("/new_session") @router.get("/new_session")
async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str: async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
return str(uuid4()) return str(uuid4())

View file

@ -1,50 +1,40 @@
import hellocomputer
from hellocomputer.analytics import DDB
from pathlib import Path from pathlib import Path
import hellocomputer TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data"
from hellocomputer.analytics import DDB, StorageEngines
TEST_STORAGE = StorageEngines.local
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2]
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
def test_0_dump(): def test_dump():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER) db = (
db.load_xls(TEST_XLS_PATH).dump() DDB()
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
.dump_local(TEST_OUTPUT_FOLDER)
)
assert db.sheets == ("answers",) assert db.sheets == ("answers",)
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists() assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()
def test_load(): def test_load():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
assert db.sheets == ("answers",)
results = db.query("select * from answers").fetchall() results = db.query("select * from answers").fetchall()
assert len(results) == 6 assert len(results) == 6
def test_load_description(): def test_load_description():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER)
file_description = db.load_description()
assert file_description.startswith("answers") assert file_description.startswith("answers")
def test_schema(): def test_schema():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder() db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
schema = [] schema = []
for sheet in db.sheets: for sheet in db.sheets:
schema.append(db.table_schema(sheet)) schema.append(db.table_schema(sheet))
assert db.schema.startswith("The schema of the database") assert schema[0].startswith("Table name:")
def test_query_prompt():
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
assert db.query_prompt("Find the average score of all students").startswith(
"The following sentence"
)

View file

@ -1,19 +1,14 @@
import os
from pathlib import Path
import hellocomputer
import pytest import pytest
from hellocomputer.analytics import DDB, StorageEngines import os
import hellocomputer
from hellocomputer.config import settings from hellocomputer.config import settings
from hellocomputer.extraction import extract_code_block
from hellocomputer.models import Chat from hellocomputer.models import Chat
from hellocomputer.extraction import extract_code_block
from pathlib import Path
from hellocomputer.analytics import DDB
TEST_XLS_PATH = (
Path(hellocomputer.__file__).parents[2] TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
/ "test"
/ "data"
/ "TestExcelHelloComputer.xlsx"
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -34,10 +29,17 @@ async def test_simple_data_query():
query = "write a query that finds the average score of all students in the current database" query = "write a query that finds the average score of all students in the current database"
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls( db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
TEST_XLS_PATH
prompt = os.linesep.join(
[
query,
db.db_schema(),
db.load_description_local(TEST_OUTPUT_FOLDER),
"Return just the SQL statement",
]
) )
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) chat = await chat.eval("You're an expert sql developer", prompt)
query = extract_code_block(chat.last_response_content()) query = extract_code_block(chat.last_response_content())
assert query.startswith("SELECT") assert query.startswith("SELECT")