This commit is contained in:
		
							parent
							
								
									e8755e627c
								
							
						
					
					
						commit
						144856e5c0
					
				|  | @ -1,8 +1,17 @@ | ||||||
| from starlette.requests import Request | from starlette.requests import Request | ||||||
|  | 
 | ||||||
| from .config import settings | from .config import settings | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_user(request: Request) -> dict: | def get_user(request: Request) -> dict: | ||||||
|  |     """_summary_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         request (Request): _description_ | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         dict: _description_ | ||||||
|  |     """ | ||||||
|     if settings.auth: |     if settings.auth: | ||||||
|         return request.session.get("user") |         return request.session.get("user") | ||||||
|     else: |     else: | ||||||
|  | @ -10,6 +19,14 @@ def get_user(request: Request) -> dict: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_user_email(request: Request) -> str: | def get_user_email(request: Request) -> str: | ||||||
|  |     """_summary_ | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         request (Request): _description_ | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         str: _description_ | ||||||
|  |     """ | ||||||
|     if settings.auth: |     if settings.auth: | ||||||
|         return request.session.get("user").get("email") |         return request.session.get("user").get("email") | ||||||
|     else: |     else: | ||||||
|  |  | ||||||
|  | @ -1,7 +1,8 @@ | ||||||
| from enum import StrEnum | from enum import StrEnum | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from sqlalchemy import create_engine, text | import duckdb | ||||||
|  | from sqlalchemy import create_engine | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class StorageEngines(StrEnum): | class StorageEngines(StrEnum): | ||||||
|  | @ -19,13 +20,11 @@ class DDB: | ||||||
|         bucket: str | None = None, |         bucket: str | None = None, | ||||||
|         **kwargs, |         **kwargs, | ||||||
|     ): |     ): | ||||||
|         self.engine = create_engine( |         self.db = duckdb.connect(":memory:") | ||||||
|             "duckdb:///:memory:", |         # Assume extension autoloading | ||||||
|             connect_args={ |         # self.db.sql("load httpfs") | ||||||
|                 "preload_extensions": ["https", "spatial"], |         # self.db.sql("load spatial") | ||||||
|                 "config": {"memory_limit": "300mb"}, |         self.storage_engine = storage_engine | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         self.sheets = tuple() |         self.sheets = tuple() | ||||||
|         self.loaded = False |         self.loaded = False | ||||||
| 
 | 
 | ||||||
|  | @ -37,18 +36,12 @@ class DDB: | ||||||
|                     bucket is not None, |                     bucket is not None, | ||||||
|                 ) |                 ) | ||||||
|             ): |             ): | ||||||
|                 with self.engine.connect() as conn: |                 self.db.sql(f""" | ||||||
|                     conn.execute( |  | ||||||
|                         text( |  | ||||||
|                             f""" |  | ||||||
|                     CREATE SECRET ( |                     CREATE SECRET ( | ||||||
|                     TYPE GCS, |                     TYPE GCS, | ||||||
|                     KEY_ID '{gcs_access}', |                     KEY_ID '{gcs_access}', | ||||||
|                     SECRET '{gcs_secret}') |                     SECRET '{gcs_secret}') | ||||||
|                     """ |                     """) | ||||||
|                         ) |  | ||||||
|                     ) |  | ||||||
|                     conn.execute(text("LOAD httpfs")) |  | ||||||
| 
 | 
 | ||||||
|                 self.path_prefix = f"gs://{bucket}" |                 self.path_prefix = f"gs://{bucket}" | ||||||
|             else: |             else: | ||||||
|  | @ -65,6 +58,9 @@ class DDB: | ||||||
|                     "With local storage you need to provide the path keyword argument" |                     "With local storage you need to provide the path keyword argument" | ||||||
|                 ) |                 ) | ||||||
| 
 | 
 | ||||||
|  |     def query(self, sql, *args, **kwargs): | ||||||
|  |         return self.db.query(sql, *args, **kwargs) | ||||||
|  | 
 | ||||||
|     @property |     @property | ||||||
|     def db(self): |     def engine(self): | ||||||
|         return self.engine.raw_connection() |         return create_engine("duckdb:///:memory:") | ||||||
|  |  | ||||||
|  | @ -71,10 +71,13 @@ class SessionDB(DDB): | ||||||
| 
 | 
 | ||||||
|         try: |         try: | ||||||
|             self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") |             self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") | ||||||
|         except duckdb.duckdb.IOException: |         except duckdb.duckdb.IOException as e: | ||||||
|             # Create the folder |             # Create the folder | ||||||
|             os.makedirs(self.path_prefix) |             if self.storage_engine == StorageEngines.local: | ||||||
|             self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") |                 os.makedirs(self.path_prefix) | ||||||
|  |                 self.db.query(f"copy metadata to '{self.path_prefix}/metadata.csv'") | ||||||
|  |             else: | ||||||
|  |                 raise e | ||||||
| 
 | 
 | ||||||
|         for sheet in self.sheets: |         for sheet in self.sheets: | ||||||
|             self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") |             self.db.query(f"copy {sheet} to '{self.path_prefix}/{sheet}.csv'") | ||||||
|  | @ -157,9 +160,6 @@ class SessionDB(DDB): | ||||||
|             + [self.table_schema(sheet) for sheet in self.sheets] |             + [self.table_schema(sheet) for sheet in self.sheets] | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def query(self, sql, *args, **kwargs): |  | ||||||
|         return self.db.query(sql, *args, **kwargs) |  | ||||||
| 
 |  | ||||||
