This commit is contained in:
parent
e8755e627c
commit
144856e5c0
|
@ -1,8 +1,17 @@
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
|
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
|
||||||
|
|
||||||
def get_user(request: Request) -> dict:
|
def get_user(request: Request) -> dict:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: _description_
|
||||||
|
"""
|
||||||
if settings.auth:
|
if settings.auth:
|
||||||
return request.session.get("user")
|
return request.session.get("user")
|
||||||
else:
|
else:
|
||||||
|
@ -10,6 +19,14 @@ def get_user(request: Request) -> dict:
|
||||||
|
|
||||||
|
|
||||||
def get_user_email(request: Request) -> str:
|
def get_user_email(request: Request) -> str:
|
||||||
|
"""_summary_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): _description_
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: _description_
|
||||||
|
"""
|
||||||
if settings.auth:
|
if settings.auth:
|
||||||
return request.session.get("user").get("email")
|
return request.session.get("user").get("email")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy import create_engine, text
|
import duckdb
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
|
||||||
class StorageEngines(StrEnum):
|
class StorageEngines(StrEnum):
|
||||||
|
@ -19,13 +20,11 @@ class DDB:
|
||||||
bucket: str | None = None,
|
bucket: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.engine = create_engine(
|
self.db = duckdb.connect(":memory:")
|
||||||
"duckdb:///:memory:",
|
# Assume extension autoloading
|
||||||
connect_args={
|
# self.db.sql("load httpfs")
|
||||||
"preload_extensions": ["https", "spatial"],
|
# self.db.sql("load spatial")
|
||||||
"config": {"memory_limit": "300mb"},
|
self.storage_engine = storage_engine
|
||||||
},
|
|
||||||
)
|
|
||||||
self.sheets = tuple()
|
self.sheets = tuple()
|
||||||
self.loaded = False
|
self.loaded = False
|
||||||
|
|
||||||
|
@ -37,18 +36,12 @@ class DDB:
|
||||||
bucket is not None,
|
bucket is not None,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
with self.engine.connect() as conn:
|
self.db.sql(f"""
|
||||||
conn.execute(
|
|
||||||
text(
|
|
||||||
f"""
|
|
||||||
CREATE SECRET (
|
CREATE SECRET (
|
||||||
TYPE GCS,
|
TYPE GCS,
|
||||||
KEY_ID '{gcs_access}',
|
KEY_ID '{gcs_access}',
|
||||||
SECRET '{gcs_secret}')
|
SECRET '{gcs_secret}')
|
||||||
"""
|
""")
|
||||||
)
|
|
||||||
)
|
|
||||||
conn.execute(text("LOAD httpfs"))
|
|
||||||
|
|
||||||
self.path_prefix = f"gs://{bucket}"
|
self.path_prefix = f"gs://{bucket}"
|
||||||
else:
|
else:
|
||||||
|
@ -65,6 +58,9 @@ class DDB:
|
||||||
"With local storage you need to provide the path keyword argument"
|
"With local storage you need to provide the path keyword argument"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def query(self, sql, *args, **kwargs):
|
||||||
|
return self.db.query(sql, *args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self):
|
def engine(self):
|
||||||
return self.engine.raw_connection()
|
return create_engine("duckdb:///:memory:")
|
||||||
|
|
|
@ -71,10 +71,13 @@ class SessionDB(DDB):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||||
except duckdb.duckdb.IOException:
|
except duckdb.duckdb.IOException as e:
|
||||||
# Create the folder
|
# Create the folder
|
||||||
os.makedirs(self.path_prefix)
|
if self.storage_engine == StorageEngines.local:
|
||||||
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
os.makedirs(self.path_prefix)
|
||||||
|
self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'")
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
for sheet in self.sheets:
|
for sheet in self.sheets:
|
||||||
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'")
|
||||||
|
@ -157,9 +160,6 @@ class SessionDB(DDB):
|
||||||
+ [self.table_schema(sheet) for sheet in self.sheets]
|
+ [self.table_schema(sheet) for sheet in self.sheets]
|
||||||
)
|
)
|
||||||
|
|
||||||
def query(self, sql, *args, **kwargs):
|
|
||||||
return self.db.query(sql, *args, **kwargs)
|
|
||||||
|
|
||||||
def query_prompt(self, user_prompt: str) -> str:
|
def query_prompt(self, user_prompt: str) -> str:
|
||||||
query = (
|
query = (
|
||||||
f"The following sentence is the description of a query that "
|
f"The following sentence is the description of a query that "
|
||||||
|
|
|
@ -1,8 +1,14 @@
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from anyio import open_file
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_fireworks import Fireworks
|
from langchain_fireworks import Fireworks
|
||||||
|
|
||||||
|
import hellocomputer
|
||||||
|
|
||||||
|
PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts"
|
||||||
|
|
||||||
|
|
||||||
class AvailableModels(StrEnum):
|
class AvailableModels(StrEnum):
|
||||||
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
|
llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct"
|
||||||
|
@ -12,23 +18,19 @@ class AvailableModels(StrEnum):
|
||||||
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
firefunction_2 = "accounts/fireworks/models/firefunction-v2"
|
||||||
|
|
||||||
|
|
||||||
general_prompt = """
|
class Prompts:
|
||||||
You're a helpful assistant. Perform the following tasks:
|
@classmethod
|
||||||
|
async def getter(cls, name):
|
||||||
|
async with await open_file(PROMPT_DIR / f"{name}.md") as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
----
|
@classmethod
|
||||||
{query}
|
async def general(cls):
|
||||||
----
|
return await cls.getter("general_prompt")
|
||||||
"""
|
|
||||||
|
|
||||||
sql_prompt = """
|
@classmethod
|
||||||
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
async def sql(cls):
|
||||||
|
return await cls.getter("sql_prompt")
|
||||||
----
|
|
||||||
{query}
|
|
||||||
----
|
|
||||||
|
|
||||||
Return only the sql statement without any additional text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Chat:
|
class Chat:
|
||||||
|
@ -59,14 +61,14 @@ class Chat:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def eval(self, task):
|
async def eval(self, task):
|
||||||
prompt = PromptTemplate.from_template(general_prompt)
|
prompt = PromptTemplate.from_template(await Prompts.general())
|
||||||
|
|
||||||
response = await self.model.ainvoke(prompt.format(query=task))
|
response = await self.model.ainvoke(prompt.format(query=task))
|
||||||
self.responses.append(response)
|
self.responses.append(response)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def sql_eval(self, question):
|
async def sql_eval(self, question):
|
||||||
prompt = PromptTemplate.from_template(sql_prompt)
|
prompt = PromptTemplate.from_template(await Prompts.sql())
|
||||||
|
|
||||||
response = await self.model.ainvoke(prompt.format(query=question))
|
response = await self.model.ainvoke(prompt.format(query=question))
|
||||||
self.responses.append(response)
|
self.responses.append(response)
|
||||||
|
|
1
src/hellocomputer/prompts/README.md
Normal file
1
src/hellocomputer/prompts/README.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Storage for separate prompts
|
3
src/hellocomputer/prompts/general_prompt.md
Normal file
3
src/hellocomputer/prompts/general_prompt.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
You're a helpful assistant. Perform the following tasks:
|
||||||
|
|
||||||
|
* {query}
|
5
src/hellocomputer/prompts/sql_prompt.md
Normal file
5
src/hellocomputer/prompts/sql_prompt.md
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following:
|
||||||
|
|
||||||
|
* {query}
|
||||||
|
|
||||||
|
Return only the sql statement without any additional text.
|
|
@ -8,8 +8,8 @@ from starlette.requests import Request
|
||||||
from hellocomputer.db import StorageEngines
|
from hellocomputer.db import StorageEngines
|
||||||
from hellocomputer.db.users import OwnershipDB
|
from hellocomputer.db.users import OwnershipDB
|
||||||
|
|
||||||
from ..config import settings
|
|
||||||
from ..auth import get_user_email
|
from ..auth import get_user_email
|
||||||
|
from ..config import settings
|
||||||
|
|
||||||
# Scheme for the Authorization header
|
# Scheme for the Authorization header
|
||||||
|
|
||||||
|
|
15
test/test_prompts.py
Normal file
15
test/test_prompts.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import pytest
|
||||||
|
from hellocomputer.models import Prompts
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_general_prompt():
|
||||||
|
general: str = await Prompts.general()
|
||||||
|
assert general.startswith("You're a helpful assistant")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_general_templated():
|
||||||
|
prompt = PromptTemplate.from_template(await Prompts.general())
|
||||||
|
assert "Do as I say" in prompt.format(query="Do as I say")
|
Loading…
Reference in a new issue