diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java index 3bcdc790c..93c1e716e 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/controller/VisualizationController.java @@ -24,12 +24,13 @@ import com.linkedin.coral.coralservice.entity.VisualizationRequestBody; import com.linkedin.coral.coralservice.entity.VisualizationResponseBody; +import com.linkedin.coral.coralservice.utils.RewriteType; import static com.linkedin.coral.coralservice.utils.VisualizationUtils.*; @RestController -@RequestMapping("/api/visualization") +@RequestMapping("/api/visualizations") public class VisualizationController { private File imageDir = createImageDir(); @@ -37,20 +38,36 @@ public class VisualizationController { public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody visualizationRequestBody) { final String fromLanguage = visualizationRequestBody.getFromLanguage(); final String query = visualizationRequestBody.getQuery(); - // final VisualizationRequestBody.RewriteType rewriteType = visualizationRequestBody.getRewriteType(); - UUID sqlNodeImageID; - UUID relNodeImageID; + final RewriteType rewriteType = visualizationRequestBody.getRewriteType(); + + UUID sqlNodeImageID, relNodeImageID; + UUID postRewriteSqlNodeImageID = null; + UUID postRewriteRelNodeImageID = null; try { - sqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir); - relNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir); + // Always generate the pre/no rewrite images first + sqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE); + relNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE); + assert !sqlNodeImageID.equals(relNodeImageID); + + if (rewriteType != RewriteType.NONE && rewriteType != null) { + // A rewrite was requested + postRewriteRelNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, rewriteType); + postRewriteSqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, rewriteType); + assert !postRewriteSqlNodeImageID.equals(postRewriteRelNodeImageID); + } + } catch (Throwable t) { t.printStackTrace(); return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(t.getMessage()); } + + // Build response body VisualizationResponseBody responseBody = new VisualizationResponseBody(); responseBody.setSqlNodeImageID(sqlNodeImageID); responseBody.setRelNodeImageID(relNodeImageID); + responseBody.setPostRewriteSqlNodeImageID(postRewriteSqlNodeImageID); + responseBody.setPostRewriteRelNodeImageID(postRewriteRelNodeImageID); return ResponseEntity.status(HttpStatus.OK).body(responseBody); } diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationRequestBody.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationRequestBody.java index d4ca9e257..6a42ebf07 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationRequestBody.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationRequestBody.java @@ -5,11 +5,14 @@ */ package com.linkedin.coral.coralservice.entity; +import com.linkedin.coral.coralservice.utils.RewriteType; + + public class VisualizationRequestBody { private String fromLanguage; private String query; - private String rewriteType; + private RewriteType rewriteType; public String getFromLanguage() { return fromLanguage; @@ -19,24 +22,7 @@ public String getQuery() { return query; } - public String getRewriteType() { + public RewriteType getRewriteType() { return rewriteType; } - - public enum RewriteType { - NONE("none"), - INCREMENTAL("incremental"), - DATAMASKING("datamasking"); - - private final String type; - - RewriteType(String description) { - this.type = description; - } - - @Override - public String toString() { - return type; - } - } } diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationResponseBody.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationResponseBody.java index cb5042b97..2c4131e89 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationResponseBody.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/entity/VisualizationResponseBody.java @@ -11,9 +11,9 @@ public class VisualizationResponseBody { private UUID sqlNodeImageID; - private UUID relNodeImageID; - + private UUID postRewriteSqlNodeImageID; + private UUID postRewriteRelNodeImageID; public VisualizationResponseBody() { } @@ -25,6 +25,14 @@ public UUID getRelNodeImageID() { return relNodeImageID; } + public UUID getPostRewriteSqlNodeImageID() { + return postRewriteSqlNodeImageID; + } + + public void setPostRewriteSqlNodeImageID(UUID postRewriteSqlNodeImageID) { + this.postRewriteSqlNodeImageID = postRewriteSqlNodeImageID; + } + public void setSqlNodeImageID(UUID sqlNodeImageID) { this.sqlNodeImageID = sqlNodeImageID; } @@ -33,4 +41,12 @@ public void setRelNodeImageID(UUID relNodeImageID) { this.relNodeImageID = relNodeImageID; } + public UUID getPostRewriteRelNodeImageID() { + return postRewriteRelNodeImageID; + } + + public void setPostRewriteRelNodeImageID(UUID postRewriteRelNodeImageID) { + this.postRewriteRelNodeImageID = postRewriteRelNodeImageID; + } + } diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/RewriteType.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/RewriteType.java new file mode 100644 index 000000000..5baaf9f1a --- /dev/null +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/RewriteType.java @@ -0,0 +1,37 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.coralservice.utils; + +import com.fasterxml.jackson.annotation.JsonCreator; + + +public enum RewriteType { + NONE("none"), + INCREMENTAL("incremental"), + DATAMASKING("datamasking"); + + private final String type; + + RewriteType(String type) { + this.type = type; + } + + @Override + public String toString() { + return type; + } + + @JsonCreator + public static RewriteType getDepartmentFromCode(String value) { + for (RewriteType type : RewriteType.values()) { + if (type.toString().equals(value)) { + return type; + } + } + return null; + } + +} diff --git a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java index 24c521941..fc1f79b8b 100644 --- a/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java +++ b/coral-service/src/main/java/com/linkedin/coral/coralservice/utils/VisualizationUtils.java @@ -12,6 +12,8 @@ import org.apache.calcite.sql.SqlNode; import com.linkedin.coral.hive.hive2rel.HiveToRelConverter; +import com.linkedin.coral.incremental.RelNodeIncrementalTransformer; +import com.linkedin.coral.transformers.CoralRelToSqlNodeConverter; import com.linkedin.coral.trino.trino2rel.TrinoToRelConverter; import com.linkedin.coral.vis.VisualizationUtil; @@ -24,7 +26,11 @@ public static File createImageDir() { return new File(System.getProperty("java.io.tmpdir") + "/images" + UUID.randomUUID()); } - public static UUID generateSqlNodeVisualization(String query, String fromLanguage, File imageDir) { + public static RelNode incrementalRewrittenRelNode = null; + + public static UUID generateSqlNodeVisualization(String query, String fromLanguage, File imageDir, + RewriteType rewriteType) { + SqlNode sqlNode = null; if (fromLanguage.equalsIgnoreCase("trino")) { sqlNode = new TrinoToRelConverter(hiveMetastoreClient).toSqlNode(query); @@ -32,6 +38,11 @@ public static UUID generateSqlNodeVisualization(String query, String fromLanguag sqlNode = new HiveToRelConverter(hiveMetastoreClient).toSqlNode(query); } + if (incrementalRewrittenRelNode != null && rewriteType == RewriteType.INCREMENTAL) { + // We want to instead generate the visualization of SqlNode2 of the RHS of Coral's translation + sqlNode = new CoralRelToSqlNodeConverter().convert(incrementalRewrittenRelNode); + } + assert sqlNode != null; VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir); @@ -41,7 +52,8 @@ public static UUID generateSqlNodeVisualization(String query, String fromLanguag return sqlNodeId; } - public static UUID generateRelNodeVisualization(String query, String fromLanguage, File imageDir) { + public static UUID generateRelNodeVisualization(String query, String fromLanguage, File imageDir, + RewriteType rewriteType) { RelNode relNode = null; if (fromLanguage.equalsIgnoreCase("trino")) { relNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query); @@ -49,6 +61,11 @@ public static UUID generateRelNodeVisualization(String query, String fromLanguag relNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query); } + if (rewriteType == RewriteType.INCREMENTAL) { + relNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode); + incrementalRewrittenRelNode = relNode; + } + assert relNode != null; VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir);