Skip to content

Commit

Permalink
clean up image generation code
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinGe00 committed Sep 16, 2023
1 parent 08633fa commit b94c23c
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ else if (toLanguage.equalsIgnoreCase("trino")) {
}
}
} catch (Throwable t) {
// TODO: use logger
t.printStackTrace();
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(t.getMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
*/
package com.linkedin.coral.coralservice.controller;

import com.linkedin.coral.coralservice.utils.VisualizationUtils;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.UUID;

import org.springframework.core.io.FileSystemResource;
Expand All @@ -33,41 +35,45 @@
@RequestMapping("/api/visualizations")
public class VisualizationController {
private File imageDir = createImageDir();
private VisualizationUtils visualizationUtils = new VisualizationUtils();

@PostMapping("/generategraphs")
public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody visualizationRequestBody) {
final String fromLanguage = visualizationRequestBody.getFromLanguage();
final String query = visualizationRequestBody.getQuery();
final RewriteType rewriteType = visualizationRequestBody.getRewriteType();

UUID sqlNodeImageID, relNodeImageID;
UUID postRewriteSqlNodeImageID = null;
UUID postRewriteRelNodeImageID = null;
if (!visualizationUtils.isValidFromLanguage(fromLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Currently, only Hive and Trino are supported as engines to generate graphs using.\n");
}

// A list of UUIDs in this order of:
// 1. Image ID of pre/no rewrite relNode
// 2. Image ID of pre/no rewrite sqlNode
// If a rewrite was requested:
// 3. Image ID of post rewrite relNode
// 4. Image ID of post rewrite sqlNode
ArrayList<UUID> imageIdList;
try {
// Always generate the pre/no rewrite images first
sqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE);
relNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE);
assert !sqlNodeImageID.equals(relNodeImageID);

if (rewriteType != RewriteType.NONE && rewriteType != null) {
// A rewrite was requested
postRewriteRelNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, rewriteType);
postRewriteSqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, rewriteType);
assert !postRewriteSqlNodeImageID.equals(postRewriteRelNodeImageID);
}

imageIdList = visualizationUtils.generateIRVisualizations(query, fromLanguage, imageDir, rewriteType);
} catch (Throwable t) {
// TODO: use logger
t.printStackTrace();
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(t.getMessage());
}

assert imageIdList.size() > 0;

// Build response body
VisualizationResponseBody responseBody = new VisualizationResponseBody();
responseBody.setSqlNodeImageID(sqlNodeImageID);
responseBody.setRelNodeImageID(relNodeImageID);
responseBody.setPostRewriteSqlNodeImageID(postRewriteSqlNodeImageID);
responseBody.setPostRewriteRelNodeImageID(postRewriteRelNodeImageID);
responseBody.setRelNodeImageID(imageIdList.get(0));
responseBody.setSqlNodeImageID(imageIdList.get(1));
if (imageIdList.size() >= 4) {
// Rewrite was requested
responseBody.setPostRewriteRelNodeImageID(imageIdList.get(2));
responseBody.setPostRewriteSqlNodeImageID(imageIdList.get(3));
}

return ResponseEntity.status(HttpStatus.OK).body(responseBody);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package com.linkedin.coral.coralservice.utils;

import java.io.File;
import java.util.ArrayList;
import java.util.UUID;

import org.apache.calcite.rel.RelNode;
Expand All @@ -18,7 +19,6 @@
import com.linkedin.coral.vis.VisualizationUtil;

import static com.linkedin.coral.coralservice.utils.CoralProvider.*;
import static com.linkedin.coral.coralservice.utils.RewriteType.*;


