Compare commits
2 commits
06ac295e17
...
d7ba280a2f
Author | SHA1 | Date | |
---|---|---|---|
d7ba280a2f | |||
0c21073e88 |
|
@ -1,36 +1,71 @@
|
||||||
import duckdb
|
|
||||||
import os
|
import os
|
||||||
|
from enum import StrEnum
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import duckdb
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
||||||
|
class StorageEngines(StrEnum):
|
||||||
|
local = "Local"
|
||||||
|
gcs = "GCS"
|
||||||
|
|
||||||
|
|
||||||
class DDB:
|
class DDB:
|
||||||
def __init__(self):
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_engine: StorageEngines,
|
||||||
|
sid: str | None = None,
|
||||||
|
path: Path | None = None,
|
||||||
|
gcs_access: str | None = None,
|
||||||
|
gcs_secret: str | None = None,
|
||||||
|
bucket: str | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.db = duckdb.connect()
|
self.db = duckdb.connect()
|
||||||
self.db.install_extension("spatial")
|
self.db.install_extension("spatial")
|
||||||
self.db.install_extension("httpfs")
|
self.db.install_extension("httpfs")
|
||||||
self.db.load_extension("spatial")
|
self.db.load_extension("spatial")
|
||||||
self.db.load_extension("httpfs")
|
self.db.load_extension("httpfs")
|
||||||
self.sheets = tuple()
|
self.sheets = tuple()
|
||||||
self.path = ""
|
self.loaded = False
|
||||||
|
|
||||||
def gcs_secret(self, gcs_access: str, gcs_secret: str) -> Self:
|
if storage_engine == StorageEngines.gcs:
|
||||||
self.db.sql(f"""
|
if all(
|
||||||
CREATE SECRET (
|
gcs_access is not None,
|
||||||
TYPE GCS,
|
gcs_secret is not None,
|
||||||
KEY_ID '{gcs_access}',
|
bucket is not None,
|
||||||
SECRET '{gcs_secret}')
|
sid is not None,
|
||||||
""")
|
):
|
||||||
|
self.db.sql(f"""
|
||||||
|
CREATE SECRET (
|
||||||
|
TYPE GCS,
|
||||||
|
KEY_ID '{gcs_access}',
|
||||||
|
SECRET '{gcs_secret}')
|
||||||
|
""")
|
||||||
|
self.path_prefix = f"gcs://{bucket}/sessions/{sid}"
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"With GCS storage engine you need to provide "
|
||||||
|
"the gcs_access, gcs_secret, sid, and bucket keyword arguments"
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
elif storage_engine == StorageEngines.local:
|
||||||
|
if path is not None:
|
||||||
|
self.path_prefix = path
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"With local storage you need to provide the path keyword argument"
|
||||||
|
)
|
||||||
|
|
||||||
def load_metadata(self, path: str = "") -> Self:
|
def load_xls(self, xls_path: Path) -> Self:
|
||||||
"""For some reason, the header is not loaded"""
|
"""For some reason, the header is not loaded"""
|
||||||
self.db.sql(f"""
|
self.db.sql(f"""
|
||||||
create table metadata as (
|
create table metadata as (
|
||||||
select
|
select
|
||||||
*
|
*
|
||||||
from
|
from
|
||||||
st_read('{path}',
|
st_read('{xls_path}',
|
||||||
layer='metadata'
|
layer='metadata'
|
||||||
)
|
)
|
||||||
)""")
|
)""")
|
||||||
|
@ -39,62 +74,55 @@ class DDB:
|
||||||
.fetchall()[0][0]
|
.fetchall()[0][0]
|
||||||
.split(";")
|
.split(";")
|
||||||
)
|
)
|
||||||
self.path = path
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def dump_local(self, path) -> Self:
|
|
||||||
# TODO: Port to fsspec and have a single dump file
|
|
||||||
self.db.query(f"copy metadata to '{path}/metadata.csv'")
|
|
||||||
|
|
||||||
for sheet in self.sheets:
|
for sheet in self.sheets:
|
||||||
self.db.query(f"""
|
self.db.query(f"""
|
||||||
copy
|
create table {sheet} as
|
||||||
(
|
(
|
||||||
select
|
select
|
||||||
*
|
*
|
||||||
from
|
from
|
||||||
st_read
|
st_read
|
||||||
(
|
(
|
||||||
'{self.path}',
|
'{xls_path}',
|
||||||
layer = '{sheet}'
|
layer = '{sheet}'
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
to '{path}/{sheet}.csv'
|
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
self.loaded = True
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def dump_gcs(self, bucketname, sid) -> Self:
|
def dump(self) -> Self:
|
||||||
self.db.sql(
|
# TODO: Create a decorator
|
||||||
f"copy metadata to 'gcs://{bucketname}/sessions/{sid}/metadata.csv'"
|
if not self.loaded:
|
||||||
|
raise ValueError("Data should be loaded first")
|
||||||
|
|
||||||
|
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||||
|
|
||||||
|
for sheet in self.sheets:
|
||||||
|
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def load_folder(self) -> Self:
|
||||||
|
self.query(
|
||||||
|
f"""
|
||||||
|
create table metadata as (
|
||||||
|
select
|
||||||
|
*
|
||||||
|
from
|
||||||
|
read_csv_auto('{self.path_prefix}/metadata.csv')
|
||||||
|
)
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
for sheet in self.sheets:
|
|
||||||
self.db.query(f"""
|
|
||||||
copy
|
|
||||||
(
|
|
||||||
select
|
|
||||||
*
|
|
||||||
from
|
|
||||||
st_read
|
|
||||||
(
|
|
||||||
'{self.path}',
|
|
||||||
layer = '{sheet}'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
to 'gcs://{bucketname}/sessions/{sid}/{sheet}.csv'
|
|
||||||
""")
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def load_folder_local(self, path: str) -> Self:
|
|
||||||
self.sheets = tuple(
|
self.sheets = tuple(
|
||||||
self.query(
|
self.query(
|
||||||
f"""
|
"""
|
||||||
select
|
select
|
||||||
Field2
|
Field2
|
||||||
from
|
from
|
||||||
read_csv_auto('{path}/metadata.csv')
|
metadata
|
||||||
where
|
where
|
||||||
Field1 = 'Sheets'
|
Field1 = 'Sheets'
|
||||||
"""
|
"""
|
||||||
|
@ -110,61 +138,21 @@ class DDB:
|
||||||
select
|
select
|
||||||
*
|
*
|
||||||
from
|
from
|
||||||
read_csv_auto('{path}/{sheet}.csv')
|
read_csv_auto('{self.path_prefix}/{sheet}.csv')
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
return self
|
self.loaded = True
|
||||||
|
|
||||||
def load_folder_gcs(self, bucketname: str, sid: str) -> Self:
|
|
||||||
self.sheets = tuple(
|
|
||||||
self.query(
|
|
||||||
f"""
|
|
||||||
select
|
|
||||||
Field2
|
|
||||||
from
|
|
||||||
read_csv_auto(
|
|
||||||
'gcs://{bucketname}/sessions/{sid}/metadata.csv'
|
|
||||||
)
|
|
||||||
where
|
|
||||||
Field1 = 'Sheets'
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
.fetchall()[0][0]
|
|
||||||
.split(";")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load all the tables into the database
|
|
||||||
for sheet in self.sheets:
|
|
||||||
self.db.query(f"""
|
|
||||||
create table {sheet} as (
|
|
||||||
select
|
|
||||||
*
|
|
||||||
from
|
|
||||||
read_csv_auto('gcs://{bucketname}/sessions/{sid}/{sheet}.csv')
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def load_description_local(self, path: str) -> Self:
|
def load_description(self) -> Self:
|
||||||
return self.query(
|
return self.query(
|
||||||
f"""
|
"""
|
||||||
select
|
select
|
||||||
Field2
|
Field2
|
||||||
from
|
from
|
||||||
read_csv_auto('{path}/metadata.csv')
|
metadata
|
||||||
where
|
|
||||||
Field1 = 'Description'"""
|
|
||||||
).fetchall()[0][0]
|
|
||||||
|
|
||||||
def load_description_gcs(self, bucketname: str, sid: str) -> Self:
|
|
||||||
return self.query(
|
|
||||||
f"""
|
|
||||||
select
|
|
||||||
Field2
|
|
||||||
from
|
|
||||||
read_csv_auto('gcs://{bucketname}/sessions/{sid}/metadata.csv')
|
|
||||||
where
|
where
|
||||||
Field1 = 'Description'"""
|
Field1 = 'Description'"""
|
||||||
).fetchall()[0][0]
|
).fetchall()[0][0]
|
||||||
|
@ -182,9 +170,11 @@ class DDB:
|
||||||
f"select column_name, column_type from (describe {table})"
|
f"select column_name, column_type from (describe {table})"
|
||||||
).fetchall()
|
).fetchall()
|
||||||
)
|
)
|
||||||
|
+ [os.linesep]
|
||||||
)
|
)
|
||||||
|
|
||||||
def db_schema(self):
|
@property
|
||||||
|
def schema(self):
|
||||||
return os.linesep.join(
|
return os.linesep.join(
|
||||||
[
|
[
|
||||||
"The schema of the database is the following:",
|
"The schema of the database is the following:",
|
||||||
|
@ -194,3 +184,18 @@ class DDB:
|
||||||
|
|
||||||
def query(self, sql, *args, **kwargs):
|
def query(self, sql, *args, **kwargs):
|
||||||
return self.db.query(sql, *args, **kwargs)
|
return self.db.query(sql, *args, **kwargs)
|
||||||
|
|
||||||
|
def query_prompt(self, user_prompt: str) -> str:
|
||||||
|
query = (
|
||||||
|
f"The following sentence is the description of a query that "
|
||||||
|
f"needs to be executed in a database: {user_prompt}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return os.linesep.join(
|
||||||
|
[
|
||||||
|
query,
|
||||||
|
self.schema,
|
||||||
|
self.load_description(),
|
||||||
|
"Return just the SQL statement",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
|
@ -6,7 +6,7 @@ class Settings(BaseSettings):
|
||||||
gcs_access: str = "access"
|
gcs_access: str = "access"
|
||||||
gcs_secret: str = "secret"
|
gcs_secret: str = "secret"
|
||||||
gcs_bucketname: str = "bucket"
|
gcs_bucketname: str = "bucket"
|
||||||
auth: bool = False
|
auth: bool = True
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env")
|
model_config = SettingsConfigDict(env_file=".env")
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
import hellocomputer
|
import hellocomputer
|
||||||
|
|
||||||
from .routers import files, sessions, analysis
|
from .routers import analysis, files, sessions
|
||||||
|
|
||||||
static_path = Path(hellocomputer.__file__).parent / "static"
|
static_path = Path(hellocomputer.__file__).parent / "static"
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from langchain_community.chat_models import ChatAnyscale
|
from langchain_community.chat_models import ChatAnyscale
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
|
@ -1,41 +1,29 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
from ..config import settings
|
from hellocomputer.analytics import DDB, StorageEngines
|
||||||
from ..models import Chat
|
|
||||||
|
|
||||||
from hellocomputer.analytics import DDB
|
|
||||||
from hellocomputer.extraction import extract_code_block
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
|
||||||
import os
|
from ..config import settings
|
||||||
|
from ..models import Chat
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
@router.get("/query", response_class=PlainTextResponse, tags=["queries"])
|
||||||
async def query(sid: str = "", q: str = "") -> str:
|
async def query(sid: str = "", q: str = "") -> str:
|
||||||
print(q)
|
|
||||||
query = f"Write a query that {q} in the current database"
|
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = (
|
db = DDB(
|
||||||
DDB()
|
StorageEngines.gcs,
|
||||||
.gcs_secret(settings.gcs_access, settings.gcs_secret)
|
gcs_access=settings.gcs_access,
|
||||||
.load_folder_gcs(settings.gcs_bucketname, sid)
|
gcs_secret=settings.gcs_secret,
|
||||||
)
|
bucket=settings.gcs_bucketname,
|
||||||
|
sid=sid,
|
||||||
|
).load_folder()
|
||||||
|
|
||||||
prompt = os.linesep.join(
|
chat = await chat.eval("You're an expert sql developer", db.query_prompt())
|
||||||
[
|
|
||||||
query,
|
|
||||||
db.db_schema(),
|
|
||||||
db.load_description_gcs(settings.gcs_bucketname, sid),
|
|
||||||
"Return just the SQL statement",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(prompt)
|
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", prompt)
|
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
result = str(db.query(query))
|
result = str(db.query(query))
|
||||||
print(result)
|
print(result)
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import s3fs
|
|
||||||
|
# import s3fs
|
||||||
from fastapi import APIRouter, File, UploadFile
|
from fastapi import APIRouter, File, UploadFile
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from ..analytics import DDB, StorageEngines
|
||||||
from ..config import settings
|
from ..config import settings
|
||||||
from ..analytics import DDB
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
# Configure the S3FS with your Google Cloud Storage credentials
|
# Configure the S3FS with your Google Cloud Storage credentials
|
||||||
gcs = s3fs.S3FileSystem(
|
# gcs = s3fs.S3FileSystem(
|
||||||
key=settings.gcs_access,
|
# key=settings.gcs_access,
|
||||||
secret=settings.gcs_secret,
|
# secret=settings.gcs_secret,
|
||||||
client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
|
# client_kwargs={"endpoint_url": "https://storage.googleapis.com"},
|
||||||
)
|
# )
|
||||||
bucket_name = settings.gcs_bucketname
|
# bucket_name = settings.gcs_bucketname
|
||||||
|
|
||||||
|
|
||||||
@router.post("/upload", tags=["files"])
|
@router.post("/upload", tags=["files"])
|
||||||
|
@ -26,10 +27,15 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""):
|
||||||
await f.flush()
|
await f.flush()
|
||||||
|
|
||||||
(
|
(
|
||||||
DDB()
|
DDB(
|
||||||
.gcs_secret(settings.gcs_access, settings.gcs_secret)
|
StorageEngines.gcs,
|
||||||
.load_metadata(f.name)
|
gcs_access=settings.gcs_access,
|
||||||
.dump_gcs(settings.gcs_bucketname, sid)
|
gcs_secret=settings.gcs_secret,
|
||||||
|
bucket=settings.gcs_bucketname,
|
||||||
|
sid=sid,
|
||||||
|
)
|
||||||
|
.load_xls(f.name)
|
||||||
|
.dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
|
|
@ -1,22 +1,16 @@
|
||||||
from uuid import uuid4
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..security import oauth2_scheme
|
from ..security import oauth2_scheme
|
||||||
|
|
||||||
# Scheme for the Authorization header
|
# Scheme for the Authorization header
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
if settings.auth:
|
|
||||||
|
|
||||||
@router.get("/token")
|
|
||||||
async def get_token() -> str:
|
|
||||||
return str(uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/new_session")
|
@router.get("/new_session")
|
||||||
async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
|
async def get_new_session(token: Annotated[str, Depends(oauth2_scheme)]) -> str:
|
||||||
return str(uuid4())
|
return str(uuid4())
|
||||||
|
|
|
@ -1,40 +1,50 @@
|
||||||
import hellocomputer
|
|
||||||
from hellocomputer.analytics import DDB
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
TEST_DATA_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "data"
|
import hellocomputer
|
||||||
|
from hellocomputer.analytics import DDB, StorageEngines
|
||||||
|
|
||||||
|
TEST_STORAGE = StorageEngines.local
|
||||||
|
TEST_XLS_PATH = (
|
||||||
|
Path(hellocomputer.__file__).parents[2]
|
||||||
|
/ "test"
|
||||||
|
/ "data"
|
||||||
|
/ "TestExcelHelloComputer.xlsx"
|
||||||
|
)
|
||||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
||||||
|
|
||||||
|
|
||||||
def test_dump():
|
def test_0_dump():
|
||||||
db = (
|
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER)
|
||||||
DDB()
|
db.load_xls(TEST_XLS_PATH).dump()
|
||||||
.load_metadata(TEST_DATA_FOLDER / "TestExcelHelloComputer.xlsx")
|
|
||||||
.dump_local(TEST_OUTPUT_FOLDER)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert db.sheets == ("answers",)
|
assert db.sheets == ("answers",)
|
||||||
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()
|
assert (TEST_OUTPUT_FOLDER / "answers.csv").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_load():
|
def test_load():
|
||||||
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
|
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
|
|
||||||
assert db.sheets == ("answers",)
|
|
||||||
|
|
||||||
results = db.query("select * from answers").fetchall()
|
results = db.query("select * from answers").fetchall()
|
||||||
assert len(results) == 6
|
assert len(results) == 6
|
||||||
|
|
||||||
|
|
||||||
def test_load_description():
|
def test_load_description():
|
||||||
file_description = DDB().load_description_local(TEST_OUTPUT_FOLDER)
|
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
|
file_description = db.load_description()
|
||||||
assert file_description.startswith("answers")
|
assert file_description.startswith("answers")
|
||||||
|
|
||||||
|
|
||||||
def test_schema():
|
def test_schema():
|
||||||
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
|
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
schema = []
|
schema = []
|
||||||
for sheet in db.sheets:
|
for sheet in db.sheets:
|
||||||
schema.append(db.table_schema(sheet))
|
schema.append(db.table_schema(sheet))
|
||||||
|
|
||||||
assert schema[0].startswith("Table name:")
|
assert db.schema.startswith("The schema of the database")
|
||||||
|
|
||||||
|
|
||||||
|
def test_query_prompt():
|
||||||
|
db = DDB(storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER).load_folder()
|
||||||
|
|
||||||
|
assert db.query_prompt("Find the average score of all students").startswith(
|
||||||
|
"The following sentence"
|
||||||
|
)
|
||||||
|
|
|
@ -1,14 +1,19 @@
|
||||||
import pytest
|
|
||||||
import os
|
import os
|
||||||
import hellocomputer
|
|
||||||
from hellocomputer.config import settings
|
|
||||||
from hellocomputer.models import Chat
|
|
||||||
from hellocomputer.extraction import extract_code_block
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from hellocomputer.analytics import DDB
|
|
||||||
|
|
||||||
|
import hellocomputer
|
||||||
|
import pytest
|
||||||
|
from hellocomputer.analytics import DDB, StorageEngines
|
||||||
|
from hellocomputer.config import settings
|
||||||
|
from hellocomputer.extraction import extract_code_block
|
||||||
|
from hellocomputer.models import Chat
|
||||||
|
|
||||||
TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output"
|
TEST_XLS_PATH = (
|
||||||
|
Path(hellocomputer.__file__).parents[2]
|
||||||
|
/ "test"
|
||||||
|
/ "data"
|
||||||
|
/ "TestExcelHelloComputer.xlsx"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -29,17 +34,10 @@ async def test_simple_data_query():
|
||||||
query = "write a query that finds the average score of all students in the current database"
|
query = "write a query that finds the average score of all students in the current database"
|
||||||
|
|
||||||
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5)
|
||||||
db = DDB().load_folder_local(TEST_OUTPUT_FOLDER)
|
db = DDB(storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent).load_xls(
|
||||||
|
TEST_XLS_PATH
|
||||||
prompt = os.linesep.join(
|
|
||||||
[
|
|
||||||
query,
|
|
||||||
db.db_schema(),
|
|
||||||
db.load_description_local(TEST_OUTPUT_FOLDER),
|
|
||||||
"Return just the SQL statement",
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
chat = await chat.eval("You're an expert sql developer", prompt)
|
chat = await chat.eval("You're an expert sql developer", db.query_prompt(query))
|
||||||
query = extract_code_block(chat.last_response_content())
|
query = extract_code_block(chat.last_response_content())
|
||||||
assert query.startswith("SELECT")
|
assert query.startswith("SELECT")
|
||||||
|
|
Loading…
Reference in a new issue