diff --git a/src/hellocomputer/main.py b/src/hellocomputer/main.py index fd54fc9..86f90c0 100644 --- a/src/hellocomputer/main.py +++ b/src/hellocomputer/main.py @@ -6,7 +6,7 @@ from pydantic import BaseModel import hellocomputer -from .routers import files, sessions +from .routers import files, sessions, analysis static_path = Path(hellocomputer.__file__).parent / "static" @@ -42,6 +42,7 @@ def get_health() -> HealthCheck: app.include_router(sessions.router) app.include_router(files.router) +app.include_router(analysis.router) app.mount( "/", StaticFiles(directory=static_path, html=True, packages=["bootstrap4"]), diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index b403d6f..07ded5a 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -14,12 +14,13 @@ router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) async def query(sid: str = "", q: str = "") -> str: + print(q) query = f"Write a query that {q} in the current database" chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) db = ( DDB() - .gcs_secret(settings.gcs_secret, settings.gcs_secret) + .gcs_secret(settings.gcs_access, settings.gcs_secret) .load_folder_gcs(settings.gcs_bucketname, sid) ) @@ -32,7 +33,11 @@ async def query(sid: str = "", q: str = "") -> str: ] ) + print(prompt) + chat = await chat.eval("You're an expert sql developer", prompt) query = extract_code_block(chat.last_response_content()) + result = str(db.query(query)) + print(result) - return str(db.query(query)) + return result diff --git a/src/hellocomputer/routers/files.py b/src/hellocomputer/routers/files.py index 5787c5d..5000d9d 100644 --- a/src/hellocomputer/routers/files.py +++ b/src/hellocomputer/routers/files.py @@ -25,11 +25,9 @@ async def upload_file(file: UploadFile = File(...), sid: str = ""): await f.write(content) await f.flush() - gcs.makedir(f"{settings.gcs_bucketname}/{sid}") - print("successfully created directory") ( DDB() - .gcs_secret(settings.gcs_secret, settings.gcs_secret) + .gcs_secret(settings.gcs_access, settings.gcs_secret) .load_metadata(f.name) .dump_gcs(settings.gcs_bucketname, sid) ) diff --git a/src/hellocomputer/static/script.js b/src/hellocomputer/static/script.js index 9693814..01a941a 100644 --- a/src/hellocomputer/static/script.js +++ b/src/hellocomputer/static/script.js @@ -28,27 +28,27 @@ textarea.addEventListener('input', function () { }); // Function to fetch response -async function fetchResponse(message) { +async function fetchResponse(message, newMessage) { try { - const response = await fetch('/greetings'); + const response = await fetch('/query?sid=' + sessionStorage.getItem('helloComputerSession') + '&q=' + message); if (!response.ok) { throw new Error('Network response was not ok ' + response.statusText); } const data = await response.text(); // Hide spinner and display result - message.innerHTML = '
' + data + '
'; + newMessage.innerHTML = '
' + data + '
'; } catch (error) { - message.innerHTML = '' + 'Error: ' + error.message; + newMessage.innerHTML = '' + 'Error: ' + error.message; } } -function addAIMessage() { +function addAIMessage(messageContent) { const newMessage = document.createElement('div'); newMessage.classList.add('message', 'bg-white', 'p-2', 'mb-2', 'rounded'); newMessage.innerHTML = '
'; chatMessages.prepend(newMessage); // Add new message at the top - fetchResponse(newMessage); + fetchResponse(messageContent, newMessage); } function addAIManualMessage(m) { @@ -68,7 +68,7 @@ function addUserMessage() { textarea.value = ''; // Clear the textarea textarea.style.height = 'auto'; // Reset the textarea height textarea.style.overflowY = 'hidden'; - addAIMessage(); + addAIMessage(messageContent); } }; diff --git a/test/data/TestExcelHelloComputer.xlsx b/test/data/TestExcelHelloComputer.xlsx index c9e1c31..5568616 100644 Binary files a/test/data/TestExcelHelloComputer.xlsx and b/test/data/TestExcelHelloComputer.xlsx differ diff --git a/test/data/~$TestExcelHelloComputer.xlsx b/test/data/~$TestExcelHelloComputer.xlsx new file mode 100644 index 0000000..0715d93 Binary files /dev/null and b/test/data/~$TestExcelHelloComputer.xlsx differ diff --git a/test/test_data.py b/test/test_data.py index 6f750f1..641c15b 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -23,7 +23,7 @@ def test_load(): assert db.sheets == ("answers",) results = db.query("select * from answers").fetchall() - assert len(results) == 2 + assert len(results) == 6 def test_load_description():