From 6b0203168b2da133b2b91870e5d4e7e816ceb35a Mon Sep 17 00:00:00 2001 From: ych <1821947036@qq.com> Date: Sat, 23 Nov 2024 19:56:32 +0800 Subject: [PATCH 1/3] add merge tree function based on embedding --- README.md | 6 +- img.png => images/load_data.png | Bin images/show_columns.png | Bin 0 -> 6498 bytes .../zeppelin/iginx/IginxInterpreter8.java | 18 ++- .../zeppelin/iginx/util/EmbeddingUtils.java | 109 ++++++++++++++ .../apache/zeppelin/iginx/util/LLMUtils.java | 140 ++++++++++++++++++ .../zeppelin/iginx/util/MultiwayTree.java | 17 ++- .../apache/zeppelin/iginx/util/TreeNode.java | 28 ++-- .../algorithm/mergeforest/ChunkMerge.java | 103 +++++++++++++ .../mergeforest/MergeForestStrategy.java | 7 + .../algorithm/mergeforest/RandomMerge.java | 101 +++++++++++++ 11 files changed, 512 insertions(+), 17 deletions(-) rename img.png => images/load_data.png (100%) create mode 100644 images/show_columns.png create mode 100644 v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java create mode 100644 v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java create mode 100644 v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java create mode 100644 v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java create mode 100644 v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java diff --git a/README.md b/README.md index 8d14394..423cf41 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,11 @@ IGinX Zeppelin 解释器是需要连接 IGinX 的,如果我们重启了 IGinX #### 特别说明 ##### LOAD DATA 命令 输入LOAD DATA 命令后,需要先点击执行按钮调出文件选择控件,命令中的文件名(1处)与选择文件控件中的名称(2处,选择文件后自动填充,不可手动修改)需要保持一致。 -![img.png](img.png) +![img](./images/load_data.png) + +##### SHOW COLUMNS 命令 +展示可视化资产图会涉及到embedding,请前往 https://nlp.stanford.edu/projects/glove/ 下载相关模型,解压后放入 /resources/model 文件夹下,并根据模型的embedding维度修改embeddingUtils类中的EMBEDDING_DIMENSION参数。 +![img](./images/show_columns.png) ### 使用RESTful语句 diff --git a/img.png b/images/load_data.png similarity index 100% rename from img.png rename to images/load_data.png diff --git a/images/show_columns.png b/images/show_columns.png new file mode 100644 index 0000000000000000000000000000000000000000..9fe68cc84ce6169cbe818087e24eed5eb853eb55 GIT binary patch literal 6498 zcmcJUcTf|~m%u|21dWJvBcjrK54}h)p#(xFB1jDcq<0adccn=SO?m(+p*N+2^Z*i? zNbg850>}3^zx(Ir=H}*R?w{SAEpO-T+j*aTZzHrdm4T#>NdW)=P*p`i2LQlB-L^gM z;@_5@?<@!a06KP6g_kg|S38+@29#5c?}zu;eN5*R$qLsNPb6ss!wM`i0OTX#&H$jO zP-GBfo;EByqE9TCF;dVg`ZXko|8W>Wq*eOS`MasBJDvu|vhD_+^3UFmcQs8OzC4-V zGikCf^O@E4yvkQ$!vg?*(FH70-`14k*uwPc0C7Y*(n4gEdxJL*5~t$Yfu+nWtI>KLX0GKG-_@L!aZ5xx#jQ@& z#t2uRzx8s1+;JaDv`9kF-db3jehFrl0~Y?z zJUyP|GB6>j!s2YkUDu_NE_L`sy)@RvVSQ`HlV=x9+WCj0p)_ab%!WqUDca9D$Cs>_ zg0+9?wSdI5wAbzdrgW;C#~(LGJ8~}}FyifL?a5LWFu>K{@(g0RK%FjBmfWyg9D@sy zgS^J!4kx;FXO1DLe7EiNfetBT`rtuskLmB? zC>{)O360de@1fP4J1^~hw13@;S1lV&Vi{YYKJHps=gM2Kt2ul1?!;Z_#AVa1OdtmU zdHDdMT)bW}+?I8=nr=e&?mQ_Mm9dca`7rnG)X)Igngw=dTKSZP)%w-2@|%jM_eFkM zZcJrfboT3~OB_5?(l`~R8(rQu6rAJV9Ng}Rh%nLRwa%-YS_o);nXvt{-{%3q&jN$h zPaJd{I5ZwblbFmNj(jl`|D zngtnr-JRA6J|vQSOf^wLDZUvF12O$Uq0wuAc2<2TQgOq>q;@J&Ul$2;yFTPTfRP+@ zh0$0SPgfSyOuuMd{Hq?>h-1pTCz27Mp;#zkXz__nXsn)#uu7xDmGb!Oc-MmzLd|Ox zMKhpYbwyJ+uF7ru&?qu(*Ns{2vmC)RKh_gH`aOC>f9%#ubfu-WFOuhycQCokMT4(v z*eAq8)CTT!t%{Ei&{{T*cY6ei2vFtio+O?Jy`r{)Q^1Z2cL8%c<43~=wrDq}v)6CR zk<~W9H5~yBl|N5d)Kzec@tZ5&A#>wOe^@?w5b>rojXs z`jDo+%BD7tZ+rkvE^?OVYCn}}o=&Pwl7R;iS38#Yoqd&`tq6xHPYoc+p{gZmgdaYyAh z3z?87Plr6<*7Ye`NJvlj45ipaZ7#}{0$=uJwGzVQxjHl*8physRgu78Q@BAOC4vNh zpv5`%6s6{je9H2GIM}!Skwy_yte?(>Lx)4{$HL?e8V7iDD?esNNa4*a62p#+oY=P- zePcAimuOF;FH05<~JS zwp}3>p3+8tbD*92gsyYcZFi_{i4|Yf+Z^9$XF3IT^sXtDhri}kR*cRRnU7-eTO&S! zbCi!5E?YA({PFcQ%VS?5So*AA{AKKHR?Mm{NUqM_4}5Vw`MgauX}aF`$SS&V%F0UZ zU$&zN$*DCpi?b=+Axn*?9u>dqrrDfd6tE!a>ZW{^%*~*9p=1*ufqFNQTR^Nl_6e1G zzFrZhL`){+1az{+iM+jJK@P}XZz}NN$p`+*azVLe~@d5iA z3A9xH62dM8oMV>@2G%HzcUN(HG*A!==~!xBwdatc$E2`>jOfuZ7PL8cs1EF7bm!a> z1MElKdgO8jVKwLxv4+a6E89C6r-)m7r@_@ReC6j})b%lg#Nves{AQ40?UDV73bUuo z9-BsN?>do-+2Qo1mYHDb@#une7UaHAy^egDXje|Hy6%zPD~^lnzDz99$ZHN>x;Hfqr!X(iO zqJjtb#rYRHBf#?I78F$1|1j>^o&PiG{zoou!I37QwB5OY&wpH(7c{T->-LFfyKnhQ zx@C+67oPq}Si7NZN&fRFmD9lK>g+6~f0Ivy5b!?eHM|trqwG6EJGcY1ZoU{^A&H4u zxHez4Q#GOz*&HSaHE9WalDL!FfM3kFQZHvDTv{5Ymlz#qGnQM$ z?*$UMnZ=p@ISZN%w!g3MdSWExVO1Aul9Xh+2OtDo;~@sh!NF5zIPYcA)6Bx2#>-tz zO${l(zXoN2>8{*(7(7`XQdOj89D_ko!;14isNo)#e}`GgK{pfK0Xx_)u9~XLA_4JH zaym0KW-%y+YvbZcms0KLC33)eJ#WKW+(Wj9A+D}zpfH|SnA#As%^J}BSlZ?pxV>V?IQxzrRF=EaP4o!t)Zzyi+~sLMzNP<}&L~&j$Z)ZK(7p*5b6E zDv#%HH2j=tpeY)h+p{rKe#2X+ImJOigO9=63X{_@fr6X3XY3VAi)w#vImwhwj)78W zH5FLcU;;Ak1;+9s?F!wYyeQpK8iFK}uDegd_F7|F(w>morxM{8J0DIHyH-yB0n6(I zga3^-s9dM z%Lf*hvCTPh{xKfqDB0OrGg&TvG&n!AH-BT}5XYkzG=U0u8WZCWnV|z>z`$WBwnK9& z$a!Q|@a=&z#|5HD{F-(C?K!vNm3>>toYJ_EK%79&MCgH^mw9p;z5Q6#%rDST+n1tJ z4FPS?*H$)33{~ZJ?V}2+vRvF#4;;uhA6izLYU@J8mO4)Mt$aZ?Kk~ ziBGbbabB7bvS}Ksy~JU^X(4i?$DnyvEW4ZqG9+ z2xbBu`?o&a8Vu36dHBx^b>u(LTWv#-TNeQr{Fxv7dU!C7ThHWsV+)MpM?>kZm?xCh zUKT9LDF!q(XTGXkIspRC)N)G7!YgOgn)Q!tpeBPO)wII32B;iv`JNry?WTl$Ncz4V zb`^Jh_Q{CUyklXF2HRcJ@Sn6R^_p$P1DLVD#*}5qgyD`Z9XpF`I)pbC0zOT_-;u|} zY{LYQ(K~e;;Kd{GHHAzS%a>3v*4mmavAnJ3B;%j8c$h1hQlPnQi!>;|oZ%=^r{|^R zx87yn%@x}Z;!q36bT9h+ZnAn^Bq44mB_PKpyO>^Q36)+wdOoUef|!qtc+$iZAaT8t zBOnI8Gs0gb8LsSWn;*L1$yw0Ar6Dnf6q@whue(P_G7>9TO zvcmqkg6Z_(obTpb=~iK!kW<-A?Hab1412WWYqy5`?fqY1G_UYRs|^9|SR{L$Sy?`j zyrmpNeJ3PamU3vYZ=xp`k*B`gx>|jj)q2g{x+!yMX_5VndC1UT*T6ki@}1%+qwwb3 z*5#T^RSMbAmjDo6f_GgO%1iRp-h#kl@Qq2KURR8p`uc(H}V0<_D;p4A2IojR!+W$sk!`KhW^L8?}4 zN%x1jh)HR*K>_U|K7B9ajFQeXDx~D1_^}#(e6XNr#1Uzl>7zB;QQpYKZwkl~4bk5& z!TXXG2XepC{kDq~9OCr4BEJ(pY9515q%`{ke67*sCSO%oSDxmc_@I!@b$PFyFVGOJ z_eog#8hHRB2hWYT-SE z?cL0X8p2qjzHiyQY)dQ<1(MFRJfBDBme=hLjb>G~@IHRbYG7w9Wk%O_Q_R zW%NuOk0)k7!i^6Uu`G2@n0gtlz2}(%R4DhA@eoU70Yn&c{8VSB@aUh|GdU5a6-}HK zD2f6}iy_yuqwe^^G0Y#dE!6K44zL2(QCN1GX5`_`DWOl?C@}OalIQw%5VolR%4fuq z*6u@WjdVpZp6_w`A&OJDnH zm^;UmZfrBt>n~-Q%HU7Cb?o@U!fO!mjRQq^6i5k!vZ zRA|`m8VS3^6=?d1`94@;K=Ywxg~*7*AvhiD!q3717REKZ8`IgBa8`k&2EplZtyxBS z^?iX81lU>*u15I_-&6U~go?Wor57Xzu7Nl59KzGs5V#8Y*@xOq}6Kd10D42%$#7Rq!G>;%fOl1%qQRV&)ib9ydJ+UgaV zsd`o0WF=p^P1_}zSMAjIrss-)s3r`o~1X>S~9GG?_;-&Y;d~pXB8G<4oVi& zPJXnD2>?D9KQDrdfD<+>Jj|-Iz%S12PV!*Y@@KDUt#CXsunc1>kv^J zTCDec5k`*{tEw3ueC%;}yxJCwhR!axZ{i_1?-R1Po&%sQ3twKNOf1E{Bl4a%o#AzpFgui$BH#n&gX8jfu6C?$RzJHiA5o?s9D{*70|fS+fi=sCCAY)5lbO zOgdSZ=By$F%nIJ3gFWxl^$;$v*~w$>bs!JsFplZ}9Rj99$cu%LhUD~_%&eV0vXfT< zf1Wf9MdKDZRFT?8=lBd<_n$+b`g6WK`T4Vn6w;7^>SM`g=li@aF>P&HpOEk2#@hD; z5oPmBlbfI1apzg@5*U*R&TJ-1<@t5YeI{;C`PKS8juhZiRV%1+(xidzV|nbKEq4@b zqX%kPpTE7Z&D8V~3O{N{M0Zr_X@9JIP&hAG=AcbT*mhl~h-)#dmQpzpKSGleZ$|q_ z^ly(;uGSWF_~>|lrx2$kWvHFusCnErG`3#+^cf=VGgaS-GiCxsI(fu*G#137Rga#` z@V)Dn*k$O$MJv5};%Cc8D13pau1Gk6Dz*r;-iaW2yTz(vUJ&ifFjoqXDCW7S^H2`o zKDC@xy(+JfNcG-$@+tSGZG*L6gp;tVQI;3lqMLvl>7}@X24>glnU9jNr%39;AleQx zYVpVjc0JPPi56G8a<_Rr{O@TfRs39JGU^FJqB&4hCoLc2C?v4v}mk?q% z`l*d+iZ7`hhh-S|(x_NYl<&VTN`%ZP70ueXI|>UL z&PteC&Gm}5$|^ehNjb7-i)9}JA`gs@|IGLQAl?}Mt_PAz+rkvvVHQ1vc92?-I^DbK zPX@F>EoxsU!W8I!F0pEfx^L(XnL^uf0E7t@^_E|0CvZt;I;@ z(ebF~qXFBcudBxP1HJJ^%oAR|>v0g6jb{F#0-imfs0`3fNb2MI)=Un_PKc*mjsa8t z7T1ht*b*3uvLGc0WN}V6oC)a?&&?`A8|3n*n0tmx-sY~*uh=)3cC3>$ikF|+(C?|_ zAnxZ%8)E2OBDAO|O5sKmN+~AXdT8<$opx-u&g5=k)qGsJquSGK^%`;xyGj91-aip1 z>w=C^?M^leefJMk)V8Z{AH@jb@V8BYEhcF)q=p2t$UA^fvCh0NPeg0bEiyv8+Tb(S z3e2dh{6RhUO0lS177*x)&wOwA?U*e!bNRg&JxJ&mB_Z2L8-&R+n-Nj2j-`nZLY=cy z0pY~EGcB11Zj{_*PxhD&-zHkqPXJ$erL`;WIS?gyKD*uscM~Jc(6{^{@QBw!`Xilx zu45_3RqZd(#-12g0eFZ51=|r*!al9E?SCTt;bpGc(9`>?k@%BD+9u}D4su8Fprh!I z0>e!HHQ&v7tgNGqCPL#AM>$Ou+ zXwDE3SXJFwAXy2e-_?$K`3+rQE>}F8^NCfbo&z_6Q$1ccakY&Zd=UR!6 z%%p~bvdGf!x%NnJ{mH5z6vdziuYg?r00`kuNOp#}xI&5lb79043?0NCG7;lWQ#b#|a6{A1R=7$h|A0 zbV=2y`)u6qusjkRv@iin$8o<%YHsHPi#GxVT2T7=rqW4h+(AGBg#0G|qq6Dh0$al) zSgl>)lz(Y9kEl$v2MV(|v8MeERqL(EB(`gRHLZnJ$|wi=yunNl_*n+Qz6KX%lV#^< z!JI}ev5lnFH@07!2M1dd=u(s@old%Vk&MEkhcAP{u3Z4=M8;1Sa@!bYe}%!Gh!4UI zx$lMl&Brtw86lb@5_I+`rbBOR@wsssbs;4FT`1}jNeBtr0!ir=%KZts+2>CgZ{ssI zx6v0Sswxi|$5@z}?UiijCyaRe_;X2PIg-`akYOkDhQal}IaZbZm-e9CJ@<;NH zYemrh1FP-}0ce*wBzERni0_Q_&g`IXR2j8nr1u^U-PFU146 z3!e%D`g}xJg7zqdYqler&4c0n1QAxh{jOgx`SSq2YDoeJ)5moxmykqGO#X+Q%4q$- ztrZJrYbqT&5kVhpNS8~4K8xu$H!ff6M%J#X-&}fjm7+WbtM-iXz02Je?3$j!Cdg#N zKp2bO_3li4Gu!KsTTl@|3iyPLkIB$dSpKbpky>?`uoc$K z#}!uBdaax7*=d-};sKsLC@SO2PKn> queryList = sqlResult.getResultInList(true, FormatUtils.DEFAULT_TIME_FORMAT, timePrecision); @@ -345,6 +345,15 @@ public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) { row -> { MultiwayTree.addTreeNodeFromString(tree, row.get(0)); }); + try { + logger.info("before merge, the size of forest is: ", tree.getRoot().getChildren().size()); + // ChunkMerge or RandomMerge + MultiwayTree.mergeTree(tree, "ChunkMerge"); + logger.info("after merge, the size of forest is: ", tree.getRoot().getChildren().size()); + } catch (Exception e) { + logger.error("merge tree error"); + e.printStackTrace(); + } try (InputStream inputStream = IginxInterpreter8.class.getClassLoader().getResourceAsStream("static/vis/network.html")) { BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); @@ -359,24 +368,31 @@ public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) { // 写入vis等库文件,只在新环境执行一次 String targetPath = outfileDir + "/graphs/lib/"; if (!FileUtil.isDirectoryLoaded(targetPath)) { + logger.info("buildNetworkForShowColumns: upload static resources start"); String sourcePath = "static/vis/lib/"; String jarUrl = Objects.requireNonNull(IginxInterpreter8.class.getClassLoader().getResource(sourcePath)) .toString(); String jarPath = jarUrl.substring(jarUrl.indexOf("file:") + 5, jarUrl.indexOf(".jar") + 4); FileUtil.extractDirectoryFromJar(jarPath, sourcePath, targetPath); + logger.info("buildNetworkForShowColumns: upload static resources finish"); } // 写入network html File networkHtml = new File(outfileDir + "/graphs/network.html"); + logger.info( + "buildNetworkForShowColumns: the absolute path of the output network.html is {}", + networkHtml.getAbsolutePath()); OutputStream outputStream = Files.newOutputStream(networkHtml.toPath()); outputStream.write(html.getBytes()); outputStream.close(); + logger.info("buildNetworkForShowColumns: html(string) write to network.html finish"); InputStream inputStreamMain = IginxInterpreter8.class.getClassLoader().getResourceAsStream("static/vis/main.html"); BufferedReader br = new BufferedReader(new InputStreamReader(inputStreamMain)); while ((line = br.readLine()) != null) { mainHtml.append(line).append("\n"); } + logger.info("buildNetworkForShowColumns: mainHtml is: {}", mainHtml); return mainHtml .toString() .replace("FILE_HOST", fileHttpHost) diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java new file mode 100644 index 0000000..6f30a27 --- /dev/null +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java @@ -0,0 +1,109 @@ +package org.apache.zeppelin.iginx.util; + +import java.io.BufferedReader; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EmbeddingUtils { + private static final Logger logger = LoggerFactory.getLogger(EmbeddingUtils.class); + private static final Map> embeddings = new HashMap<>(); + private static final int EMBEDDING_DIMENSION = 50; // 嵌入向量维度 + private static final Random RANDOM = new Random(42); // 固定随机种子 + + static { + try { + InputStream inputStream = + EmbeddingUtils.class.getClassLoader().getResourceAsStream("model/glove.6B.50d.txt"); + if (inputStream == null) { + throw new IllegalArgumentException("fail to find GloVe model"); + } + + try (BufferedReader br = new BufferedReader(new InputStreamReader(inputStream))) { + String line; + while ((line = br.readLine()) != null) { + String[] tokens = line.split(" "); + String word = tokens[0]; + List vector = new ArrayList<>(tokens.length - 1); + for (int i = 1; i < tokens.length; i++) { + vector.add(Double.parseDouble(tokens[i])); + } + embeddings.put(word, vector); + } + } + logger.info("load GloVe success"); + } catch (Exception e) { + throw new RuntimeException("load GloVe fail", e); + } + } + + public static List getEmbedding(String word) { + List embedding = embeddings.get(word); + + if (embedding == null) { + // 按照"-"," ","_"分割取embedding的平均值 + String[] parts = word.split("[-\\s_]+"); + if (parts.length > 1) { + List sumEmbedding = new ArrayList<>(); + double sum = 0; + for (String part : parts) { + List partEmbedding = embeddings.get(part); + if (partEmbedding != null) { + sumEmbedding.addAll(partEmbedding); + sum += 1; // 用于计算平均值 + } + } + // 如果至少找到一个部分的向量,计算平均向量 + for (int i = 0; i < EMBEDDING_DIMENSION; i++) { + sumEmbedding.set(i, sumEmbedding.get(i) / sum); + } + return sumEmbedding; + } else { + // 如果所有部分都未找到且不可拆分,生成随机向量 + embedding = generateRandomVector(EMBEDDING_DIMENSION); + System.out.println("单词 '" + word + "' 未找到,生成随机向量替代。"); + } + } + return embedding; + } + + private static List generateRandomVector(int dimension) { + List vector = new ArrayList<>(dimension); + for (int i = 0; i < dimension; i++) { + vector.add(RANDOM.nextDouble() * 2 - 1); // 随机值范围 [-1, 1] + } + return vector; + } + + public static double calculateSimilarity(List embedding1, List embedding2) { + if (embedding1 == null || embedding2 == null || embedding1.size() != embedding2.size()) { + throw new IllegalArgumentException( + "Embeddings must not be null and must have the same length"); + } + + double dotProduct = 0.0; + double normA = 0.0; + double normB = 0.0; + + for (int i = 0; i < embedding1.size(); i++) { + double valueA = embedding1.get(i); + double valueB = embedding2.get(i); + dotProduct += valueA * valueB; + normA += Math.pow(valueA, 2); + normB += Math.pow(valueB, 2); + } + + return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); + } + + public static void main(String[] args) { + // 测试:获取两个句子的嵌入向量并计算相似度 + List embedding1 = getEmbedding("weather"); + List embedding2 = getEmbedding("climate"); + + double similarity = calculateSimilarity(embedding1, embedding2); + System.out.println("相似度: " + similarity); + } +} diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java new file mode 100644 index 0000000..55aeb32 --- /dev/null +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java @@ -0,0 +1,140 @@ +package org.apache.zeppelin.iginx.util; + +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; +import java.util.List; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.util.EntityUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class LLMUtils { + private static final Logger logger = LoggerFactory.getLogger(LLMUtils.class); + private static final String API_KEY = "204a3ea9bf39f18dd9bf32c71ecbb607.mITgz6pgV7Hzj27A"; + private static final String API_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"; + + public static String getResponse(String prompt) { + try (CloseableHttpClient httpClient = HttpClients.createDefault()) { + // 构建请求体 + JSONObject requestBody = new JSONObject(); + requestBody.put("model", "glm-4"); // 添加模型参数 + + // 构建消息数组 + JSONArray messages = new JSONArray(); + JSONObject userMessage = new JSONObject(); + userMessage.put("role", "user"); + userMessage.put("content", prompt); + messages.add(userMessage); + requestBody.put("messages", messages); + + // 创建 HTTP POST 请求 + HttpPost postRequest = new HttpPost(API_URL); + postRequest.setHeader("Content-Type", "application/json"); + postRequest.setHeader("Authorization", "Bearer " + API_KEY); + + // 设置请求体 + StringEntity entity = new StringEntity(requestBody.toString(), "UTF-8"); + postRequest.setEntity(entity); + + // 发送请求并获取响应 + try (CloseableHttpResponse response = httpClient.execute(postRequest)) { + int statusCode = response.getStatusLine().getStatusCode(); + String responseBody = EntityUtils.toString(response.getEntity(), "UTF-8"); + + if (statusCode == 200) { + if (responseBody == null || responseBody.isEmpty()) { + return "Error: Empty response body"; + } + JSONObject jsonResponse = JSONObject.parseObject(responseBody); + if (!jsonResponse.containsKey("choices")) { + return "Error: Missing 'choices' in response"; + } + JSONArray choices = jsonResponse.getJSONArray("choices"); + if (choices.isEmpty()) { + return "Error: Empty 'choices' in response"; + } + String message = choices.getJSONObject(0).getJSONObject("message").getString("content"); + logger.info("get LLM response: {}", message); + return message; + } else { + return "Error: " + statusCode + " - " + responseBody; + } + } + } catch (Exception e) { + e.printStackTrace(); + return "Error occurred: " + e.getMessage(); + } + } + + /** + * 获取LLM回答中的标准的json字符串 + * + * @param json + * @return + */ + private static String getStandardJson(String json) { + + int start = json.indexOf('{'); + int end = json.lastIndexOf('}'); + String res = null; + if (start != -1 && end != -1 && start < end) { + res = json.substring(start, end + 1); + } + // res = res.replace("json", "").replaceAll("```", ""); + return res; + } + + public static String getConcept(List nodes) { + StringBuilder conceptName = new StringBuilder(); + for (TreeNode node : nodes) { + conceptName.append(node.getValue()).append(";"); + } + // String res = getResponse("你是一个概括大师,我将给你多个短语,中间用';'来分隔,请给概括这些短语,返回一个短语。" + + // "注意只需返回这个短语即可,短语使用大括号进行包裹。需要你合并概括的多和短语是:" + conceptName ); + String res = + getResponse( + "You are a summarization master. I will provide multiple words or phases separated by ';'. Please summarize them into a single word or a phase if possible. If merging is not feasible, return a hyphen-separated string like 'apple-desk'." + + "Return only the result wrapped in curly braces. Words: " + + conceptName); + res = getStandardConcept(res); + if (res == null) { + logger.info("getConcept: the form of LLM response is wrong"); + return getConcept(nodes); + } + return res; + } + + public static String getRelation(String name, String str) { + // String res = + // getResponse("你是一个概念大师,我将给你一个待匹配短语和若干个目标短语,提供的若干个目标短语中间用';'来分隔,请分析在概念上和待匹配短语具有非常非常强烈关联的目标短语,如有则返回该目标短语。" + + // "注意只需返回这个短语即可,短语使用大括号进行包裹,如果没有一个目标短语符合要求则直接返回一个大括号。\n待匹配短语为:" + name + + // "\n若干目标短语为:" + str); + String res = + getResponse( + "You are a concept master. I will provide you with a phrase to match and several target phrases, separated by ';'. Please analyze and identify the target phrase that is conceptually very, very strongly related to the phrase to match. If such a phrase exists, return that target phrase." + + "Note that you should only return the phrase, wrapped in curly braces. If no target phrase meets the requirement, return an empty pair of curly braces.\nPhrase to match: " + + name + + "\nTarget phrases: " + + str); + + res = getStandardConcept(res); + return res; + } + + private static String getStandardConcept(String text) { + + int start = text.indexOf('{'); + int end = text.lastIndexOf('}'); + String res = null; + if (start != -1 && end != -1 && start < end) { + res = text.substring(start + 1, end); + res = res.replaceAll("\"", ""); + return res; + } + return null; + } +} diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java index a4afffa..5006dd7 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java @@ -1,6 +1,9 @@ package org.apache.zeppelin.iginx.util; +import java.lang.reflect.Constructor; +import java.util.List; import org.apache.commons.lang3.StringUtils; +import org.apache.zeppelin.iginx.util.algorithm.mergeforest.MergeForestStrategy; public class MultiwayTree { public static final String ROOT_NODE_NAME = "数据资产"; @@ -51,7 +54,7 @@ public void traversePreorder(TreeNode node) { public static MultiwayTree getMultiwayTree() { MultiwayTree tree = new MultiwayTree(); - tree.root = new TreeNode(ROOT_NODE_PATH, ROOT_NODE_NAME); // 初始化 + tree.root = new TreeNode(ROOT_NODE_PATH, ROOT_NODE_NAME, null); // 初始化 return tree; } @@ -59,9 +62,19 @@ public static void addTreeNodeFromString(MultiwayTree tree, String nodeString) { String[] nodes = nodeString.split("\\."); TreeNode newNode = tree.root; for (int i = 0; i < nodes.length; i++) { + List embedding = EmbeddingUtils.getEmbedding(nodes[i]); newNode = tree.insert( - newNode, new TreeNode(StringUtils.join(newNode.path, ".", nodes[i]), nodes[i])); + newNode, + new TreeNode(StringUtils.join(newNode.path, ".", nodes[i]), nodes[i], embedding)); } } + + public static void mergeTree(MultiwayTree tree, String strategy) throws Exception { + Class mergeClass = + Class.forName("org.apache.zeppelin.iginx.util.algorithm.mergeforest." + strategy); + Constructor constructor = mergeClass.getConstructor(); + MergeForestStrategy mergeForestStrategy = (MergeForestStrategy) constructor.newInstance(); + mergeForestStrategy.mergeForest(tree.root); + } } diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/TreeNode.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/TreeNode.java index 630906e..2cab1e8 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/TreeNode.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/TreeNode.java @@ -7,19 +7,13 @@ public class TreeNode { String path; String value; List children; - List columns; + List embedding; - public TreeNode(String path, String value, List columns) { - this.path = path; - this.value = value; - this.children = new ArrayList<>(); - this.columns = columns; - } - - public TreeNode(String path, String value) { + public TreeNode(String path, String value, List embedding) { this.path = path; this.value = value; this.children = new ArrayList<>(); + this.embedding = embedding; } public String getValue() { @@ -38,11 +32,19 @@ public void setChildren(List children) { this.children = children; } - public List getColumns() { - return columns; + public String getPath() { + return path; + } + + public void setPath(String path) { + this.path = path; + } + + public List getEmbedding() { + return embedding; } - public void setColumns(List columns) { - this.columns = columns; + public void setEmbedding(List embedding) { + this.embedding = embedding; } } diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java new file mode 100644 index 0000000..8c9a787 --- /dev/null +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java @@ -0,0 +1,103 @@ +package org.apache.zeppelin.iginx.util.algorithm.mergeforest; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import org.apache.zeppelin.iginx.util.EmbeddingUtils; +import org.apache.zeppelin.iginx.util.LLMUtils; +import org.apache.zeppelin.iginx.util.TreeNode; + +public class ChunkMerge implements MergeForestStrategy { + private static final int THREAD_POOL_SIZE = 8; // 可根据机器配置调整线程数 + + public void mergeForest(TreeNode root) throws ExecutionException, InterruptedException { + List forest = root.getChildren(); + int targetTreeCount = (int) Math.ceil(Math.sqrt(forest.size())); + System.out.println("targetTreeCount: " + targetTreeCount); + ExecutorService executor = Executors.newFixedThreadPool(THREAD_POOL_SIZE); + + while (forest.size() > targetTreeCount) { + System.out.println("开始新一轮合并,当前森林大小: " + forest.size()); + List>> futureResults = new ArrayList<>(); + + int chunkSize = Math.max(2, (int) Math.ceil((double) forest.size() / THREAD_POOL_SIZE)); + System.out.println("chunkSize: " + chunkSize); + for (int i = 0; i < forest.size(); i += chunkSize) { + List chunk = + new ArrayList<>(forest.subList(i, Math.min(i + chunkSize, forest.size()))); + futureResults.add(executor.submit(() -> mergeChunk(chunk))); + } + + List mergedForest = new ArrayList<>(); + for (Future> future : futureResults) { + mergedForest.addAll(future.get()); + } + forest.clear(); + forest.addAll(mergedForest); + + if (forest.size() <= targetTreeCount) { + break; + } + } + executor.shutdown(); + + root.setChildren(forest); + } + + private List mergeChunk(List chunk) { + System.out.println("mergeChunk"); + if (chunk.size() == 1) return chunk; + List mergedChunk = new ArrayList<>(chunk); + + List nodesToMerge = findBestMergeGroup(mergedChunk); + + if (!nodesToMerge.isEmpty()) { + String newConceptName = LLMUtils.getConcept(nodesToMerge); + System.out.println("newConceptName: " + newConceptName); + + List newEmbedding = EmbeddingUtils.getEmbedding(newConceptName); + TreeNode newParent = new TreeNode(newConceptName, newConceptName, newEmbedding); + + for (TreeNode node : nodesToMerge) { + newParent.getChildren().add(node); + changePath(node, newConceptName); + } + + mergedChunk.removeAll(nodesToMerge); + mergedChunk.add(newParent); + } + + return mergedChunk; + } + + private List findBestMergeGroup(List nodes) { + System.out.println("findBestMergeGroup"); + List bestGroup = new ArrayList<>(); + double bestSimilarity = -1; + + for (int i = 0; i < nodes.size(); i++) { + for (int j = i + 1; j < nodes.size(); j++) { + double similarity = + EmbeddingUtils.calculateSimilarity( + nodes.get(i).getEmbedding(), nodes.get(j).getEmbedding()); + if (similarity > bestSimilarity) { + bestSimilarity = similarity; + bestGroup = Arrays.asList(nodes.get(i), nodes.get(j)); // 使用 Arrays.asList 替代 List.of + } + } + } + + return bestGroup; + } + + private void changePath(TreeNode node, String prePath) { + node.setPath(prePath + "." + node.getPath()); + for (TreeNode child : node.getChildren()) { + changePath(child, prePath); + } + } +} diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java new file mode 100644 index 0000000..fd9e5e2 --- /dev/null +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java @@ -0,0 +1,7 @@ +package org.apache.zeppelin.iginx.util.algorithm.mergeforest; + +import org.apache.zeppelin.iginx.util.TreeNode; + +public interface MergeForestStrategy { + void mergeForest(TreeNode root) throws Exception; +} diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java new file mode 100644 index 0000000..5a01cd6 --- /dev/null +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java @@ -0,0 +1,101 @@ +package org.apache.zeppelin.iginx.util.algorithm.mergeforest; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import org.apache.zeppelin.iginx.util.EmbeddingUtils; +import org.apache.zeppelin.iginx.util.LLMUtils; +import org.apache.zeppelin.iginx.util.TreeNode; + +public class RandomMerge implements MergeForestStrategy { + private static final Double SIMILARITY_THRESHOLD = 0.7; + private static final Integer MAX_CONTINUOUS_FAILURE_COUNT = 5; + + @Override + public void mergeForest(TreeNode root) { + List forest = root.getChildren(); + int targetTreeCount = (int) Math.ceil(Math.sqrt(forest.size())); + Random random = new Random(); + int failureCount = 0; // 记录连续合并失败次数 + + while (forest.size() > targetTreeCount && failureCount < MAX_CONTINUOUS_FAILURE_COUNT) { + System.out.println("当前森林的大小为: " + forest.size()); + + // 随机选择一个根节点 + TreeNode referenceNode = forest.get(random.nextInt(forest.size())); + List selectedNodes = new ArrayList<>(); + selectedNodes.add(referenceNode); + + // 计算其他根节点与 referenceNode 的相似性并排序 + List otherNodes = new ArrayList<>(forest); + otherNodes.remove(referenceNode); + + otherNodes.sort( + (n1, n2) -> { + try { + double similarity1 = + EmbeddingUtils.calculateSimilarity( + referenceNode.getEmbedding(), n1.getEmbedding()); + double similarity2 = + EmbeddingUtils.calculateSimilarity( + referenceNode.getEmbedding(), n2.getEmbedding()); + return Double.compare(similarity2, similarity1); // 按相似性降序排列 + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + + // 根据相似性逐步添加节点并计算累计相似性 + double cumulativeSimilarity = 1.0; + for (TreeNode node : otherNodes) { + double similarity = + EmbeddingUtils.calculateSimilarity(referenceNode.getEmbedding(), node.getEmbedding()); + cumulativeSimilarity *= similarity; + System.out.println("当前累计相似性: " + cumulativeSimilarity); + + if (cumulativeSimilarity > SIMILARITY_THRESHOLD) { + selectedNodes.add(node); + } else { + break; // 达到阈值,停止继续合并 + } + } + + // 合并满足条件的节点 + if (selectedNodes.size() > 1) { + String newConceptName = LLMUtils.getConcept(selectedNodes); + List newEmbedding = EmbeddingUtils.getEmbedding(newConceptName); + TreeNode newParent = new TreeNode(newConceptName, newConceptName, newEmbedding); + + for (TreeNode node : selectedNodes) { + newParent.getChildren().add(node); + changePath(node, newConceptName); + } + + // 从森林中移除已合并的节点,并加入新合并的节点 + forest.removeAll(selectedNodes); + forest.add(newParent); + + // 重置连续失败计数器 + failureCount = 0; + System.out.println("合并成功,新森林大小为: " + forest.size()); + } else { + // 增加失败计数 + failureCount++; + System.out.println("合并失败,连续失败次数: " + failureCount); + } + } + + if (failureCount >= MAX_CONTINUOUS_FAILURE_COUNT) { + System.out.println("连续合并失败次数已达到上限,停止合并"); + } + + root.setChildren(forest); + } + + private void changePath(TreeNode node, String prePath) { + node.setPath(prePath + "." + node.getPath()); + for (TreeNode child : node.getChildren()) { + changePath(child, prePath); + } + } +} From 539a1700f6964efd3a4b72d173df4a5988018bb9 Mon Sep 17 00:00:00 2001 From: ych <1821947036@qq.com> Date: Sun, 24 Nov 2024 11:22:43 +0800 Subject: [PATCH 2/3] add comment and optimize algorithm for merge forest --- README.md | 1 + .../zeppelin/iginx/IginxInterpreter8.java | 25 +++++------- .../zeppelin/iginx/util/EmbeddingUtils.java | 18 +++++++-- .../apache/zeppelin/iginx/util/LLMUtils.java | 40 ++++++++++--------- .../zeppelin/iginx/util/MultiwayTree.java | 24 ++++++++--- .../algorithm/mergeforest/ChunkMerge.java | 19 ++++++--- .../mergeforest/MergeForestStrategy.java | 6 +++ .../algorithm/mergeforest/RandomMerge.java | 29 +++++++++----- 8 files changed, 104 insertions(+), 58 deletions(-) diff --git a/README.md b/README.md index 423cf41..f330640 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,7 @@ IGinX Zeppelin 解释器是需要连接 IGinX 的,如果我们重启了 IGinX ##### SHOW COLUMNS 命令 展示可视化资产图会涉及到embedding,请前往 https://nlp.stanford.edu/projects/glove/ 下载相关模型,解压后放入 /resources/model 文件夹下,并根据模型的embedding维度修改embeddingUtils类中的EMBEDDING_DIMENSION参数。 ![img](./images/show_columns.png) +此外,当前选用的模型只支持英文,暂不支持中文。 ### 使用RESTful语句 diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java b/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java index b341614..25160e7 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java @@ -37,12 +37,8 @@ import org.apache.zeppelin.iginx.util.MultiwayTree; import org.apache.zeppelin.iginx.util.SqlCmdUtil; import org.apache.zeppelin.interpreter.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; public class IginxInterpreter8 extends Interpreter { - private static final Logger LOGGER = LoggerFactory.getLogger(IginxInterpreter8.class); - private static final String IGINX_HOST = "iginx.host"; private static final String IGINX_PORT = "iginx.port"; private static final String IGINX_USERNAME = "iginx.username"; @@ -392,13 +388,12 @@ public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) thro while ((line = br.readLine()) != null) { mainHtml.append(line).append("\n"); } - logger.info("buildNetworkForShowColumns: mainHtml is: {}", mainHtml); return mainHtml .toString() .replace("FILE_HOST", fileHttpHost) .replace("FILE_PORT", String.valueOf(fileHttpPort)); } catch (IOException e) { - LOGGER.warn("load show columns to tree error", e); + logger.warn("load show columns to tree error", e); } return ""; } @@ -424,7 +419,7 @@ private InterpreterResult processLoadCsv(String sql, InterpreterContext context) InterpreterResult interpreterResult; String uploadParagraphKey = context.getParagraphId() + "_UPLOAD_FILE"; /* response upload file form, user will rerun paragraph when upload finished. */ - LOGGER.info("+++++++Id={}, paragraphId={}", context.getNoteId(), uploadParagraphKey); + logger.info("+++++++Id={}, paragraphId={}", context.getNoteId(), uploadParagraphKey); if (!uploadParagraphSet.contains(uploadParagraphKey)) { try (InputStream inputStream = IginxInterpreter8.class.getClassLoader().getResourceAsStream("uploadForm.html"); @@ -445,13 +440,13 @@ private InterpreterResult processLoadCsv(String sql, InterpreterContext context) uploadParagraphSet.add(uploadParagraphKey); return interpreterResult; } catch (IOException e) { - LOGGER.error("load html error", e); + logger.error("load html error", e); } return new InterpreterResult(InterpreterResult.Code.ERROR); } try { - LOGGER.info("load data sql execute, sql={}", sql); + logger.info("load data sql execute, sql={}", sql); SessionExecuteSqlResult res = session.executeSql(sql); String parseErrorMsg = res.getParseErrorMsg(); if (parseErrorMsg != null && !parseErrorMsg.isEmpty()) { @@ -510,11 +505,11 @@ private String convertPath(String inputPath, String sql) throws IOException { String path; Path pathObj; if (SystemUtils.IS_OS_WINDOWS) { - LOGGER.info("current os is Windows"); + logger.info("current os is Windows"); path = inputPath.replace("/", "\\"); pathObj = Paths.get(path); } else { - LOGGER.info("current os is Linux or Mac"); + logger.info("current os is Linux or Mac"); path = inputPath.replace("\\", "/"); pathObj = Paths.get(path); } @@ -523,7 +518,7 @@ private String convertPath(String inputPath, String sql) throws IOException { HttpUtil.getCurrentPath(DEFAULT_UPLOAD_DIR) + File.separator + pathObj.getFileName().toString(); - LOGGER.info("converted path is {}", path); + logger.info("converted path is {}", path); return path; } @@ -1057,7 +1052,7 @@ public InterpreterResult tuneFontSize( } else { hTagNumber = 6; } - LOGGER.info( + logger.info( "NoteId={},ParagraphId={},fontSizeEnable={},fontSize={}", context.getNoteId(), context.getParagraphId(), @@ -1069,7 +1064,7 @@ public InterpreterResult tuneFontSize( message.stream() .map( item -> { - LOGGER.debug("type={},data={}", item.getType(), item.getData()); + logger.debug("type={},data={}", item.getType(), item.getData()); if (item.getType().equals(InterpreterResult.Type.TABLE)) { String collect = Arrays.stream(item.getData().split(NEWLINE)) @@ -1098,7 +1093,7 @@ public InterpreterResult tuneFontSize( item.getType(), String.format("%s", hTagNumber, item.getData(), hTagNumber)); } else { - LOGGER.warn("unexpected result type {}", item.getType()); + logger.warn("unexpected result type {}", item.getType()); } return new InterpreterResultMessage(item.getType(), item.getData()); }) diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java index 6f30a27..cfa9ba5 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java @@ -39,6 +39,15 @@ public class EmbeddingUtils { } } + /** + * 获取输入内容的 embedding + * 1. 直接通过 Map 查找,若找到,直接返回结果 + * 2. 若找不到,则对输入内容按照 "-", " ", "_" 进行划分,再分别获取 embeddig 并取平均 + * 3. 若输入的内容找不到且已无法再划分,则随机一个 embedding (固定了随机种子,为了让每次执行的结果一致) + * + * @param word + * @return + */ public static List getEmbedding(String word) { List embedding = embeddings.get(word); @@ -51,11 +60,12 @@ public static List getEmbedding(String word) { for (String part : parts) { List partEmbedding = embeddings.get(part); if (partEmbedding != null) { - sumEmbedding.addAll(partEmbedding); - sum += 1; // 用于计算平均值 + for (int i = 0; i < EMBEDDING_DIMENSION; i++) { + sumEmbedding.set(i, sumEmbedding.get(i) + partEmbedding.get(i)); // 按维度累加 + } + sum += 1; } } - // 如果至少找到一个部分的向量,计算平均向量 for (int i = 0; i < EMBEDDING_DIMENSION; i++) { sumEmbedding.set(i, sumEmbedding.get(i) / sum); } @@ -63,7 +73,7 @@ public static List getEmbedding(String word) { } else { // 如果所有部分都未找到且不可拆分,生成随机向量 embedding = generateRandomVector(EMBEDDING_DIMENSION); - System.out.println("单词 '" + word + "' 未找到,生成随机向量替代。"); + logger.info("fail to find the embedding of '" + word + "', generating Random Vector to replace"); } } return embedding; diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java index 55aeb32..ec57a57 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java @@ -17,6 +17,13 @@ public class LLMUtils { private static final String API_KEY = "204a3ea9bf39f18dd9bf32c71ecbb607.mITgz6pgV7Hzj27A"; private static final String API_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"; + /** + * 输入 prompt ,LLM 返回响应结果 + * 目前使用的是智谱 GLM-4 模型 + * + * @param prompt + * @return + */ public static String getResponse(String prompt) { try (CloseableHttpClient httpClient = HttpClients.createDefault()) { // 构建请求体 @@ -66,41 +73,30 @@ public static String getResponse(String prompt) { } } catch (Exception e) { e.printStackTrace(); + logger.info("get LLM response error"); return "Error occurred: " + e.getMessage(); } } /** - * 获取LLM回答中的标准的json字符串 + * 获取概括的单词,由于有些情况确实不适合合并,允许返回 wordA-wordB 的样式 + * 用于 【顶层合并】 * - * @param json + * @param nodes * @return */ - private static String getStandardJson(String json) { - - int start = json.indexOf('{'); - int end = json.lastIndexOf('}'); - String res = null; - if (start != -1 && end != -1 && start < end) { - res = json.substring(start, end + 1); - } - // res = res.replace("json", "").replaceAll("```", ""); - return res; - } - public static String getConcept(List nodes) { StringBuilder conceptName = new StringBuilder(); for (TreeNode node : nodes) { conceptName.append(node.getValue()).append(";"); } - // String res = getResponse("你是一个概括大师,我将给你多个短语,中间用';'来分隔,请给概括这些短语,返回一个短语。" + - // "注意只需返回这个短语即可,短语使用大括号进行包裹。需要你合并概括的多和短语是:" + conceptName ); + //你是一个概括大师,我将给你多个短语,中间用';'来分隔,请给概括这些短语,返回一个短语。注意只需返回这个短语即可,短语使用大括号进行包裹。需要你合并概括的多和短语是:" + conceptName String res = getResponse( "You are a summarization master. I will provide multiple words or phases separated by ';'. Please summarize them into a single word or a phase if possible. If merging is not feasible, return a hyphen-separated string like 'apple-desk'." + "Return only the result wrapped in curly braces. Words: " + conceptName); - res = getStandardConcept(res); + res = getStandardResponse(res); if (res == null) { logger.info("getConcept: the form of LLM response is wrong"); return getConcept(nodes); @@ -121,11 +117,17 @@ public static String getRelation(String name, String str) { + "\nTarget phrases: " + str); - res = getStandardConcept(res); + res = getStandardResponse(res); return res; } - private static String getStandardConcept(String text) { + /** + * 将返回要求包裹的的 {} 除去,获取真正有用的内容 + * + * @param text + * @return + */ + private static String getStandardResponse(String text) { int start = text.indexOf('{'); int end = text.lastIndexOf('}'); diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java index 5006dd7..d9580c1 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java @@ -1,11 +1,15 @@ package org.apache.zeppelin.iginx.util; import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; import java.util.List; import org.apache.commons.lang3.StringUtils; import org.apache.zeppelin.iginx.util.algorithm.mergeforest.MergeForestStrategy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class MultiwayTree { + private static final Logger logger = LoggerFactory.getLogger(MultiwayTree.class); public static final String ROOT_NODE_NAME = "数据资产"; public static final String ROOT_NODE_PATH = "root"; @@ -70,11 +74,19 @@ public static void addTreeNodeFromString(MultiwayTree tree, String nodeString) { } } - public static void mergeTree(MultiwayTree tree, String strategy) throws Exception { - Class mergeClass = - Class.forName("org.apache.zeppelin.iginx.util.algorithm.mergeforest." + strategy); - Constructor constructor = mergeClass.getConstructor(); - MergeForestStrategy mergeForestStrategy = (MergeForestStrategy) constructor.newInstance(); - mergeForestStrategy.mergeForest(tree.root); + public static void mergeTree(MultiwayTree tree, String strategy) { + try { + Class mergeClass = Class.forName("org.apache.zeppelin.iginx.util.algorithm.mergeforest." + strategy); + Constructor constructor = mergeClass.getConstructor(); + MergeForestStrategy mergeForestStrategy = (MergeForestStrategy) constructor.newInstance(); + logger.info("begin --" + strategy + "--"); + mergeForestStrategy.mergeForest(tree.root); + } catch (ClassNotFoundException e) { + logger.error("error: Could not find class '{}'. Please check if the class name is correct.", strategy, e); + } catch (NoSuchMethodException e) { + logger.error("error: Could not find a default constructor in class '{}'. Please ensure it has a no-args constructor.", strategy, e); + } catch (Exception e) { + logger.error("error: Reflection failed. Could not create an instance of class '{}'.", strategy, e); + } } } diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java index 8c9a787..e8223d0 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java @@ -14,6 +14,17 @@ public class ChunkMerge implements MergeForestStrategy { private static final int THREAD_POOL_SIZE = 8; // 可根据机器配置调整线程数 + /** + * 分块合并算法 + * 1. 循环多轮进行,当某轮结束后的森林大小满足要求,则退出 + * 2. 每轮按照线程数平均分块,每块中的树分别进行合并,块间互不干扰 + * 3. 通过 findBestMergeGroup 查找最合适的合并组:块中的所有树通过循环,两两比较,找到 embedding 最相似的两个合并 + * (第3步的 findBestMergeGroup 有较大的优化空间,因为有时候多棵树一起合并可能效果更好) + * + * @param root + * @throws ExecutionException + * @throws InterruptedException + */ public void mergeForest(TreeNode root) throws ExecutionException, InterruptedException { List forest = root.getChildren(); int targetTreeCount = (int) Math.ceil(Math.sqrt(forest.size())); @@ -21,7 +32,7 @@ public void mergeForest(TreeNode root) throws ExecutionException, InterruptedExc ExecutorService executor = Executors.newFixedThreadPool(THREAD_POOL_SIZE); while (forest.size() > targetTreeCount) { - System.out.println("开始新一轮合并,当前森林大小: " + forest.size()); + System.out.println("beginning new merge, the recent size of forest is: " + forest.size()); List>> futureResults = new ArrayList<>(); int chunkSize = Math.max(2, (int) Math.ceil((double) forest.size() / THREAD_POOL_SIZE)); @@ -49,7 +60,6 @@ public void mergeForest(TreeNode root) throws ExecutionException, InterruptedExc } private List mergeChunk(List chunk) { - System.out.println("mergeChunk"); if (chunk.size() == 1) return chunk; List mergedChunk = new ArrayList<>(chunk); @@ -57,7 +67,7 @@ private List mergeChunk(List chunk) { if (!nodesToMerge.isEmpty()) { String newConceptName = LLMUtils.getConcept(nodesToMerge); - System.out.println("newConceptName: " + newConceptName); + System.out.println("get new concept name: " + newConceptName); List newEmbedding = EmbeddingUtils.getEmbedding(newConceptName); TreeNode newParent = new TreeNode(newConceptName, newConceptName, newEmbedding); @@ -75,7 +85,6 @@ private List mergeChunk(List chunk) { } private List findBestMergeGroup(List nodes) { - System.out.println("findBestMergeGroup"); List bestGroup = new ArrayList<>(); double bestSimilarity = -1; @@ -86,7 +95,7 @@ private List findBestMergeGroup(List nodes) { nodes.get(i).getEmbedding(), nodes.get(j).getEmbedding()); if (similarity > bestSimilarity) { bestSimilarity = similarity; - bestGroup = Arrays.asList(nodes.get(i), nodes.get(j)); // 使用 Arrays.asList 替代 List.of + bestGroup = Arrays.asList(nodes.get(i), nodes.get(j)); } } } diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java index fd9e5e2..114e7a9 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/MergeForestStrategy.java @@ -3,5 +3,11 @@ import org.apache.zeppelin.iginx.util.TreeNode; public interface MergeForestStrategy { + /** + * 合并森林(root下的 children)算法,若初始有 n 棵树,则最终合并成 sqrt(n) 棵树 + * + * @param root + * @throws Exception + */ void mergeForest(TreeNode root) throws Exception; } diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java index 5a01cd6..2961c77 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java @@ -11,6 +11,15 @@ public class RandomMerge implements MergeForestStrategy { private static final Double SIMILARITY_THRESHOLD = 0.7; private static final Integer MAX_CONTINUOUS_FAILURE_COUNT = 5; + /** + * 随机合并算法 + * 1. 循环多轮进行,当某轮结束后的森林大小满足要求或者已经连续 MAX_CONTINUOUS_FAILURE_COUNT 轮合并失败,则退出 + * 2. 每轮随机选择森林中的一棵树,计算其他所有树与它的 embedding 相似度,并从高到底排序 + * 3. 按照顺序对相似度进行累乘:假设相似度是 x,每次乘实际上是乘 1-(1-x)/2 ,乘积如果不小于 SIMILARITY_THRESHOLD 则继续乘,反之则结束 + * 4. 如果累计相似度满足要求的树不少于2棵,则进行合并,反之则计为一次合并失败 + * + * @param root + */ @Override public void mergeForest(TreeNode root) { List forest = root.getChildren(); @@ -19,7 +28,7 @@ public void mergeForest(TreeNode root) { int failureCount = 0; // 记录连续合并失败次数 while (forest.size() > targetTreeCount && failureCount < MAX_CONTINUOUS_FAILURE_COUNT) { - System.out.println("当前森林的大小为: " + forest.size()); + System.out.println("beginning new merge, the recent size of forest is: " + forest.size()); // 随机选择一个根节点 TreeNode referenceNode = forest.get(random.nextInt(forest.size())); @@ -50,10 +59,10 @@ public void mergeForest(TreeNode root) { for (TreeNode node : otherNodes) { double similarity = EmbeddingUtils.calculateSimilarity(referenceNode.getEmbedding(), node.getEmbedding()); - cumulativeSimilarity *= similarity; - System.out.println("当前累计相似性: " + cumulativeSimilarity); + cumulativeSimilarity *= (0.5 + similarity / 2); + System.out.println("now cumulative similarity: " + cumulativeSimilarity); - if (cumulativeSimilarity > SIMILARITY_THRESHOLD) { + if (cumulativeSimilarity >= SIMILARITY_THRESHOLD) { selectedNodes.add(node); } else { break; // 达到阈值,停止继续合并 @@ -77,16 +86,18 @@ public void mergeForest(TreeNode root) { // 重置连续失败计数器 failureCount = 0; - System.out.println("合并成功,新森林大小为: " + forest.size()); + System.out.println("merge success, now forest size is: " + forest.size()); } else { // 增加失败计数 failureCount++; - System.out.println("合并失败,连续失败次数: " + failureCount); + System.out.println("merge failure, now consecutive failure count is: " + failureCount); } - } - if (failureCount >= MAX_CONTINUOUS_FAILURE_COUNT) { - System.out.println("连续合并失败次数已达到上限,停止合并"); + // 判断是否已经到达连续合并失败的最大上限 + if (failureCount >= MAX_CONTINUOUS_FAILURE_COUNT) { + System.out.println("The count of consecutive merge failures has reached the limit, stop merge"); + break; + } } root.setChildren(forest); From e112c897977ce2680a43082bf4284a300321d748 Mon Sep 17 00:00:00 2001 From: ych <1821947036@qq.com> Date: Sun, 24 Nov 2024 13:23:10 +0800 Subject: [PATCH 3/3] fix an embedding bug and add some logs --- .../zeppelin/iginx/IginxInterpreter8.java | 2 +- .../zeppelin/iginx/util/EmbeddingUtils.java | 20 +++++++++++------- .../apache/zeppelin/iginx/util/LLMUtils.java | 9 ++++---- .../zeppelin/iginx/util/MultiwayTree.java | 17 ++++++++++----- .../algorithm/mergeforest/ChunkMerge.java | 17 ++++++++------- .../algorithm/mergeforest/RandomMerge.java | 21 ++++++++++--------- 6 files changed, 49 insertions(+), 37 deletions(-) diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java b/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java index 25160e7..f78eaa8 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java @@ -330,7 +330,7 @@ private InterpreterResult processSql(String sql, InterpreterContext context) { * * @param sqlResult */ - public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) throws Exception { + public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) { StringBuilder mainHtml = new StringBuilder(); List> queryList = sqlResult.getResultInList(true, FormatUtils.DEFAULT_TIME_FORMAT, timePrecision); diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java index cfa9ba5..c44b977 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java @@ -40,25 +40,28 @@ public class EmbeddingUtils { } /** - * 获取输入内容的 embedding - * 1. 直接通过 Map 查找,若找到,直接返回结果 - * 2. 若找不到,则对输入内容按照 "-", " ", "_" 进行划分,再分别获取 embeddig 并取平均 - * 3. 若输入的内容找不到且已无法再划分,则随机一个 embedding (固定了随机种子,为了让每次执行的结果一致) + * 获取输入内容的 embedding 1. 直接通过 Map 查找,若找到,直接返回结果 2. 若找不到,则对输入内容按照 "-", " ", "_" 进行划分,再分别获取 embeddig + * 并取平均 3. 若输入的内容找不到且已无法再划分,则随机一个 embedding (固定了随机种子,为了让每次执行的结果一致) * * @param word * @return */ public static List getEmbedding(String word) { + System.out.println("getEmbedding: input word is: " + word); + logger.info("getEmbedding: input word is: {}", word); List embedding = embeddings.get(word); if (embedding == null) { // 按照"-"," ","_"分割取embedding的平均值 String[] parts = word.split("[-\\s_]+"); if (parts.length > 1) { - List sumEmbedding = new ArrayList<>(); + List sumEmbedding = new ArrayList<>(EMBEDDING_DIMENSION); + for (int i = 0; i < EMBEDDING_DIMENSION; i++) { + sumEmbedding.add(0.0); // 初始化为零向量 + } double sum = 0; for (String part : parts) { - List partEmbedding = embeddings.get(part); + List partEmbedding = getEmbedding(part); if (partEmbedding != null) { for (int i = 0; i < EMBEDDING_DIMENSION; i++) { sumEmbedding.set(i, sumEmbedding.get(i) + partEmbedding.get(i)); // 按维度累加 @@ -73,7 +76,8 @@ public static List getEmbedding(String word) { } else { // 如果所有部分都未找到且不可拆分,生成随机向量 embedding = generateRandomVector(EMBEDDING_DIMENSION); - logger.info("fail to find the embedding of '" + word + "', generating Random Vector to replace"); + logger.info( + "fail to find the embedding of '" + word + "', generating Random Vector to replace"); } } return embedding; @@ -110,7 +114,7 @@ public static double calculateSimilarity(List embedding1, List e public static void main(String[] args) { // 测试:获取两个句子的嵌入向量并计算相似度 - List embedding1 = getEmbedding("weather"); + List embedding1 = getEmbedding("summer_2024"); List embedding2 = getEmbedding("climate"); double similarity = calculateSimilarity(embedding1, embedding2); diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java index ec57a57..df06fef 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/LLMUtils.java @@ -18,8 +18,7 @@ public class LLMUtils { private static final String API_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"; /** - * 输入 prompt ,LLM 返回响应结果 - * 目前使用的是智谱 GLM-4 模型 + * 输入 prompt ,LLM 返回响应结果 目前使用的是智谱 GLM-4 模型 * * @param prompt * @return @@ -79,8 +78,7 @@ public static String getResponse(String prompt) { } /** - * 获取概括的单词,由于有些情况确实不适合合并,允许返回 wordA-wordB 的样式 - * 用于 【顶层合并】 + * 获取概括的单词,由于有些情况确实不适合合并,允许返回 wordA-wordB 的样式 用于 【顶层合并】 * * @param nodes * @return @@ -90,7 +88,8 @@ public static String getConcept(List nodes) { for (TreeNode node : nodes) { conceptName.append(node.getValue()).append(";"); } - //你是一个概括大师,我将给你多个短语,中间用';'来分隔,请给概括这些短语,返回一个短语。注意只需返回这个短语即可,短语使用大括号进行包裹。需要你合并概括的多和短语是:" + conceptName + // 你是一个概括大师,我将给你多个短语,中间用';'来分隔,请给概括这些短语,返回一个短语。注意只需返回这个短语即可,短语使用大括号进行包裹。需要你合并概括的多和短语是:" + + // conceptName String res = getResponse( "You are a summarization master. I will provide multiple words or phases separated by ';'. Please summarize them into a single word or a phase if possible. If merging is not feasible, return a hyphen-separated string like 'apple-desk'." diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java index d9580c1..676a973 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/MultiwayTree.java @@ -1,7 +1,6 @@ package org.apache.zeppelin.iginx.util; import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; import java.util.List; import org.apache.commons.lang3.StringUtils; import org.apache.zeppelin.iginx.util.algorithm.mergeforest.MergeForestStrategy; @@ -76,17 +75,25 @@ public static void addTreeNodeFromString(MultiwayTree tree, String nodeString) { public static void mergeTree(MultiwayTree tree, String strategy) { try { - Class mergeClass = Class.forName("org.apache.zeppelin.iginx.util.algorithm.mergeforest." + strategy); + Class mergeClass = + Class.forName("org.apache.zeppelin.iginx.util.algorithm.mergeforest." + strategy); Constructor constructor = mergeClass.getConstructor(); MergeForestStrategy mergeForestStrategy = (MergeForestStrategy) constructor.newInstance(); logger.info("begin --" + strategy + "--"); mergeForestStrategy.mergeForest(tree.root); } catch (ClassNotFoundException e) { - logger.error("error: Could not find class '{}'. Please check if the class name is correct.", strategy, e); + logger.error( + "error: Could not find class '{}'. Please check if the class name is correct.", + strategy, + e); } catch (NoSuchMethodException e) { - logger.error("error: Could not find a default constructor in class '{}'. Please ensure it has a no-args constructor.", strategy, e); + logger.error( + "error: Could not find a default constructor in class '{}'. Please ensure it has a no-args constructor.", + strategy, + e); } catch (Exception e) { - logger.error("error: Reflection failed. Could not create an instance of class '{}'.", strategy, e); + logger.error( + "error: Reflection failed. Could not create an instance of class '{}'.", strategy, e); } } } diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java index e8223d0..768b523 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/ChunkMerge.java @@ -10,16 +10,17 @@ import org.apache.zeppelin.iginx.util.EmbeddingUtils; import org.apache.zeppelin.iginx.util.LLMUtils; import org.apache.zeppelin.iginx.util.TreeNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class ChunkMerge implements MergeForestStrategy { + private static final Logger logger = LoggerFactory.getLogger(ChunkMerge.class); private static final int THREAD_POOL_SIZE = 8; // 可根据机器配置调整线程数 /** - * 分块合并算法 - * 1. 循环多轮进行,当某轮结束后的森林大小满足要求,则退出 - * 2. 每轮按照线程数平均分块,每块中的树分别进行合并,块间互不干扰 - * 3. 通过 findBestMergeGroup 查找最合适的合并组:块中的所有树通过循环,两两比较,找到 embedding 最相似的两个合并 - * (第3步的 findBestMergeGroup 有较大的优化空间,因为有时候多棵树一起合并可能效果更好) + * 分块合并算法 1. 循环多轮进行,当某轮结束后的森林大小满足要求,则退出 2. 每轮按照线程数平均分块,每块中的树分别进行合并,块间互不干扰 3. 通过 findBestMergeGroup + * 查找最合适的合并组:块中的所有树通过循环,两两比较,找到 embedding 最相似的两个合并 (第3步的 findBestMergeGroup + * 有较大的优化空间,因为有时候多棵树一起合并可能效果更好) * * @param root * @throws ExecutionException @@ -28,11 +29,11 @@ public class ChunkMerge implements MergeForestStrategy { public void mergeForest(TreeNode root) throws ExecutionException, InterruptedException { List forest = root.getChildren(); int targetTreeCount = (int) Math.ceil(Math.sqrt(forest.size())); - System.out.println("targetTreeCount: " + targetTreeCount); + logger.info("targetTreeCount: {}", targetTreeCount); ExecutorService executor = Executors.newFixedThreadPool(THREAD_POOL_SIZE); while (forest.size() > targetTreeCount) { - System.out.println("beginning new merge, the recent size of forest is: " + forest.size()); + logger.info("beginning new merge, the recent size of forest is: {}", forest.size()); List>> futureResults = new ArrayList<>(); int chunkSize = Math.max(2, (int) Math.ceil((double) forest.size() / THREAD_POOL_SIZE)); @@ -67,7 +68,7 @@ private List mergeChunk(List chunk) { if (!nodesToMerge.isEmpty()) { String newConceptName = LLMUtils.getConcept(nodesToMerge); - System.out.println("get new concept name: " + newConceptName); + logger.info("get new concept name: {}", newConceptName); List newEmbedding = EmbeddingUtils.getEmbedding(newConceptName); TreeNode newParent = new TreeNode(newConceptName, newConceptName, newEmbedding); diff --git a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java index 2961c77..a8824f0 100644 --- a/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java +++ b/v8/src/main/java/org/apache/zeppelin/iginx/util/algorithm/mergeforest/RandomMerge.java @@ -6,17 +6,18 @@ import org.apache.zeppelin.iginx.util.EmbeddingUtils; import org.apache.zeppelin.iginx.util.LLMUtils; import org.apache.zeppelin.iginx.util.TreeNode; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; public class RandomMerge implements MergeForestStrategy { + private static final Logger logger = LoggerFactory.getLogger(RandomMerge.class); private static final Double SIMILARITY_THRESHOLD = 0.7; private static final Integer MAX_CONTINUOUS_FAILURE_COUNT = 5; /** - * 随机合并算法 - * 1. 循环多轮进行,当某轮结束后的森林大小满足要求或者已经连续 MAX_CONTINUOUS_FAILURE_COUNT 轮合并失败,则退出 - * 2. 每轮随机选择森林中的一棵树,计算其他所有树与它的 embedding 相似度,并从高到底排序 - * 3. 按照顺序对相似度进行累乘:假设相似度是 x,每次乘实际上是乘 1-(1-x)/2 ,乘积如果不小于 SIMILARITY_THRESHOLD 则继续乘,反之则结束 - * 4. 如果累计相似度满足要求的树不少于2棵,则进行合并,反之则计为一次合并失败 + * 随机合并算法 1. 循环多轮进行,当某轮结束后的森林大小满足要求或者已经连续 MAX_CONTINUOUS_FAILURE_COUNT 轮合并失败,则退出 2. + * 每轮随机选择森林中的一棵树,计算其他所有树与它的 embedding 相似度,并从高到底排序 3. 按照顺序对相似度进行累乘:假设相似度是 x,每次乘实际上是乘 1-(1-x)/2 + * ,乘积如果不小于 SIMILARITY_THRESHOLD 则继续乘,反之则结束 4. 如果累计相似度满足要求的树不少于2棵,则进行合并,反之则计为一次合并失败 * * @param root */ @@ -28,7 +29,7 @@ public void mergeForest(TreeNode root) { int failureCount = 0; // 记录连续合并失败次数 while (forest.size() > targetTreeCount && failureCount < MAX_CONTINUOUS_FAILURE_COUNT) { - System.out.println("beginning new merge, the recent size of forest is: " + forest.size()); + logger.info("beginning new merge, the recent size of forest is: {}", forest.size()); // 随机选择一个根节点 TreeNode referenceNode = forest.get(random.nextInt(forest.size())); @@ -60,7 +61,7 @@ public void mergeForest(TreeNode root) { double similarity = EmbeddingUtils.calculateSimilarity(referenceNode.getEmbedding(), node.getEmbedding()); cumulativeSimilarity *= (0.5 + similarity / 2); - System.out.println("now cumulative similarity: " + cumulativeSimilarity); + logger.info("now cumulative similarity: {}", cumulativeSimilarity); if (cumulativeSimilarity >= SIMILARITY_THRESHOLD) { selectedNodes.add(node); @@ -86,16 +87,16 @@ public void mergeForest(TreeNode root) { // 重置连续失败计数器 failureCount = 0; - System.out.println("merge success, now forest size is: " + forest.size()); + logger.info("merge success, now forest size is: {}", forest.size()); } else { // 增加失败计数 failureCount++; - System.out.println("merge failure, now consecutive failure count is: " + failureCount); + logger.info("merge failure, now consecutive failure count is: {}", failureCount); } // 判断是否已经到达连续合并失败的最大上限 if (failureCount >= MAX_CONTINUOUS_FAILURE_COUNT) { - System.out.println("The count of consecutive merge failures has reached the limit, stop merge"); + logger.info("The count of consecutive merge failures has reached the limit, stop merge"); break; } }