Drop sqlalchemy
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-17 09:04:08 +02:00
parent e8755e627c
commit 144856e5c0
9 changed files with 81 additions and 42 deletions

View file

@ -1,8 +1,17 @@
from starlette.requests import Request
from .config import settings
def get_user(request: Request) -> dict:
"""_summary_
Args:
request (Request): _description_
Returns:
dict: _description_
"""
if settings.auth:
return request.session.get("user")
else:
@ -10,6 +19,14 @@ def get_user(request: Request) -> dict:
def get_user_email(request: Request) -> str:
"""_summary_
Args:
request (Request): _description_
Returns:
str: _description_
"""
if settings.auth:
return request.session.get("user").get("email")
else:

View file

@ -1,7 +1,8 @@
from enum import StrEnum
from pathlib import Path
from sqlalchemy import create_engine, text
import duckdb
from sqlalchemy import create_engine
class StorageEngines(StrEnum):
@ -19,13 +20,11 @@ class DDB:
bucket: str | None = None,
**kwargs,
):
self.engine = create_engine(
"duckdb:///:memory:",
connect_args={
"preload_extensions": ["https", "spatial"],
"config": {"memory_limit": "300mb"},
},
)
self.db = duckdb.connect(":memory:")
# Assume extension autoloading
# self.db.sql("load httpfs")
# self.db.sql("load spatial")
self.storage_engine = storage_engine
self.sheets = tuple()
self.loaded = False
@ -37,18 +36,12 @@ class DDB:
bucket is not None,
)
):
with self.engine.connect() as conn:
conn.execute(
text(
f"""
self.db.sql(f"""
CREATE SECRET (
TYPE GCS,
KEY_ID '{gcs_access}',
SECRET '{gcs_secret}')
"""
)
)
conn.execute(text("LOAD httpfs"))
""")
self.path_prefix = f"gs://{bucket}"
else:
@ -65,6 +58,9 @@ class DDB:
"With local storage you need to provide the path keyword argument"
)
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
@property
def db(self):
return self.engine.raw_connection()
def engine(self):
return create_engine("duckdb:///:memory:")

View file

@ -71,10 +71,13 @@ class SessionDB(DDB):
try:
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
except duckdb.duckdb.IOException:
except duckdb.duckdb.IOException as e:
# Create the folder
os.makedirs(self.path_prefix)
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
if self.storage_engine == StorageEngines.local:
os.makedirs(self.path_prefix)
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
else:
raise e
for sheet in self.sheets:
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
@ -157,9 +160,6 @@ class SessionDB(DDB):
+ [self.table_schema(sheet) for sheet in self.sheets]
)
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
def query_prompt(self, user_prompt: str) -> str:
query = (
f"The following sentence is the description of a query that "

View file

@ -1,8 +1,14 @@
from enum import StrEnum
from pathlib import Path
from anyio import open_file
from langchain_core.prompts import PromptTemplate
from langchain_fireworks import Fireworks
import hellocomputer
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
class AvailableModels(StrEnum):
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
@ -12,23 +18,19 @@ class AvailableModels(StrEnum):
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
general_prompt = """
You're a helpful assistant. Perform the following tasks:
class Prompts:
@classmethod
async def getter(cls, name):
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
return await f.read()
----
{query}
----
"""
@classmethod
async def general(cls):
return await cls.getter("general_prompt")
sql_prompt = """
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
----
{query}
----
Return only the sql statement without any additional text.
"""
@classmethod
async def sql(cls):
return await cls.getter("sql_prompt")
class Chat:
@ -59,14 +61,14 @@ class Chat:
)
async def eval(self, task):
prompt = PromptTemplate.from_template(general_prompt)
prompt = PromptTemplate.from_template(await Prompts.general())
response = await self.model.ainvoke(prompt.format(query=task))
self.responses.append(response)
return self
async def sql_eval(self, question):
prompt = PromptTemplate.from_template(sql_prompt)
prompt = PromptTemplate.from_template(await Prompts.sql())
response = await self.model.ainvoke(prompt.format(query=question))
self.responses.append(response)

View file

@ -0,0 +1 @@
Storage for separate prompts

View file

@ -0,0 +1,3 @@
You're a helpful assistant. Perform the following tasks:
* {query}

View file

@ -0,0 +1,5 @@
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
* {query}
Return only the sql statement without any additional text.

View file

@ -8,8 +8,8 @@ from starlette.requests import Request
from hellocomputer.db import StorageEngines
from hellocomputer.db.users import OwnershipDB
from ..config import settings
from ..auth import get_user_email
from ..config import settings
# Scheme for the Authorization header

15
test/test_prompts.py Normal file
View file

@ -0,0 +1,15 @@
import pytest
from hellocomputer.models import Prompts
from langchain.prompts import PromptTemplate
@pytest.mark.asyncio
async def test_get_general_prompt():
general: str = await Prompts.general()
assert general.startswith("You're a helpful assistant")
@pytest.mark.asyncio
async def test_general_templated():
prompt = PromptTemplate.from_template(await Prompts.general())
assert "Do as I say" in prompt.format(query="Do as I say")