Drop sqlalchemy
Some checks failed
ci/woodpecker/push/woodpecker Pipeline failed

This commit is contained in:
Guillem Borrell 2024-07-17 09:04:08 +02:00
parent e8755e627c
commit 144856e5c0
9 changed files with 81 additions and 42 deletions

View file

@ -1,8 +1,17 @@
from starlette.requests import Request from starlette.requests import Request
from .config import settings from .config import settings
def get_user(request: Request) -> dict: def get_user(request: Request) -> dict:
"""_summary_
Args:
request (Request): _description_
Returns:
dict: _description_
"""
if settings.auth: if settings.auth:
return request.session.get("user") return request.session.get("user")
else: else:
@ -10,6 +19,14 @@ def get_user(request: Request) -> dict:
def get_user_email(request: Request) -> str: def get_user_email(request: Request) -> str:
"""_summary_
Args:
request (Request): _description_
Returns:
str: _description_
"""
if settings.auth: if settings.auth:
return request.session.get("user").get("email") return request.session.get("user").get("email")
else: else:

View file

@ -1,7 +1,8 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path from pathlib import Path
from sqlalchemy import create_engine, text import duckdb
from sqlalchemy import create_engine
class StorageEngines(StrEnum): class StorageEngines(StrEnum):
@ -19,13 +20,11 @@ class DDB:
bucket: str | None = None, bucket: str | None = None,
**kwargs, **kwargs,
): ):
self.engine = create_engine( self.db = duckdb.connect(":memory:")
"duckdb:///:memory:", # Assume extension autoloading
connect_args={ # self.db.sql("load httpfs")
"preload_extensions": ["https", "spatial"], # self.db.sql("load spatial")
"config": {"memory_limit": "300mb"}, self.storage_engine = storage_engine
},
)
self.sheets = tuple() self.sheets = tuple()
self.loaded = False self.loaded = False
@ -37,18 +36,12 @@ class DDB:
bucket is not None, bucket is not None,
) )
): ):
with self.engine.connect() as conn: self.db.sql(f"""
conn.execute(
text(
f"""
CREATE SECRET ( CREATE SECRET (
TYPE GCS, TYPE GCS,
KEY_ID '{gcs_access}', KEY_ID '{gcs_access}',
SECRET '{gcs_secret}') SECRET '{gcs_secret}')
""" """)
)
)
conn.execute(text("LOAD httpfs"))
self.path_prefix = f"gs://{bucket}" self.path_prefix = f"gs://{bucket}"
else: else:
@ -65,6 +58,9 @@ class DDB:
"With local storage you need to provide the path keyword argument" "With local storage you need to provide the path keyword argument"
) )
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
@property @property
def db(self): def engine(self):
return self.engine.raw_connection() return create_engine("duckdb:///:memory:")

View file

@ -71,10 +71,13 @@ class SessionDB(DDB):
try: try:
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
except duckdb.duckdb.IOException: except duckdb.duckdb.IOException as e:
# Create the folder # Create the folder
os.makedirs(self.path_prefix) if self.storage_engine == StorageEngines.local:
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") os.makedirs(self.path_prefix)
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
else:
raise e
for sheet in self.sheets: for sheet in self.sheets:
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
@ -157,9 +160,6 @@ class SessionDB(DDB):
+ [self.table_schema(sheet) for sheet in self.sheets] + [self.table_schema(sheet) for sheet in self.sheets]
) )
def query(self, sql, *args, **kwargs):
return self.db.query(sql, *args, **kwargs)
def query_prompt(self, user_prompt: str) -> str: def query_prompt(self, user_prompt: str) -> str:
query = ( query = (
f"The following sentence is the description of a query that " f"The following sentence is the description of a query that "

View file

@ -1,8 +1,14 @@
from enum import StrEnum from enum import StrEnum
from pathlib import Path
from anyio import open_file
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_fireworks import Fireworks from langchain_fireworks import Fireworks
import hellocomputer
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
class AvailableModels(StrEnum): class AvailableModels(StrEnum):
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
@ -12,23 +18,19 @@ class AvailableModels(StrEnum):
firefunction_2 = "accounts/fireworks/models/firefunction-v2" firefunction_2 = "accounts/fireworks/models/firefunction-v2"
general_prompt = """ class Prompts:
You're a helpful assistant. Perform the following tasks: @classmethod
async def getter(cls, name):
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
return await f.read()
---- @classmethod
{query} async def general(cls):
---- return await cls.getter("general_prompt")
"""
sql_prompt = """ @classmethod
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: async def sql(cls):
return await cls.getter("sql_prompt")
----
{query}
----
Return only the sql statement without any additional text.
"""
class Chat: class Chat:
@ -59,14 +61,14 @@ class Chat:
) )
async def eval(self, task): async def eval(self, task):
prompt = PromptTemplate.from_template(general_prompt) prompt = PromptTemplate.from_template(await Prompts.general())
response = await self.model.ainvoke(prompt.format(query=task)) response = await self.model.ainvoke(prompt.format(query=task))
self.responses.append(response) self.responses.append(response)
return self return self
async def sql_eval(self, question): async def sql_eval(self, question):
prompt = PromptTemplate.from_template(sql_prompt) prompt = PromptTemplate.from_template(await Prompts.sql())
response = await self.model.ainvoke(prompt.format(query=question)) response = await self.model.ainvoke(prompt.format(query=question))
self.responses.append(response) self.responses.append(response)

View file

@ -0,0 +1 @@
Storage for separate prompts

View file

@ -0,0 +1,3 @@
You're a helpful assistant. Perform the following tasks:
* {query}

View file

@ -0,0 +1,5 @@
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
* {query}
Return only the sql statement without any additional text.

View file

@ -8,8 +8,8 @@ from starlette.requests import Request
from hellocomputer.db import StorageEngines from hellocomputer.db import StorageEngines
from hellocomputer.db.users import OwnershipDB from hellocomputer.db.users import OwnershipDB
from ..config import settings
from ..auth import get_user_email from ..auth import get_user_email
from ..config import settings
# Scheme for the Authorization header # Scheme for the Authorization header

15
test/test_prompts.py Normal file
View file

@ -0,0 +1,15 @@
import pytest
from hellocomputer.models import Prompts
from langchain.prompts import PromptTemplate
@pytest.mark.asyncio
async def test_get_general_prompt():
general: str = await Prompts.general()
assert general.startswith("You're a helpful assistant")
@pytest.mark.asyncio
async def test_general_templated():
prompt = PromptTemplate.from_template(await Prompts.general())
assert "Do as I say" in prompt.format(query="Do as I say")