|     def query_prompt(self, user_prompt: str) -> str: |     def query_prompt(self, user_prompt: str) -> str: | ||||||
|         query = ( |         query = ( | ||||||
|             f"The following sentence is the description of a query that " |             f"The following sentence is the description of a query that " | ||||||
|  |  | ||||||
|  | @ -1,8 +1,14 @@ | ||||||
| from enum import StrEnum | from enum import StrEnum | ||||||
|  | from pathlib import Path | ||||||
| 
 | 
 | ||||||
|  | from anyio import open_file | ||||||
| from langchain_core.prompts import PromptTemplate | from langchain_core.prompts import PromptTemplate | ||||||
| from langchain_fireworks import Fireworks | from langchain_fireworks import Fireworks | ||||||
| 
 | 
 | ||||||
|  | import hellocomputer | ||||||
|  | 
 | ||||||
|  | PROMPT_DIR = Path(hellocomputer.__file__).parent / "prompts" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class AvailableModels(StrEnum): | class AvailableModels(StrEnum): | ||||||
|     llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" |     llama3_70b = "accounts/fireworks/models/llama-v3-70b-instruct" | ||||||
|  | @ -12,23 +18,19 @@ class AvailableModels(StrEnum): | ||||||
|     firefunction_2 = "accounts/fireworks/models/firefunction-v2" |     firefunction_2 = "accounts/fireworks/models/firefunction-v2" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| general_prompt = """ | class Prompts: | ||||||
| You're a helpful assistant. Perform the following tasks: |     @classmethod | ||||||
|  |     async def getter(cls, name): | ||||||
|  |         async with await open_file(PROMPT_DIR / f"{name}.md") as f: | ||||||
|  |             return await f.read() | ||||||
| 
 | 
 | ||||||
| ---- |     @classmethod | ||||||
| {query} |     async def general(cls): | ||||||
| ---- |         return await cls.getter("general_prompt") | ||||||
| """ |  | ||||||
| 
 | 
 | ||||||
| sql_prompt = """ |     @classmethod | ||||||
| You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: |     async def sql(cls): | ||||||
| 
 |         return await cls.getter("sql_prompt") | ||||||
| ---- |  | ||||||
| {query} |  | ||||||
| ---- |  | ||||||
| 
 |  | ||||||
| Return only the sql statement without any additional text. |  | ||||||
| """ |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Chat: | class Chat: | ||||||
|  | @ -59,14 +61,14 @@ class Chat: | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     async def eval(self, task): |     async def eval(self, task): | ||||||
|         prompt = PromptTemplate.from_template(general_prompt) |         prompt = PromptTemplate.from_template(await Prompts.general()) | ||||||
| 
 | 
 | ||||||
|         response = await self.model.ainvoke(prompt.format(query=task)) |         response = await self.model.ainvoke(prompt.format(query=task)) | ||||||
|         self.responses.append(response) |         self.responses.append(response) | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|     async def sql_eval(self, question): |     async def sql_eval(self, question): | ||||||
|         prompt = PromptTemplate.from_template(sql_prompt) |         prompt = PromptTemplate.from_template(await Prompts.sql()) | ||||||
| 
 | 
 | ||||||
|         response = await self.model.ainvoke(prompt.format(query=question)) |         response = await self.model.ainvoke(prompt.format(query=question)) | ||||||
|         self.responses.append(response) |         self.responses.append(response) | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								src/hellocomputer/prompts/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								src/hellocomputer/prompts/README.md
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1 @@ | ||||||
|  | Storage for separate prompts | ||||||
							
								
								
									
										3
									
								
								src/hellocomputer/prompts/general_prompt.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								src/hellocomputer/prompts/general_prompt.md
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,3 @@ | ||||||
|  | You're a helpful assistant. Perform the following tasks: | ||||||
|  | 
 | ||||||
|  | * {query} | ||||||
							
								
								
									
										5
									
								
								src/hellocomputer/prompts/sql_prompt.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								src/hellocomputer/prompts/sql_prompt.md
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,5 @@ | ||||||
|  | You're a SQL expert. Write a query using the duckdb dialect. The goal of the query is the following: | ||||||
|  | 
 | ||||||
|  | * {query} | ||||||
|  | 
 | ||||||
|  | Return only the sql statement without any additional text. | ||||||
|  | @ -8,8 +8,8 @@ from starlette.requests import Request | ||||||
| from hellocomputer.db import StorageEngines | from hellocomputer.db import StorageEngines | ||||||
| from hellocomputer.db.users import OwnershipDB | from hellocomputer.db.users import OwnershipDB | ||||||
| 
 | 
 | ||||||
| from ..config import settings |  | ||||||
| from ..auth import get_user_email | from ..auth import get_user_email | ||||||
|  | from ..config import settings | ||||||
| 
 | 
 | ||||||
| # Scheme for the Authorization header | # Scheme for the Authorization header | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										15
									
								
								test/test_prompts.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								test/test_prompts.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,15 @@ | ||||||
|  | import pytest | ||||||
|  | from hellocomputer.models import Prompts | ||||||
|  | from langchain.prompts import PromptTemplate | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_get_general_prompt(): | ||||||
|  |     general: str = await Prompts.general() | ||||||
|  |     assert general.startswith("You're a helpful assistant") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_general_templated(): | ||||||
|  |     prompt = PromptTemplate.from_template(await Prompts.general()) | ||||||
|  |     assert "Do as I say" in prompt.format(query="Do as I say") | ||||||
		Loading…
	
		Reference in a new issue