public class VisualizationUtils {
Expand All @@ -27,60 +27,102 @@ public static File createImageDir() {
return new File(System.getProperty("java.io.tmpdir") + "/images" + UUID.randomUUID());
}

public static RelNode incrementalRewrittenRelNode = null;
public boolean isValidFromLanguage(String fromLanguage) {
return fromLanguage.equalsIgnoreCase("trino") || fromLanguage.equalsIgnoreCase("hive");
}

public static UUID generateSqlNodeVisualization(String query, String fromLanguage, File imageDir,
public ArrayList<UUID> generateIRVisualizations(String query, String fromLanguage, File imageDir,
RewriteType rewriteType) {

SqlNode sqlNode = null;
if (fromLanguage.equalsIgnoreCase("trino")) {
sqlNode = new TrinoToRelConverter(hiveMetastoreClient).toSqlNode(query);
} else if (fromLanguage.equalsIgnoreCase("hive")) {
sqlNode = new HiveToRelConverter(hiveMetastoreClient).toSqlNode(query);
}

if (incrementalRewrittenRelNode != null && rewriteType == INCREMENTAL) {
// We want to instead generate the visualization of SqlNode2 of the RHS of Coral's translation
sqlNode = new CoralRelToSqlNodeConverter().convert(incrementalRewrittenRelNode);
ArrayList<UUID> imageIDList = new ArrayList<>();

// Always generate the pre/no rewrite images first
RelNodeAndID relNodeAndID = generateRelNodeVisualization(getRelNode(query, fromLanguage), imageDir);
imageIDList.add(relNodeAndID.id);

SqlNodeAndID sqlNodeAndID = generateSqlNodeVisualization(getSqlNode(query, fromLanguage), imageDir);
imageIDList.add(sqlNodeAndID.id);

// Generate rewritten IR images if requested, otherwise, simply return the non rewritten image ids
RelNode preRewriteRelNode = relNodeAndID.relNode;
RelNode postRewriteRelNode;

if (rewriteType != RewriteType.NONE && rewriteType != null) {
switch (rewriteType) {
case INCREMENTAL:
postRewriteRelNode = RelNodeIncrementalTransformer.convertRelIncremental(preRewriteRelNode);
break;
case DATAMASKING:
default:
return imageIDList;
}
RelNodeAndID postRewroteRelNodeAndID = generateRelNodeVisualization(postRewriteRelNode, imageDir);
imageIDList.add(postRewroteRelNodeAndID.id);

SqlNode postRewriteSqlNode = new CoralRelToSqlNodeConverter().convert(postRewriteRelNode);
SqlNodeAndID postRewriteSqlNodeAndID = generateSqlNodeVisualization(postRewriteSqlNode, imageDir);
imageIDList.add(postRewriteSqlNodeAndID.id);
}

assert sqlNode != null;
return imageIDList;
}

private SqlNodeAndID generateSqlNodeVisualization(SqlNode sqlNode, File imageDir) {
// Generate graphviz svg using sqlNode
VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir);
UUID sqlNodeId = UUID.randomUUID();
visualizationUtil.visualizeSqlNodeToFile(sqlNode, "/" + sqlNodeId + ".svg");

return sqlNodeId;
return new SqlNodeAndID(sqlNodeId, sqlNode);
}

public static UUID generateRelNodeVisualization(String query, String fromLanguage, File imageDir,
RewriteType rewriteType) {
private RelNodeAndID generateRelNodeVisualization(RelNode relNode, File imageDir) {
// Generate graphviz svg using relNode
VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir);
UUID relNodeID = UUID.randomUUID();
visualizationUtil.visualizeRelNodeToFile(relNode, "/" + relNodeID + ".svg");

return new RelNodeAndID(relNodeID, relNode);
}

private RelNode getRelNode(String query, String fromLanguage) {
RelNode relNode = null;
if (fromLanguage.equalsIgnoreCase("trino")) {
relNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query);
} else if (fromLanguage.equalsIgnoreCase("hive")) {
relNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
}

switch (rewriteType) {
case INCREMENTAL:
relNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode);
incrementalRewrittenRelNode = relNode;
break;
case DATAMASKING:
case NONE:
default:
break;
return relNode;
}

private SqlNode getSqlNode(String query, String fromLanguage) {
SqlNode sqlNode = null;
if (fromLanguage.equalsIgnoreCase("trino")) {
sqlNode = new TrinoToRelConverter(hiveMetastoreClient).toSqlNode(query);
} else if (fromLanguage.equalsIgnoreCase("hive")) {
sqlNode = new HiveToRelConverter(hiveMetastoreClient).toSqlNode(query);
}

assert relNode != null;
return sqlNode;
}

// Generate graphviz svg using relNode
VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir);
UUID relNodeID = UUID.randomUUID();
visualizationUtil.visualizeRelNodeToFile(relNode, "/" + relNodeID + ".svg");
private class RelNodeAndID {
private UUID id;
private RelNode relNode;

public RelNodeAndID(UUID id, RelNode relNode) {
this.id = id;
this.relNode = relNode;
}
}

return relNodeID;
private class SqlNodeAndID {
private UUID id;
private SqlNode sqlNode;

public SqlNodeAndID(UUID id, SqlNode sqlNode) {
this.id = id;
this.sqlNode = sqlNode;
}
}
}

0 comments on commit b94c23c

Please sign in to comment.