This commit is contained in:
parent
e8755e627c
commit
144856e5c0
|
@ -1,8 +1,17 @@
|
|||
from starlette.requests import Request
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
def get_user(request: Request) -> dict:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
request (Request): _description_
|
||||
|
||||
Returns:
|
||||
dict: _description_
|
||||
"""
|
||||
if settings.auth:
|
||||
return request.session.get("user")
|
||||
else:
|
||||
|
@ -10,6 +19,14 @@ def get_user(request: Request) -> dict:
|
|||
|
||||
|
||||
def get_user_email(request: Request) -> str:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
request (Request): _description_
|
||||
|
||||
Returns:
|
||||
str: _description_
|
||||
"""
|
||||
if settings.auth:
|
||||
return request.session.get("user").get("email")
|
||||
else:
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
import duckdb
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
|
||||
class StorageEngines(StrEnum):
|
||||
|
@ -19,13 +20,11 @@ class DDB:
|
|||
bucket: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.engine = create_engine(
|
||||
"duckdb:///:memory:",
|
||||
connect_args={
|
||||
"preload_extensions": ["https", "spatial"],
|
||||
"config": {"memory_limit": "300mb"},
|
||||
},
|
||||
)
|
||||
self.db = duckdb.connect(":memory:")
|
||||
# Assume extension autoloading
|
||||
# self.db.sql("load httpfs")
|
||||
# self.db.sql("load spatial")
|
||||
self.storage_engine = storage_engine
|
||||
self.sheets = tuple()
|
||||
self.loaded = False
|
||||
|
||||
|
@ -37,18 +36,12 @@ class DDB:
|
|||
bucket is not None,
|
||||
)
|
||||
):
|
||||
with self.engine.connect() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
f"""
|
||||
self.db.sql(f"""
|
||||
CREATE SECRET (
|
||||
TYPE GCS,
|
||||
KEY_ID '{gcs_access}',
|
||||
SECRET '{gcs_secret}')
|
||||
"""
|
||||
)
|
||||
)
|
||||
conn.execute(text("LOAD httpfs"))
|
||||
""")
|
||||
|
||||
self.path_prefix = f"gs://{bucket}"
|
||||
else:
|
||||
|
@ -65,6 +58,9 @@ class DDB:
|
|||
"With local storage you need to provide the path keyword argument"
|
||||
)
|
||||
|
||||
def query(self, sql, *args, **kwargs):
|
||||
return self.db.query(sql, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def db(self):
|
||||
return self.engine.raw_connection()
|
||||
def engine(self):
|
||||
return create_engine("duckdb:///:memory:")
|
||||
|
|
|
@ -71,10 +71,13 @@ class SessionDB(DDB):
|
|||
|
||||
try:
|
||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||
except duckdb.duckdb.IOException:
|
||||
except duckdb.duckdb.IOException as e:
|
||||
# Create the folder
|
||||
if self.storage_engine == StorageEngines.local:
|
||||
os.makedirs(self.path_prefix)
|
||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||
else:
|
||||
raise e
|
||||
|
||||
for sheet in self.sheets:
|
||||
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
||||
|
@ -157,9 +160,6 @@ class SessionDB(DDB):
|
|||
+ [self.table_schema(sheet) for sheet in self.sheets]
|
||||
)
|
||||
|
||||
def query(self, sql, *args, **kwargs):
|
||||
return self.db.query(sql, *args, **kwargs)
|
||||
|
||||
def query_prompt(self, user_prompt: str) -> str:
|
||||
query = (
|
||||
f"The following sentence is the description of a query that "
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
|
||||
from anyio import open_file
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_fireworks import Fireworks
|
||||
|
||||
import hellocomputer
|
||||
|
||||
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
||||
|
||||
|
||||
class AvailableModels(StrEnum):
|
||||
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
|
||||
|
@ -12,23 +18,19 @@ class AvailableModels(StrEnum):
|
|||
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||
|
||||
|
||||
general_prompt = """
|
||||
You're a helpful assistant. Perform the following tasks:
|
||||
class Prompts:
|
||||
@classmethod
|
||||
async def getter(cls, name):
|
||||
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
||||
return await f.read()
|
||||
|
||||
----
|
||||
{query}
|
||||
----
|
||||
"""
|
||||
@classmethod
|
||||
async def general(cls):
|
||||
return await cls.getter("general_prompt")
|
||||
|
||||
sql_prompt = """
|
||||
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
||||
|
||||
----
|
||||
{query}
|
||||
----
|
||||
|
||||
Return only the sql statement without any additional text.
|
||||
"""
|
||||
@classmethod
|
||||
async def sql(cls):
|
||||
return await cls.getter("sql_prompt")
|
||||
|
||||
|
||||
class Chat:
|
||||
|
@ -59,14 +61,14 @@ class Chat:
|
|||
)
|
||||
|
||||
async def eval(self, task):
|
||||
prompt = PromptTemplate.from_template(general_prompt)
|
||||
prompt = PromptTemplate.from_template(await Prompts.general())
|
||||
|
||||
response = await self.model.ainvoke(prompt.format(query=task))
|
||||
self.responses.append(response)
|
||||
return self
|
||||
|
||||
async def sql_eval(self, question):
|
||||
prompt = PromptTemplate.from_template(sql_prompt)
|
||||
prompt = PromptTemplate.from_template(await Prompts.sql())
|
||||
|
||||
response = await self.model.ainvoke(prompt.format(query=question))
|
||||
self.responses.append(response)
|
||||
|
|
1
src/hellocomputer/prompts/README.md
Normal file
1
src/hellocomputer/prompts/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Storage for separate prompts
|
3
src/hellocomputer/prompts/general_prompt.md
Normal file
3
src/hellocomputer/prompts/general_prompt.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
You're a helpful assistant. Perform the following tasks:
|
||||
|
||||
* {query}
|
5
src/hellocomputer/prompts/sql_prompt.md
Normal file
5
src/hellocomputer/prompts/sql_prompt.md
Normal file
|
@ -0,0 +1,5 @@
|
|||
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
||||
|
||||
* {query}
|
||||
|
||||
Return only the sql statement without any additional text.
|
|
@ -8,8 +8,8 @@ from starlette.requests import Request
|
|||
from hellocomputer.db import StorageEngines
|
||||
from hellocomputer.db.users import OwnershipDB
|
||||
|
||||
from ..config import settings
|
||||
from ..auth import get_user_email
|
||||
from ..config import settings
|
||||
|
||||
# Scheme for the Authorization header
|
||||
|
||||
|
|
15
test/test_prompts.py
Normal file
15
test/test_prompts.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import pytest
|
||||
from hellocomputer.models import Prompts
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_general_prompt():
|
||||
general: str = await Prompts.general()
|
||||
assert general.startswith("You're a helpful assistant")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_general_templated():
|
||||
prompt = PromptTemplate.from_template(await Prompts.general())
|
||||
assert "Do as I say" in prompt.format(query="Do as I say")
|
Loading…
Reference in a new issue