From e43260704ab3071f5bfc110fe155e13dda361cf7 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 4 Jul 2023 16:38:53 +0800 Subject: [PATCH] HuggingFace support and June updates (#2) * Improve SQL select experience implementation * support pgvector * Feature: Add huggingface model/hyperparameters space suggestion * Improve packaging * Minor fix * Generalize prompt template and knowledge * Minor fix --- .github/workflows/ci.yml | 8 +- README.md | 3 +- assets/mlcopilot.db | Bin 3985408 -> 5976064 bytes mlcopilot/.env.template | 16 +++- mlcopilot/constants.py | 36 +++++++++ mlcopilot/experience.py | 34 ++++---- mlcopilot/knowledge.py | 98 +++++++++++++++++++---- mlcopilot/orm.py | 158 +++++++++++++++++++++++++------------ mlcopilot/space.py | 7 +- mlcopilot/suggest.py | 165 ++++++++++++++++----------------------- mlcopilot/utils.py | 32 ++++++-- pyproject.toml | 13 ++- requirements.txt | 17 ---- test/llm.py | 2 +- test/test_experience.py | 40 +++++++--- test/test_knowledge.py | 9 ++- test/test_orm.py | 45 ----------- test/test_suggest.py | 39 ++++----- 18 files changed, 421 insertions(+), 301 deletions(-) delete mode 100644 requirements.txt delete mode 100644 test/test_orm.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a8de671..107b4e4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: Python CI on: push: - branches: [ demo_orphan ] + branches: [ main, dev ] pull_request_target: - branches: [ demo_orphan ] + branches: [ main, dev ] concurrency: group: ${{ format('ci-{0}', github.head_ref && format('pr-{0}', github.event.pull_request.number) || github.sha) }} @@ -32,7 +32,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -e .[dev] - name: Lint with flake8 run: flake8 @@ -61,7 +61,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt + pip install -e .[dev] - name: Test with pytest run: | pytest \ No newline at end of file diff --git a/README.md b/README.md index 581bd59..130d4bb 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,7 @@ MLCopilot is a tool to help you find the best models/hyperparametes for your tas 0. Clone this repo: `git clone REPO_URL; cd mlcopilot` 1. Put assets/mlcopilot.db in your home directory: `cp assets/mlcopilot.db ~/.mlcopilot/mlcopilot.db` 2. Install Python 3.8 or higher -3. Build: `hatch build`. (May need to install [hatch](https://hatch.pypa.io/latest/install/) first) -4. Install: `pip install ./dist/*.whl` +3. Install: `pip install .`. If you want to develop, use `pip install -e .[dev]` instead. ### Run diff --git a/assets/mlcopilot.db b/assets/mlcopilot.db index 02609083688352a1c9fb970541c2ca9069e413e9..a303cfd49f81b5be3e278bed1b3df808f65ae688 100644 GIT binary patch literal 5976064 zcmeEv2YeJo`+wS{-yJQ4&Jj8!;d>?;4p#Sge-K9XV{NCdGf8WgqW}cg!-JN}QcIGMHXQs#VozfjPMM_Ry zmerxq_*C%m_4R3`Q26-xg!}kNb|d%x1%n9i{y^?qihddy?h~8xg;f5tPgtnTM_x<* zj;xRDjtSG$F=Smr|8h7U0h#9l()FJUX-G1my+TrMUyD+ z6{4EH5E>X6Q``5YvUhM;?HQ$R5cZTCc&a^leeZ#$g^u8$$lA4iCvXmHVy4a8G;bc_ zxp2?cZ92u3ZDO)5DMVGGl->XxnXhoFw!cQD1exfIiWZ4{f=}A!v@TgPi zOXcMZOGwJePDxLF@;0{Njy!9E_squ{(XD%j&aJ!mQ*?~$SDKZjS_iV!)tg2u+IQ_9 z*CD=(=M?WlG)0@*$90eE(l)M#`_b6DDL(r6X~mCN_B^4u>`QIhA|x=fZEas4_gfv3 z37sOrn(xRFekXXjpP=$mp@~80X?49lEh9T;Sf(vG)t2CCOmRVQ*`82b?yczV3dO}` zp2rp#6sf{Jf`nES*D3C%#U(*cyGbzzgbrO?zwE;oR+I!q)~)MXG(_lf#kGowz_OjL zxSU-kJvj=WQO~&exbDS$MN#U7r-49T}@H@9#6OV(A_qhch#U|(AaX$DfXD%n z10n}R4u~8OIUsW2KaT?u%&T4n4t$@=VMr>4BB>aHq(Tr9c_0!QbLC6@*pW)|K|b=I z<=@F~LhC;(A0j^`e^b6+zC*r2zFfXQK3!fUA1NOsPeoPnPvn5e0g(eD2Sg5t91uAm zazNyO$N`Z9A_qhch#V+~1A+b(ebaI)h4@GMX5}Q?GAjoANBHWDMpFew81zb$+(Q^- zjL;iQTB)1R8^S$=K8z7MlR*{gA(SC*Laz&U6H0xMKp0J`KsTXRO5B9b6d(|Kjox1% zbZUcNkbgvo)?m^z|Nl7Y7^J?npY;FBPs>lpkH`m9l`oXfmCuk* zhFxH+e1zN~&yf$7+h8FWAnz@IUfxL_FK;8a$c=K1JXYRR-auYgUPB%suOJVT2g-eA z4`ugcKghn3-I86GeIh$AJ1cuvc2ag!c1Xs{UYG5ZZINw|t(Gm5Et1WX&6G`%jhDS7 z8!5||<;pT-DKe{UAYur*%R0-Rm9>?%l$m5&SqoV+SwopZR#R40R#6r%3zGRsA4z|f z-jRMQ{X%*}`l%ol85ByAxM78MRGR>$q(5` zzRg1Nbq10z2P3(aj^t(S??2Y75FC+(gBB4EyIJ+a+_dJr01v;Zv~Q`b&+(hgQQz+ zBwcDD=~NTRb2X4WTOCRJYDn5eB56|Ai4Gv?86!hHRW7{I?l5Rxe7@K-QZzYg>I*J9}Z8t8XE)qPp*qQLq-@;9KT z$A%vbe=mHcEKs%`dNGmaNUKQ)$Og#brJqYLNxMP&UIsmRrmVhnZ}|7&pFt~khA)u* z5dJ`3L;5iMqC7;pR=!y7C!a67BHsyZyS}u!yq|2Lyk+?2a9%2rO^|OC5C)^5N?sW{ehulHvWoINc{(fy-DEo1QRx}^3|XvfqwEXmHrWu_N8!7qa``@KJZua) z`Ongkvis5$*Z^7ru`ph$3AKe>4h;=F8u)pTG_Zc?k1>XsK1KEXuRc5K^dntU}*TF{t1(uAYCqZ+ezB)fSVKX|>zaQ__<>_vp~Rb@v#R@}-9UUCXc0 zHF2;l$NJ zc&?8|sT>vQuPIa7)9T2vC0a+u_$QQCuIO2EX+_&KYi6b`JJmLVHO+x2;m4n_b(VEx zPPU@EZD_h}*zo593L?v`);T}Zk=~_K2SpQsiix*n=h+lptd9IVYo?-;H9IvQkEZC5 zonp(gWhdF14r4qm?}j~x*=&xuq1McyJpu}vmRlvhQ?EEh6Wqd9=Hk5&Q=Vr>Z(%{4 z3U-Nnfq{xd+%XM3Bco}43;(2YFWI(Jhc0bndf>!NTTHv0ENgmpOxLbG947x^<<{@f zqi3m}*~5m=$}A*ct_{4;&S+|H?mwX1Gd&FLmY0*9pX3mFa2~#ZuyQMR?%u5}Sel-c zmt)UKao{!BZPvV`G(|USo;5cu51rrBeVY!6^6yaYHR7^z+U8_4R_8YG@BFkX?33w~ zo{fO^X8!%kt<<4&yF^9j9P|TEZ?I=a`@i7JqOh^FQ+iU2J=bPSN{h*}CFP`MrwiX; zr}V_U^d3TaR{wwk!|&Jb*`ad>u+^QlPRhy0h>#T1EiEV81_BQXE@$VXSmC~;_t7(1 zSO491w9U!NLND%}lb&SDQ2N_^^K2{(HZRG>-dap1n3~OooYoGEABuQgpo?RPHJHP@Eix`QHqL{@5j|Bk+noD5s`_W-;%yPyT^6Lcik3!-5!I!h#@1> zJz8PVDE&R>q$q?@Ceu?VE;|XX&pdn6L5=*MFaP3QvmH5IavV0hLKmw{2=%Y-tJ3O? zO3jOIQfE|ZUtq3OjcOosVVX=T4Tiq#xJql%>-%`F zr_}3u7n9mvp5r=oPw$mXJvg+Dl z>N|PM4C;7Zq3RxMw7mk&_0#v8KJZ+={&(#^K z30G69)sX^WQfaDs2tx#`sy0>eR8?1YA5dy4c?h+D1kmVIY5@sgG8vU}c8*phV-K&E z3ir~5vrOX#0BE!(b*OM%mEiuzSD+1cH%@8N2>ySKPOB08|EQ}K{Qn>{g8v`y$t3vy zH9Ec9|8Fws4a_K@kTd&=WOo_=|0Vf{^7rH?;r%}#-zVQC-vZzNN@(}<