From e8c7600ed7caf0770c6d9b2a7c4506b3b925eff8 Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Sun, 16 Jun 2024 08:56:45 +0200 Subject: [PATCH] Better tests and docs --- requirements.in | 4 ++- src/hellocomputer/db.py | 29 +++++++++++------ src/hellocomputer/routers/analysis.py | 6 ++-- src/hellocomputer/sessions.py | 2 +- src/hellocomputer/static/about.html | 2 +- src/hellocomputer/static/index.html | 6 ++-- src/hellocomputer/static/script.js | 30 ++++++++++++------ .../templates/TestExcelHelloComputer.xlsx | Bin 0 -> 10937 bytes src/hellocomputer/users.py | 4 +-- test/test_query.py | 24 +++++++++++++- 10 files changed, 78 insertions(+), 29 deletions(-) create mode 100644 src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx diff --git a/requirements.in b/requirements.in index 640ff7d..9521c80 100644 --- a/requirements.in +++ b/requirements.in @@ -6,9 +6,11 @@ pydantic-settings s3fs aiofiles duckdb +duckdb-engine polars pyarrow pyjwt[crypto] python-multipart authlib -itsdangerous \ No newline at end of file +itsdangerous +sqlalchemy \ No newline at end of file diff --git a/src/hellocomputer/db.py b/src/hellocomputer/db.py index 3630616..528ac00 100644 --- a/src/hellocomputer/db.py +++ b/src/hellocomputer/db.py @@ -1,8 +1,7 @@ from enum import StrEnum +from sqlalchemy import create_engine, text from pathlib import Path -import duckdb - class StorageEngines(StrEnum): local = "Local" @@ -19,11 +18,13 @@ class DDB: bucket: str | None = None, **kwargs, ): - self.db = duckdb.connect() - self.db.install_extension("spatial") - self.db.install_extension("httpfs") - self.db.load_extension("spatial") - self.db.load_extension("httpfs") + self.engine = create_engine( + "duckdb:///:memory:", + connect_args={ + "preload_extensions": ["https", "spatial"], + "config": {"memory_limit": "300mb"}, + }, + ) self.sheets = tuple() self.loaded = False @@ -35,12 +36,18 @@ class DDB: bucket is not None, ) ): - self.db.sql(f""" + with self.engine.connect() as conn: + conn.execute( + text( + f""" CREATE SECRET ( TYPE GCS, KEY_ID '{gcs_access}', SECRET '{gcs_secret}') - """) + """ + ) + ) + self.path_prefix = f"gcs://{bucket}" else: raise ValueError( @@ -55,3 +62,7 @@ class DDB: raise ValueError( "With local storage you need to provide the path keyword argument" ) + + @property + def db(self): + return self.engine.raw_connection() diff --git a/src/hellocomputer/routers/analysis.py b/src/hellocomputer/routers/analysis.py index 13e9331..ab39b98 100644 --- a/src/hellocomputer/routers/analysis.py +++ b/src/hellocomputer/routers/analysis.py @@ -1,6 +1,7 @@ from fastapi import APIRouter from fastapi.responses import PlainTextResponse + from hellocomputer.db import StorageEngines from hellocomputer.extraction import extract_code_block from hellocomputer.sessions import SessionDB @@ -13,7 +14,7 @@ router = APIRouter() @router.get("/query", response_class=PlainTextResponse, tags=["queries"]) async def query(sid: str = "", q: str = "") -> str: - chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) + llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) db = SessionDB( StorageEngines.gcs, gcs_access=settings.gcs_access, @@ -22,9 +23,8 @@ async def query(sid: str = "", q: str = "") -> str: sid=sid, ).load_folder() - chat = await chat.eval("You're an expert sql developer", db.query_prompt(q)) + chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) query = extract_code_block(chat.last_response_content()) result = str(db.query(query)) - print(result) return result diff --git a/src/hellocomputer/sessions.py b/src/hellocomputer/sessions.py index bb1c0d9..dd96104 100644 --- a/src/hellocomputer/sessions.py +++ b/src/hellocomputer/sessions.py @@ -149,7 +149,7 @@ class SessionDB(DDB): ) @property - def schema(self): + def schema(self) -> str: return os.linesep.join( [ "The schema of the database is the following:", diff --git a/src/hellocomputer/static/about.html b/src/hellocomputer/static/about.html index 04ef787..ae7ccdf 100644 --- a/src/hellocomputer/static/about.html +++ b/src/hellocomputer/static/about.html @@ -37,7 +37,7 @@

Hola, computer! is a web assistant that allows you to query excel files using natural language. It may not be as powerful as Excel, but it has an efficient query backend that can process your data faster - and more efficiently than Excel. + than Excel.

diff --git a/src/hellocomputer/static/index.html b/src/hellocomputer/static/index.html index 30bbe4f..1e1c522 100644 --- a/src/hellocomputer/static/index.html +++ b/src/hellocomputer/static/index.html @@ -25,9 +25,11 @@

How to - + File templates - + About Config diff --git a/src/hellocomputer/static/script.js b/src/hellocomputer/static/script.js index 81f15d7..3c3bc97 100644 --- a/src/hellocomputer/static/script.js +++ b/src/hellocomputer/static/script.js @@ -71,17 +71,27 @@ function addAIManualMessage(m) { chatMessages.prepend(newMessage); // Add new message at the top } +function addUserMessageBlock(messageContent) { + const newMessage = document.createElement('div'); + newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); + newMessage.textContent = messageContent; + chatMessages.prepend(newMessage); // Add new message at the top + textarea.value = ''; // Clear the textarea + textarea.style.height = 'auto'; // Reset the textarea height + textarea.style.overflowY = 'hidden'; +}; + function addUserMessage() { const messageContent = textarea.value.trim(); - if (messageContent) { - const newMessage = document.createElement('div'); - newMessage.classList.add('message', 'bg-light', 'p-2', 'mb-2', 'rounded'); - newMessage.textContent = messageContent; - chatMessages.prepend(newMessage); // Add new message at the top - textarea.value = ''; // Clear the textarea - textarea.style.height = 'auto'; // Reset the textarea height - textarea.style.overflowY = 'hidden'; - addAIMessage(messageContent); + if (sessionStorage.getItem("helloComputerSessionLoaded") == 'false') { + textarea.value = ''; + addAIManualMessage('Please upload a data file or select a session first!'); + } + else { + if (messageContent) { + addUserMessageBlock(messageContent); + addAIMessage(messageContent); + } } }; @@ -104,6 +114,7 @@ document.addEventListener("DOMContentLoaded", function () { try { const session_response = await fetch('/new_session'); sessionStorage.setItem("helloComputerSession", JSON.parse(await session_response.text())); + sessionStorage.setItem("helloComputerSessionLoaded", false); const response = await fetch('/greetings?sid=' + sessionStorage.getItem('helloComputerSession')); @@ -155,6 +166,7 @@ document.addEventListener("DOMContentLoaded", function () { const data = await response.text(); uploadResultDiv.textContent = 'Upload successful: ' + JSON.parse(data)['message']; + sessionStorage.setItem("helloComputerSessionLoaded", true); addAIManualMessage('File uploaded and processed!'); } catch (error) { diff --git a/src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx b/src/hellocomputer/static/templates/TestExcelHelloComputer.xlsx new file mode 100644 index 0000000000000000000000000000000000000000..556861628c906a6b96c15c7ad9a62ad73200d09b GIT binary patch literal 10937 zcmeHtWmFv7wsqs~5Ijh5cM0w;!JQx-XsmH}4;l#W?(Xg$JRw+cPYCV=`1<7D`%X^I zc|X5j@9rL>x@&aJHL7;axz^fyEmZ{=SX=-+01*HHPyk9&O261a0RV|`000gE5n5jY zZ08KLb2irSum?IBvbfvYkmtcd)8zo5A=m$R{2!iyfrJlseQcOg*NJx+am|{huU~26 z1xS&7cvgA=W8g*p%y-m`?uQR8d|8>i7-E5)DRoxswfko3hP$m&ZD9CPd)j6BbH*OI z>W1|BL!Uc?)mrHH_VNVBId6zXd7c{@qZOwEbX1%A)41tFye41qyoS>f_Rr5ujn$&^ zjCOh+Y!J?WB3AA8q0kMySGr#&6xBx?n>6ZzaP4ER`O$r*wQ(01BV~}X&`QY6>4hgTp&LKKKh zjU9nDPOL0{9RDlN|HGR6<E66s6#sh8ou`#sNYz{8J9ZF0 zc28*r^PE=iA$$MWdPYr6nmQy6CUf`pkPoh*qd>|D$L~w3wD%rc2M0#6jvec0-Zguz zklCd66fgGVYTX@V7XB_7tycKCTo?d=2nhhdfL!5j!|Dchv^E8Ut^bH>6`F?Nd_G(s zW6%@KX6Di^qrbciW_XjDK-tSw*XvTp9waR&i)fIV=K0!QT5v;Vg;wgM-zNvs8f)rs>vEBw?;yKak6JEjMwJ#x^9%`F(SavO#i>+! zF+WhV>d-`VnN%(X$ZDy>U9?A3w_{b(D+gGYcpuJdQKWH;GPzI8odc|hl&}n|^e7JU z;1~5b3H6bT$$Vk^U&jjc&dz9|u@>wPsAWQNEyW4Xs<2UR8zMIi^YfvNO2rE~81}<3 z9{9aW^y@h@?nQhAazTG4>cAE|6As@cZCTtMyp5kfw6ib_s+!LR6z)O$N z3c$E-^PO{oxoW^(x8F*7SaurCsqHb&!IRdLbe>sJ)gT%+olDEX9I@6j z(oNyt!45(zB-|TyeVVg>z9cTay)l>`KRrm*ND}X%^{k~S@osWp%yBst9w&P||Q+;v-@O&uo0mOV8qIP`1fNgv3rCd(SW!oM$`>by-+((*0 zBeGyhR)xHsPBd^z-V!Hxb7Nd!0K+h7$Ol;ufBNMe2LJ(fIak_T-({|J1G#D^2jLS9 z4oTayNzy6Dpy?*)`P$0T^u<=*(R#s4XV9fJ^&M&0w`swdo32?VMus_q*-flf$7Y+< zS@aG*Dge}4v$;B}Y%L7#3Sp|)nS@PY>9L*>k5dQIXi+P#7j^j#tB898oaGH~epy)4 zJgJDFjAzDo3yk+4RtCCij+c1aCY)EVpB}q=SaZ7G(85n0DWhiic4Zl76L=*~ksVnS z!_&@#k}{rWe_5q@8b9wIW}fz=an`mMLg$gnS7PacX!h@FmM<+b`9tS2e<&J&2nA8| zzhdC8n*L8bgo322kYM?LdvwH4SPihD1|NrX11OVNgO*xva``*oAO`r^G2)``k)Mur3s%Jor_4#^N%g8k$a^HaQ~z00a<9{5uXfSptF1POLw_*#8I!8Sj_vw#XszQ~Rp``4Oqj zup9}aa-)O3?Xgz5kO7Wh%qzYYQDwdF?Z*qRX60MoG6WEadUJ~WV0s>_7bnEWT1gM$ zh@gjhVZ%I8RM73)CCN`MY_IMFL`%4U%@SrI>taIxZ0m!9iHIG=3L+NQX4{oQNxP5x z-m|W3mbnX6Y}fC-=Re<1QK>^TtX|a0b)iW2MF@47OReWUo&Xw61_e4UhE6vaxW~%H zeAEj%YHDY})^ZEiQK0Sz`K#qgowl)mK5Z0V8K&9_Bq43Wy}yZnqtiWFP-y>onI`Jp z_$2q*Sav=C)~0;hoF6{TXi8A$A=wFOqHk4!xg)UfxaO4B5B zIi^&#DW52olr38PCr>PKxwg7!0$cFf3|BY~HxQwt+x=W7#E0M`Z<)55^A!76hCW$y zL`V{go9RXi7 zl_*S|@>Ud=_t3i`t`YL&`9#3@qZ`J%6}%JOK-k$AP|wK$y0rFUXg8?5Jr)7(UU0hp z0CP7u?@;2A7}zra)W}PyCR8rJ;!ET_(rp<5I%rh#cJR8y7qQw!C_~=~)}N2`^H4`G ztEh6R#3Vm8&Zspar^HyR$uZ5TtwuMdIOSifaE8*8S0s-tcS}uu+sbAnZDlHNg%xg5 zQSdG)+ZcLle|d2~c@H0mZ&6T`wD5=f`@0uCy6B)9fM@{-S!Db#THyG%7GjpzAX;F& zBxPDdpbePD#L{xs{nT0R*}*dzth3-?mHxi0>plhtDzQb3ac5moa81PQDN7`?ffAuc zIjyN+l&%O?&xB#oPT}r+_~jVALsm1Mc2P)tR_{#l;4sE{eyzYfz&$Ydb7I437bAgF z9@NCji1YMj_L9zokna++Qzo3zrJP@)8cG^>VgbGQB|RTB9<3gB^miM+-sd~ zYx6x8F6QUGAUn5v7Hny^@C8}wht{y4TCk1YArA4$jz_}ZHb0d=eQs=icas^P)JJ4t zB_hs2Y}Ez!8+sEJ$#~H96vj3htjjl26Tqu+*A*XsI3^FSITe6tfr%-Z`KpfLU`_6& z?$9VxjM16H^Fr45Ok>*Rz=TSR+hUuffp3BjGWh56o9JOx;UR4W9|~?V)S6PRA8)BQ z_8*CHrg+!H0#5vikf?jJh|OG0HW@=rsoH$S$sIR#pt8hD<0%tcL^_Aad#aElZ*9F_ z>0gIFT8)MZxbu8%uJ8_uceiIKxiz{4@7};|H=Pn>nr`*-OSu>Y;8^N10Y6<+M2XV~ zzfIcM&)vEKmH3Djozmq?9%l01Bm^L>FP8&G}P!=WXBod>nfrWa5?1tQJttmkR{#^n3pHWCx6=9{v?=_ z6Oa`bpQ*$l8~H*NL$!clgteZtX@<4Fenus9j6Yx;CLLMZ1X(*5qXEZ`eVW5Ll35W+ z2lW#S?`Ob-={t{-XT(H#ixsga!oS|h|8M4By>1X|fkXf-%0KqGentUjOQ0=~_2=qZ!P5*;`EJihN`!arTidv)Zwk!ZLrE0+*zG-)rY*3O6vg}nDkApjaQ zqY$v_Nf!Zc=?~w!9n1duh0;^Fprdr!MnpvCqiDRT8!P9hW1+|0&f#I+Sf}gvn3^2-ux8|E>DDzfwXI0rtC$6K*Sn=p8BDdg`|wxNzZ z`&k2|fwyG3`g?ywy5b zS*9anZ&6D41j9))*7JG1a}y6&%gZlRxZ0pPx^U(8FWPd%X^XsED) z(~|4KHN{LJ=WsT+{Ufp!Pgg;AQ$PPc*S*QgELb%4@$QH-Mj+p$zmHfo6nwbu<-2hF ziKzn$c(aS%CH%FqFD`^PCwykx8-K;7ae0EIGlYpxz`0LgqG*k1_XHYlrY4$2XfTl0 z%D7_ih7jKjU`gH|K_gqKUf07JzG+m}A}PV39y0X6Mj{EY3;P-u!L&X~#ED;}VuK@$_YP># zb6;-pY`$K$0@r2y2U`U5I{RM5RLDmF#-#x@ii+0gtWc{?qC+fum;HoL&)c2#8HSJx zc|!&StXj(~jtsgEDx_DkPeW5k4uUHd33{$VMV!6V#&j2iQa*~AG^x|2c(NhW6w$q` z*7#!fjllG+;*zP3?|ebRk(J>z1&;wiwPd-p>7LpA*>8x;CbLh5MtdJg#+AY7@Og;E|u_L4!L>m zcKWk$c}DZ8k@({cr;e^9_s@K9YZr=dVLR+Xjx>5qomJfj-s}O_kwVTeHU~ioGT6=u zfnE(N-$22YoeFCTZXH!2JKemxW;HC?E6#?2sI9fAR9C!B6R@1tft0UlQu1DCx4~h3 zcnAq!)s9{mqi?Rkx7~oJi>mOXEgJn6k&VLy+x^4CDjP%VvIVbOp|hT~n#8Ew@oK|2 zq;UeXI$j-mdDL#RjoH>tf%y2Tt6{M@35=4q9Jt$*;eCP2R8l@R`VT8PQ4qOs3Ntq$;N5^J3 zUO2AFN-f_S+!RO1^Pu6V)C>}?+%yxOTFNZy-nO7g(81gCg&|T|E)V5MK(ebnd}7NBy_P{oaFzB9Hsn+2tj_oP9pkB=IKF+KP+ zk^aOS7ToXQoOFLAf@jZr=~8T8Dy5QcT2F*bRCyJ4(x^;1?+BwXmKsRCkNw#lF>+FZ zGsCqL?jm}FLz%+8!1ub(1}Th{le!M?iKmnrPYta$iX$?;;{7dcQ3T?E<(JEgb8vqE zHO*V|?Uyyqex=*uo$J1!?VfK<$)l!XZ^{m@XTzJET~X+T6`oBRN8c<5HBuP|v`^*T zQm54=x<^qA5LJqLxhYeVkin*w@8M4NH>FmDcUudnXFXI(R8f+dB}wzGsf{fP2Bo6` z6^6gq+N!pF$Rv+g@kJFYqI_}=AP{W*#8YFN1zqz94$rytqeZMfcImz&}@>Hu5s zJtx20#FN67Iiop!z zm@qSqT{@>V`FF#HJ-BL6w9+6-_&}4*LVmE8mB8*`Xk4j;1 z8Jv6BZ=5W});sdDjDexSlyH?ZvidBROzqoMppYGb6C!!S#K}f0M< zL#tp_ZS%31UlB(uMd>+p6r6Xp$(+qvmRq|)wMfo$vo2b_2)=+iTYi@p9DEx2 zMgLs70E^vbxYL_Y#kw*>s&JJ7!4&$Hap2DTvoa z=w>6aGvlSY204DL(+2y$5yF%%XU2~yUrRt>E59;0nf~_uhxKs*-wL!#wT@DAAG1n2 zp>VZGeFPZj9dXy^w&a2dIY0I>mNfH2Fn0&%rfx(wg<9MUfB*5(*+GAxmnuUgaE4kT z5X4~&JG%d-O9%HO8#>q*{mrBRjWRD~E?nNsIC6@@?z=;Z!Qpz=ewh5?wNAtwp>5&l z;SCq;5HSZr>0V1k+Fo<-R9GqO!N^lf-*ntjf3LHlQQUaCuNE1!+2bV1iuNv_zLrc+ zqfiFihEa{YLUMe38uk0R6#l(L#NWfQrU?0`4?=$2e-q!*#1Uw&?(7J%v-rt>=XfRM zem2aIYuPU*+b*RaDJewd+5ojO2{0+lw+mBS$GX{QjdT_li@n5)^jsU%`#X{Zit5 zIu9R+lgG)ltOK%okrrV1w>7ti+F?p7%thIzAGu_s`CG+4w3=r0zwZdD0Ve`p?6 zgMuhtpU>q4@9C;htrJmDx}f0%D|JU9d?DryYm_A>-l;WRFjSUZ8rGkRIS&qGWEHR+ znI{h8G)Wp(uCiqk^X>P87tJD_#^ug9b<**da9G74%ZWCzRuIh1tF(G%Pj-$PKiXx8 z+C_Bt$T4$xQ;IYfG;wn-*2ds#0r#)~N|y}Vqj_UCwZy>8|;`QSKd zyR}%$FvCk{C>!^B?UD_;mx@&~YrWZKLy>(p6j@n;p1hZK+C6D?ZLlIHMdPH*9Bm9Y zCypA4wDx7MfZ5k@%n4W3SKsx*YBsgXB8a>+_)F>U^Z4KOoYNFQOYn0iaETTly}gOr zu8{fcrP&0SL`+Z(U`xBHQt|DC*C>1qh1>LyQp?Ph+cmz1T}_oqRN z1q&vou?cUba>VDxSLCf7xtl~$KnMpK?2SF~MV&F4JEW`9*R!UtJIBu%qY;KwO+99m zLam$)E`M|v0i9vUgpOh|0s2lKU83E3&JwRy+i~f&GFF1l176H6&Ik1m?YE92?mb&k zIP-33ojLzxL(0{XL~8?wmOBH6a|&aN0>cQ$aSK&@I+LhNKQ^P*+HB<1im1u4N+Z zq`GVnp?dF^(F}FYlnQG^FD7H|W8(fe!5F{5&JMlpY{$jBiZ>U_6U!oy_wx6ROC*Pf zkPZl&=l<2E^ox6SadHOR{+H1Ip)&x`kf0E`%m#qeeI60=+ds!MiI^WUr4RRsH#1Rv z&rK%(R?EH+Uz^a7d(o|Zz*sOH85t{XdNIMK=<-x}ocEr{&{_%YgcLfCqi^hU$LIC_ zBdllSj#C z;rRut739XI%>~9Svq05nk@@FMVf|$MC+gLLx25s35h$7WEdm%&Zm%A&MP%BpD=T{E zUE1`j)f2?)Oc=wljKq^_Nb;a5>Zr>{6|iNb^KO^I}g`RrCsjfM>SNH zOeT?nxaiduwgp2Bcq7&m91{;7hE9sqIez$^0eQV$eGP6-OUiCMt2Wncp5#&$3%@XQ z?;5SBVoTDY9=vYqZ;_OHTdz{D`myxW)EL>JZNoAdGwBXf&Vk1Cb?!^v#j1*#O4LiR z2QP{hPb*TTwn}papkHCa1fFK0H}};0)(d-w11ra EKLePf?*IS* literal 0 HcmV?d00001 diff --git a/src/hellocomputer/users.py b/src/hellocomputer/users.py index 2f65e94..a5cf93d 100644 --- a/src/hellocomputer/users.py +++ b/src/hellocomputer/users.py @@ -99,10 +99,10 @@ class OwnershipDB(DDB): FROM '{self.path_prefix}/*.csv' WHERE - email = '{user_email} + email = '{user_email}' ORDER BY timestamp ASC - LIMIT 10' + LIMIT 10 """) .pl() .to_series() diff --git a/test/test_query.py b/test/test_query.py index d4d56a2..65b306d 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -8,12 +8,15 @@ from hellocomputer.extraction import extract_code_block from hellocomputer.models import Chat from hellocomputer.sessions import SessionDB +TEST_STORAGE = StorageEngines.local +TEST_OUTPUT_FOLDER = Path(hellocomputer.__file__).parents[2] / "test" / "output" TEST_XLS_PATH = ( Path(hellocomputer.__file__).parents[2] / "test" / "data" / "TestExcelHelloComputer.xlsx" ) +SID = "test" @pytest.mark.asyncio @@ -35,9 +38,28 @@ async def test_simple_data_query(): chat = Chat(api_key=settings.anyscale_api_key, temperature=0.5) db = SessionDB( - storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent + storage_engine=StorageEngines.local, path=TEST_XLS_PATH.parent, sid=SID ).load_xls(TEST_XLS_PATH) chat = await chat.eval("You're an expert sql developer", db.query_prompt(query)) query = extract_code_block(chat.last_response_content()) assert query.startswith("SELECT") + + +@pytest.mark.asyncio +@pytest.mark.skipif( + settings.anyscale_api_key == "Awesome API", reason="API Key not set" +) +async def test_data_query(): + q = "find the average score of all the sudents" + + llm = Chat(api_key=settings.anyscale_api_key, temperature=0.5) + db = SessionDB( + storage_engine=TEST_STORAGE, path=TEST_OUTPUT_FOLDER, sid="test" + ).load_folder() + + chat = await llm.eval("You're a DUCKDB expert", db.query_prompt(q)) + query = extract_code_block(chat.last_response_content()) + result = db.query(query).pl() + + assert result.shape[0] == 1