Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add merge tree function based on embedding #17

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ IGinX Zeppelin 解释器是需要连接 IGinX 的,如果我们重启了 IGinX
#### 特别说明
##### LOAD DATA 命令
输入LOAD DATA 命令后,需要先点击执行按钮调出文件选择控件,命令中的文件名(1处)与选择文件控件中的名称(2处,选择文件后自动填充,不可手动修改)需要保持一致。
![img.png](img.png)
![img](./images/load_data.png)

##### SHOW COLUMNS 命令
展示可视化资产图会涉及到embedding,请前往 https://nlp.stanford.edu/projects/glove/ 下载相关模型,解压后放入 /resources/model 文件夹下,并根据模型的embedding维度修改embeddingUtils类中的EMBEDDING_DIMENSION参数。
![img](./images/show_columns.png)
此外,当前选用的模型只支持英文,暂不支持中文。

### 使用RESTful语句

Expand Down
File renamed without changes
Binary file added images/show_columns.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 25 additions & 14 deletions v8/src/main/java/org/apache/zeppelin/iginx/IginxInterpreter8.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,8 @@
import org.apache.zeppelin.iginx.util.MultiwayTree;
import org.apache.zeppelin.iginx.util.SqlCmdUtil;
import org.apache.zeppelin.interpreter.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class IginxInterpreter8 extends Interpreter {
private static final Logger LOGGER = LoggerFactory.getLogger(IginxInterpreter8.class);

private static final String IGINX_HOST = "iginx.host";
private static final String IGINX_PORT = "iginx.port";
private static final String IGINX_USERNAME = "iginx.username";
Expand Down Expand Up @@ -345,6 +341,15 @@ public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) {
row -> {
MultiwayTree.addTreeNodeFromString(tree, row.get(0));
});
try {
logger.info("before merge, the size of forest is: ", tree.getRoot().getChildren().size());
// ChunkMerge or RandomMerge
MultiwayTree.mergeTree(tree, "ChunkMerge");
logger.info("after merge, the size of forest is: ", tree.getRoot().getChildren().size());
} catch (Exception e) {
logger.error("merge tree error");
e.printStackTrace();
}
try (InputStream inputStream =
IginxInterpreter8.class.getClassLoader().getResourceAsStream("static/vis/network.html")) {
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
Expand All @@ -359,18 +364,24 @@ public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) {
// 写入vis等库文件,只在新环境执行一次
String targetPath = outfileDir + "/graphs/lib/";
if (!FileUtil.isDirectoryLoaded(targetPath)) {
logger.info("buildNetworkForShowColumns: upload static resources start");
String sourcePath = "static/vis/lib/";
String jarUrl =
Objects.requireNonNull(IginxInterpreter8.class.getClassLoader().getResource(sourcePath))
.toString();
String jarPath = jarUrl.substring(jarUrl.indexOf("file:") + 5, jarUrl.indexOf(".jar") + 4);
FileUtil.extractDirectoryFromJar(jarPath, sourcePath, targetPath);
logger.info("buildNetworkForShowColumns: upload static resources finish");
}
// 写入network html
File networkHtml = new File(outfileDir + "/graphs/network.html");
logger.info(
"buildNetworkForShowColumns: the absolute path of the output network.html is {}",
networkHtml.getAbsolutePath());
OutputStream outputStream = Files.newOutputStream(networkHtml.toPath());
outputStream.write(html.getBytes());
outputStream.close();
logger.info("buildNetworkForShowColumns: html(string) write to network.html finish");
InputStream inputStreamMain =
IginxInterpreter8.class.getClassLoader().getResourceAsStream("static/vis/main.html");
BufferedReader br = new BufferedReader(new InputStreamReader(inputStreamMain));
Expand All @@ -382,7 +393,7 @@ public String buildNetworkForShowColumns(SessionExecuteSqlResult sqlResult) {
.replace("FILE_HOST", fileHttpHost)
.replace("FILE_PORT", String.valueOf(fileHttpPort));
} catch (IOException e) {
LOGGER.warn("load show columns to tree error", e);
logger.warn("load show columns to tree error", e);
}
return "";
}
Expand All @@ -408,7 +419,7 @@ private InterpreterResult processLoadCsv(String sql, InterpreterContext context)
InterpreterResult interpreterResult;
String uploadParagraphKey = context.getParagraphId() + "_UPLOAD_FILE";
/* response upload file form, user will rerun paragraph when upload finished. */
LOGGER.info("+++++++Id={}, paragraphId={}", context.getNoteId(), uploadParagraphKey);
logger.info("+++++++Id={}, paragraphId={}", context.getNoteId(), uploadParagraphKey);
if (!uploadParagraphSet.contains(uploadParagraphKey)) {
try (InputStream inputStream =
IginxInterpreter8.class.getClassLoader().getResourceAsStream("uploadForm.html");
Expand All @@ -429,13 +440,13 @@ private InterpreterResult processLoadCsv(String sql, InterpreterContext context)
uploadParagraphSet.add(uploadParagraphKey);
return interpreterResult;
} catch (IOException e) {
LOGGER.error("load html error", e);
logger.error("load html error", e);
}
return new InterpreterResult(InterpreterResult.Code.ERROR);
}

try {
LOGGER.info("load data sql execute, sql={}", sql);
logger.info("load data sql execute, sql={}", sql);
SessionExecuteSqlResult res = session.executeSql(sql);
String parseErrorMsg = res.getParseErrorMsg();
if (parseErrorMsg != null && !parseErrorMsg.isEmpty()) {
Expand Down Expand Up @@ -494,11 +505,11 @@ private String convertPath(String inputPath, String sql) throws IOException {
String path;
Path pathObj;
if (SystemUtils.IS_OS_WINDOWS) {
LOGGER.info("current os is Windows");
logger.info("current os is Windows");
path = inputPath.replace("/", "\\");
pathObj = Paths.get(path);
} else {
LOGGER.info("current os is Linux or Mac");
logger.info("current os is Linux or Mac");
path = inputPath.replace("\\", "/");
pathObj = Paths.get(path);
}
Expand All @@ -507,7 +518,7 @@ private String convertPath(String inputPath, String sql) throws IOException {
HttpUtil.getCurrentPath(DEFAULT_UPLOAD_DIR)
+ File.separator
+ pathObj.getFileName().toString();
LOGGER.info("converted path is {}", path);
logger.info("converted path is {}", path);
return path;
}

Expand Down Expand Up @@ -1041,7 +1052,7 @@ public InterpreterResult tuneFontSize(
} else {
hTagNumber = 6;
}
LOGGER.info(
logger.info(
"NoteId={},ParagraphId={},fontSizeEnable={},fontSize={}",
context.getNoteId(),
context.getParagraphId(),
Expand All @@ -1053,7 +1064,7 @@ public InterpreterResult tuneFontSize(
message.stream()
.map(
item -> {
LOGGER.debug("type={},data={}", item.getType(), item.getData());
logger.debug("type={},data={}", item.getType(), item.getData());
if (item.getType().equals(InterpreterResult.Type.TABLE)) {
String collect =
Arrays.stream(item.getData().split(NEWLINE))
Expand Down Expand Up @@ -1082,7 +1093,7 @@ public InterpreterResult tuneFontSize(
item.getType(),
String.format("<h%d>%s</h%d>", hTagNumber, item.getData(), hTagNumber));
} else {
LOGGER.warn("unexpected result type {}", item.getType());
logger.warn("unexpected result type {}", item.getType());
}
return new InterpreterResultMessage(item.getType(), item.getData());
})
Expand Down
123 changes: 123 additions & 0 deletions v8/src/main/java/org/apache/zeppelin/iginx/util/EmbeddingUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package org.apache.zeppelin.iginx.util;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EmbeddingUtils {
private static final Logger logger = LoggerFactory.getLogger(EmbeddingUtils.class);
private static final Map<String, List<Double>> embeddings = new HashMap<>();
private static final int EMBEDDING_DIMENSION = 50; // 嵌入向量维度
private static final Random RANDOM = new Random(42); // 固定随机种子

static {
try {
InputStream inputStream =
EmbeddingUtils.class.getClassLoader().getResourceAsStream("model/glove.6B.50d.txt");
if (inputStream == null) {
throw new IllegalArgumentException("fail to find GloVe model");
}

try (BufferedReader br = new BufferedReader(new InputStreamReader(inputStream))) {
String line;
while ((line = br.readLine()) != null) {
String[] tokens = line.split(" ");
String word = tokens[0];
List<Double> vector = new ArrayList<>(tokens.length - 1);
for (int i = 1; i < tokens.length; i++) {
vector.add(Double.parseDouble(tokens[i]));
}
embeddings.put(word, vector);
}
}
logger.info("load GloVe success");
} catch (Exception e) {
throw new RuntimeException("load GloVe fail", e);
}
}

/**
* 获取输入内容的 embedding 1. 直接通过 Map 查找,若找到,直接返回结果 2. 若找不到,则对输入内容按照 "-", " ", "_" 进行划分,再分别获取 embeddig
* 并取平均 3. 若输入的内容找不到且已无法再划分,则随机一个 embedding (固定了随机种子,为了让每次执行的结果一致)
*
* @param word
* @return
*/
public static List<Double> getEmbedding(String word) {
System.out.println("getEmbedding: input word is: " + word);
logger.info("getEmbedding: input word is: {}", word);
List<Double> embedding = embeddings.get(word);

if (embedding == null) {
// 按照"-"," ","_"分割取embedding的平均值
String[] parts = word.split("[-\\s_]+");
if (parts.length > 1) {
List<Double> sumEmbedding = new ArrayList<>(EMBEDDING_DIMENSION);
for (int i = 0; i < EMBEDDING_DIMENSION; i++) {
sumEmbedding.add(0.0); // 初始化为零向量
}
double sum = 0;
for (String part : parts) {
List<Double> partEmbedding = getEmbedding(part);
if (partEmbedding != null) {
for (int i = 0; i < EMBEDDING_DIMENSION; i++) {
sumEmbedding.set(i, sumEmbedding.get(i) + partEmbedding.get(i)); // 按维度累加
}
sum += 1;
}
}
for (int i = 0; i < EMBEDDING_DIMENSION; i++) {
sumEmbedding.set(i, sumEmbedding.get(i) / sum);
}
return sumEmbedding;
} else {
// 如果所有部分都未找到且不可拆分,生成随机向量
embedding = generateRandomVector(EMBEDDING_DIMENSION);
logger.info(
"fail to find the embedding of '" + word + "', generating Random Vector to replace");
}
}
return embedding;
}

private static List<Double> generateRandomVector(int dimension) {
List<Double> vector = new ArrayList<>(dimension);
for (int i = 0; i < dimension; i++) {
vector.add(RANDOM.nextDouble() * 2 - 1); // 随机值范围 [-1, 1]
}
return vector;
}

public static double calculateSimilarity(List<Double> embedding1, List<Double> embedding2) {
if (embedding1 == null || embedding2 == null || embedding1.size() != embedding2.size()) {
throw new IllegalArgumentException(
"Embeddings must not be null and must have the same length");
}

double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;

for (int i = 0; i < embedding1.size(); i++) {
double valueA = embedding1.get(i);
double valueB = embedding2.get(i);
dotProduct += valueA * valueB;
normA += Math.pow(valueA, 2);
normB += Math.pow(valueB, 2);
}

return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

public static void main(String[] args) {
// 测试:获取两个句子的嵌入向量并计算相似度
List<Double> embedding1 = getEmbedding("summer_2024");
List<Double> embedding2 = getEmbedding("climate");

double similarity = calculateSimilarity(embedding1, embedding2);
System.out.println("相似度: " + similarity);
}
}
Loading