diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/TranslationController.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/TranslationController.java index a58877632..e52c79bbd 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/TranslationController.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/TranslationController.java @@ -89,6 +89,7 @@ else if (toLanguage.equalsIgnoreCase("trino")) { } } } catch (Throwable t) { + // TODO: use logger t.printStackTrace(); return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(t.getMessage()); } diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java index be87cab4b..2088210f4 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java @@ -5,10 +5,12 @@ */ package com.linkedin.coral.coralservice.controller; +import com.linkedin.coral.coralservice.utils.VisualizationUtils; import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; import java.util.UUID; import org.springframework.core.io.FileSystemResource; @@ -33,6 +35,7 @@ @RequestMapping("/api/visualizations") public class VisualizationController { private File imageDir = createImageDir(); + private VisualizationUtils visualizationUtils = new VisualizationUtils(); @PostMapping("/generategraphs") public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody visualizationRequestBody) { @@ -40,34 +43,37 @@ public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody final String query = visualizationRequestBody.getQuery(); final RewriteType rewriteType = visualizationRequestBody.getRewriteType(); - UUID sqlNodeImageID, relNodeImageID; - UUID postRewriteSqlNodeImageID = null; - UUID postRewriteRelNodeImageID = null; + if (!visualizationUtils.isValidFromLanguage(fromLanguage)) { + return ResponseEntity.status(HttpStatus.BAD_REQUEST) + .body("Currently, only Hive and Trino are supported as engines to generate graphs using.\n"); + } + // A list of UUIDs in this order of: + // 1. Image ID of pre/no rewrite relNode + // 2. Image ID of pre/no rewrite sqlNode + // If a rewrite was requested: + // 3. Image ID of post rewrite relNode + // 4. Image ID of post rewrite sqlNode + ArrayList imageIdList; try { - // Always generate the pre/no rewrite images first - sqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE); - relNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE); - assert !sqlNodeImageID.equals(relNodeImageID); - - if (rewriteType != RewriteType.NONE && rewriteType != null) { - // A rewrite was requested - postRewriteRelNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, rewriteType); - postRewriteSqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, rewriteType); - assert !postRewriteSqlNodeImageID.equals(postRewriteRelNodeImageID); - } - + imageIdList = visualizationUtils.generateIRVisualizations(query, fromLanguage, imageDir, rewriteType); } catch (Throwable t) { + // TODO: use logger t.printStackTrace(); return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(t.getMessage()); } + assert imageIdList.size() > 0; + // Build response body VisualizationResponseBody responseBody = new VisualizationResponseBody(); - responseBody.setSqlNodeImageID(sqlNodeImageID); - responseBody.setRelNodeImageID(relNodeImageID); - responseBody.setPostRewriteSqlNodeImageID(postRewriteSqlNodeImageID); - responseBody.setPostRewriteRelNodeImageID(postRewriteRelNodeImageID); + responseBody.setRelNodeImageID(imageIdList.get(0)); + responseBody.setSqlNodeImageID(imageIdList.get(1)); + if (imageIdList.size() >= 4) { + // Rewrite was requested + responseBody.setPostRewriteRelNodeImageID(imageIdList.get(2)); + responseBody.setPostRewriteSqlNodeImageID(imageIdList.get(3)); + } return ResponseEntity.status(HttpStatus.OK).body(responseBody); } diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java index 23932ef47..088ca08c2 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java @@ -6,6 +6,7 @@ package com.linkedin.coral.coralservice.utils; import java.io.File; +import java.util.ArrayList; import java.util.UUID; import org.apache.calcite.rel.RelNode; @@ -18,7 +19,6 @@ import com.linkedin.coral.vis.VisualizationUtil; import static com.linkedin.coral.coralservice.utils.CoralProvider.*; -import static com.linkedin.coral.coralservice.utils.RewriteType.*; public class VisualizationUtils { @@ -27,35 +27,64 @@ public static File createImageDir() { return new File(System.getProperty("java.io.tmpdir") + "/images" + UUID.randomUUID()); } - public static RelNode incrementalRewrittenRelNode = null; + public boolean isValidFromLanguage(String fromLanguage) { + return fromLanguage.equalsIgnoreCase("trino") || fromLanguage.equalsIgnoreCase("hive"); + } - public static UUID generateSqlNodeVisualization(String query, String fromLanguage, File imageDir, + public ArrayList generateIRVisualizations(String query, String fromLanguage, File imageDir, RewriteType rewriteType) { - - SqlNode sqlNode = null; - if (fromLanguage.equalsIgnoreCase("trino")) { - sqlNode = new TrinoToRelConverter(hiveMetastoreClient).toSqlNode(query); - } else if (fromLanguage.equalsIgnoreCase("hive")) { - sqlNode = new HiveToRelConverter(hiveMetastoreClient).toSqlNode(query); - } - - if (incrementalRewrittenRelNode != null && rewriteType == INCREMENTAL) { - // We want to instead generate the visualization of SqlNode2 of the RHS of Coral's translation - sqlNode = new CoralRelToSqlNodeConverter().convert(incrementalRewrittenRelNode); + ArrayList imageIDList = new ArrayList<>(); + + // Always generate the pre/no rewrite images first + RelNodeAndID relNodeAndID = generateRelNodeVisualization(getRelNode(query, fromLanguage), imageDir); + imageIDList.add(relNodeAndID.id); + + SqlNodeAndID sqlNodeAndID = generateSqlNodeVisualization(getSqlNode(query, fromLanguage), imageDir); + imageIDList.add(sqlNodeAndID.id); + + // Generate rewritten IR images if requested, otherwise, simply return the non rewritten image ids + RelNode preRewriteRelNode = relNodeAndID.relNode; + RelNode postRewriteRelNode; + + if (rewriteType != RewriteType.NONE && rewriteType != null) { + switch (rewriteType) { + case INCREMENTAL: + postRewriteRelNode = RelNodeIncrementalTransformer.convertRelIncremental(preRewriteRelNode); + break; + case DATAMASKING: + default: + return imageIDList; + } + RelNodeAndID postRewroteRelNodeAndID = generateRelNodeVisualization(postRewriteRelNode, imageDir); + imageIDList.add(postRewroteRelNodeAndID.id); + + SqlNode postRewriteSqlNode = new CoralRelToSqlNodeConverter().convert(postRewriteRelNode); + SqlNodeAndID postRewriteSqlNodeAndID = generateSqlNodeVisualization(postRewriteSqlNode, imageDir); + imageIDList.add(postRewriteSqlNodeAndID.id); } - assert sqlNode != null; + return imageIDList; + } + private SqlNodeAndID generateSqlNodeVisualization(SqlNode sqlNode, File imageDir) { // Generate graphviz svg using sqlNode VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir); UUID sqlNodeId = UUID.randomUUID(); visualizationUtil.visualizeSqlNodeToFile(sqlNode, "/" + sqlNodeId + ".svg"); - return sqlNodeId; + return new SqlNodeAndID(sqlNodeId, sqlNode); } - public static UUID generateRelNodeVisualization(String query, String fromLanguage, File imageDir, - RewriteType rewriteType) { + private RelNodeAndID generateRelNodeVisualization(RelNode relNode, File imageDir) { + // Generate graphviz svg using relNode + VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir); + UUID relNodeID = UUID.randomUUID(); + visualizationUtil.visualizeRelNodeToFile(relNode, "/" + relNodeID + ".svg"); + + return new RelNodeAndID(relNodeID, relNode); + } + + private RelNode getRelNode(String query, String fromLanguage) { RelNode relNode = null; if (fromLanguage.equalsIgnoreCase("trino")) { relNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query); @@ -63,24 +92,37 @@ public static UUID generateRelNodeVisualization(String query, String fromLanguag relNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query); } - switch (rewriteType) { - case INCREMENTAL: - relNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode); - incrementalRewrittenRelNode = relNode; - break; - case DATAMASKING: - case NONE: - default: - break; + return relNode; + } + + private SqlNode getSqlNode(String query, String fromLanguage) { + SqlNode sqlNode = null; + if (fromLanguage.equalsIgnoreCase("trino")) { + sqlNode = new TrinoToRelConverter(hiveMetastoreClient).toSqlNode(query); + } else if (fromLanguage.equalsIgnoreCase("hive")) { + sqlNode = new HiveToRelConverter(hiveMetastoreClient).toSqlNode(query); } - assert relNode != null; + return sqlNode; + } - // Generate graphviz svg using relNode - VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir); - UUID relNodeID = UUID.randomUUID(); - visualizationUtil.visualizeRelNodeToFile(relNode, "/" + relNodeID + ".svg"); + private class RelNodeAndID { + private UUID id; + private RelNode relNode; + + public RelNodeAndID(UUID id, RelNode relNode) { + this.id = id; + this.relNode = relNode; + } + } - return relNodeID; + private class SqlNodeAndID { + private UUID id; + private SqlNode sqlNode; + + public SqlNodeAndID(UUID id, SqlNode sqlNode) { + this.id = id; + this.sqlNode = sqlNode; + } } }