diff --git a/src/hellocomputer/auth.py b/src/hellocomputer/auth.py index c97564b..ca04316 100644 --- a/src/hellocomputer/auth.py +++ b/src/hellocomputer/auth.py @@ -1,8 +1,17 @@ from starlette.requests import Request + from .config import settings def get_user(request: Request) -> dict: + """_summary_ + + Args: + request (Request): _description_ + + Returns: + dict: _description_ + """ if settings.auth: return request.session.get("user") else: @@ -10,6 +19,14 @@ def get_user(request: Request) -> dict: def get_user_email(request: Request) -> str: + """_summary_ + + Args: + request (Request): _description_ + + Returns: + str: _description_ + """ if settings.auth: return request.session.get("user").get("email") else: diff --git a/src/hellocomputer/db/__init__.py b/src/hellocomputer/db/__init__.py index 113e4e3..2d03641 100644 --- a/src/hellocomputer/db/__init__.py +++ b/src/hellocomputer/db/__init__.py @@ -1,7 +1,8 @@ from enum import StrEnum from pathlib import Path -from sqlalchemy import create_engine, text +import duckdb +from sqlalchemy import create_engine class StorageEngines(StrEnum): @@ -19,13 +20,11 @@ class DDB: bucket: str | None = None, **kwargs, ): - self.engine = create_engine( - "duckdb:///:memory:", - connect_args={ - "preload_extensions": ["https", "spatial"], - "config": {"memory_limit": "300mb"}, - }, - ) + self.db = duckdb.connect(":memory:") + # Assume extension autoloading + # self.db.sql("load httpfs") + # self.db.sql("load spatial") + self.storage_engine = storage_engine self.sheets = tuple() self.loaded = False @@ -37,18 +36,12 @@ class DDB: bucket is not None, ) ): - with self.engine.connect() as conn: - conn.execute( - text( - f""" + self.db.sql(f""" CREATE SECRET ( TYPE GCS, KEY_ID '{gcs_access}', SECRET '{gcs_secret}') - """ - ) - ) - conn.execute(text("LOAD httpfs")) + """) self.path_prefix = f"gs://{bucket}" else: @@ -65,6 +58,9 @@ class DDB: "With local storage you need to provide the path keyword argument" ) + def query(self, sql, *args, **kwargs): + return self.db.query(sql, *args, **kwargs) + @property - def db(self): - return self.engine.raw_connection() + def engine(self): + return create_engine("duckdb:///:memory:") diff --git a/src/hellocomputer/db/sessions.py b/src/hellocomputer/db/sessions.py index a728334..3de4db6 100644 --- a/src/hellocomputer/db/sessions.py +++ b/src/hellocomputer/db/sessions.py @@ -71,10 +71,13 @@ class SessionDB(DDB): try: self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") - except duckdb.duckdb.IOException: + except duckdb.duckdb.IOException as e: # Create the folder - os.makedirs(self.path_prefix) - self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") + if self.storage_engine == StorageEngines.local: + os.makedirs(self.path_prefix) + self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") + else: + raise e for sheet in self.sheets: self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") @@ -157,9 +160,6 @@ class SessionDB(DDB): + [self.table_schema(sheet) for sheet in self.sheets] ) - def query(self, sql, *args, **kwargs): - return self.db.query(sql, *args, **kwargs) - def query_prompt(self, user_prompt: str) -> str: query = ( f"The following sentence is the description of a query that " diff --git a/src/hellocomputer/models.py b/src/hellocomputer/models.py index 511b3b8..5840901 100644 --- a/src/hellocomputer/models.py +++ b/src/hellocomputer/models.py @@ -1,8 +1,14 @@ from enum import StrEnum +from pathlib import Path +from anyio import open_file from langchain_core.prompts import PromptTemplate from langchain_fireworks import Fireworks +import hellocomputer + +PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts" + class AvailableModels(StrEnum): llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" @@ -12,23 +18,19 @@ class AvailableModels(StrEnum): firefunction_2 = "accounts/fireworks/models/firefunction-v2" -general_prompt = """ -You're a helpful assistant. Perform the following tasks: +class Prompts: + @classmethod + async def getter(cls, name): + async with await open_file(PROMPT_DIR / f"{name}.md") as f: + return await f.read() ----- -{query} ----- -""" + @classmethod + async def general(cls): + return await cls.getter("general_prompt") -sql_prompt = """ -You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: - ----- -{query} ----- - -Return only the sql statement without any additional text. -""" + @classmethod + async def sql(cls): + return await cls.getter("sql_prompt") class Chat: @@ -59,14 +61,14 @@ class Chat: ) async def eval(self, task): - prompt = PromptTemplate.from_template(general_prompt) + prompt = PromptTemplate.from_template(await Prompts.general()) response = await self.model.ainvoke(prompt.format(query=task)) self.responses.append(response) return self async def sql_eval(self, question): - prompt = PromptTemplate.from_template(sql_prompt) + prompt = PromptTemplate.from_template(await Prompts.sql()) response = await self.model.ainvoke(prompt.format(query=question)) self.responses.append(response) diff --git a/src/hellocomputer/prompts/README.md b/src/hellocomputer/prompts/README.md new file mode 100644 index 0000000..78ed6a7 --- /dev/null +++ b/src/hellocomputer/prompts/README.md @@ -0,0 +1 @@ +Storage for separate prompts \ No newline at end of file diff --git a/src/hellocomputer/prompts/general_prompt.md b/src/hellocomputer/prompts/general_prompt.md new file mode 100644 index 0000000..73e6c97 --- /dev/null +++ b/src/hellocomputer/prompts/general_prompt.md @@ -0,0 +1,3 @@ +You're a helpful assistant. Perform the following tasks: + +* {query} diff --git a/src/hellocomputer/prompts/sql_prompt.md b/src/hellocomputer/prompts/sql_prompt.md new file mode 100644 index 0000000..0853c71 --- /dev/null +++ b/src/hellocomputer/prompts/sql_prompt.md @@ -0,0 +1,5 @@ +You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: + +* {query} + +Return only the sql statement without any additional text. \ No newline at end of file diff --git a/src/hellocomputer/routers/sessions.py b/src/hellocomputer/routers/sessions.py index c1fcfd9..c5e12ca 100644 --- a/src/hellocomputer/routers/sessions.py +++ b/src/hellocomputer/routers/sessions.py @@ -8,8 +8,8 @@ from starlette.requests import Request from hellocomputer.db import StorageEngines from hellocomputer.db.users import OwnershipDB -from ..config import settings from ..auth import get_user_email +from ..config import settings # Scheme for the Authorization header diff --git a/test/test_prompts.py b/test/test_prompts.py new file mode 100644 index 0000000..ced5dd2 --- /dev/null +++ b/test/test_prompts.py @@ -0,0 +1,15 @@ +import pytest +from hellocomputer.models import Prompts +from langchain.prompts import PromptTemplate + + +@pytest.mark.asyncio +async def test_get_general_prompt(): + general: str = await Prompts.general() + assert general.startswith("You're a helpful assistant") + + +@pytest.mark.asyncio +async def test_general_templated(): + prompt = PromptTemplate.from_template(await Prompts.general()) + assert "Do as I say" in prompt.format(query="Do as I say")