Skip to content

Commit

Permalink
add rewrite logic
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinGe00 committed Sep 14, 2023
1 parent 0a5a16f commit db22621
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,50 @@

import com.linkedin.coral.coralservice.entity.VisualizationRequestBody;
import com.linkedin.coral.coralservice.entity.VisualizationResponseBody;
import com.linkedin.coral.coralservice.utils.RewriteType;

import static com.linkedin.coral.coralservice.utils.VisualizationUtils.*;


@RestController
@RequestMapping("/api/visualization")
@RequestMapping("/api/visualizations")
public class VisualizationController {
private File imageDir = createImageDir();

@PostMapping("/generategraphs")
public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody visualizationRequestBody) {
final String fromLanguage = visualizationRequestBody.getFromLanguage();
final String query = visualizationRequestBody.getQuery();
// final VisualizationRequestBody.RewriteType rewriteType = visualizationRequestBody.getRewriteType();
UUID sqlNodeImageID;
UUID relNodeImageID;
final RewriteType rewriteType = visualizationRequestBody.getRewriteType();

UUID sqlNodeImageID, relNodeImageID;
UUID postRewriteSqlNodeImageID = null;
UUID postRewriteRelNodeImageID = null;

try {
sqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir);
relNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir);
// Always generate the pre/no rewrite images first
sqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE);
relNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, RewriteType.NONE);
assert !sqlNodeImageID.equals(relNodeImageID);

if (rewriteType != RewriteType.NONE && rewriteType != null) {
// A rewrite was requested
postRewriteRelNodeImageID = generateRelNodeVisualization(query, fromLanguage, imageDir, rewriteType);
postRewriteSqlNodeImageID = generateSqlNodeVisualization(query, fromLanguage, imageDir, rewriteType);
assert !postRewriteSqlNodeImageID.equals(postRewriteRelNodeImageID);
}

} catch (Throwable t) {
t.printStackTrace();
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body(t.getMessage());
}

// Build response body
VisualizationResponseBody responseBody = new VisualizationResponseBody();
responseBody.setSqlNodeImageID(sqlNodeImageID);
responseBody.setRelNodeImageID(relNodeImageID);
responseBody.setPostRewriteSqlNodeImageID(postRewriteSqlNodeImageID);
responseBody.setPostRewriteRelNodeImageID(postRewriteRelNodeImageID);

return ResponseEntity.status(HttpStatus.OK).body(responseBody);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
*/
package com.linkedin.coral.coralservice.entity;

import com.linkedin.coral.coralservice.utils.RewriteType;


public class VisualizationRequestBody {
private String fromLanguage;
private String query;

private String rewriteType;
private RewriteType rewriteType;

public String getFromLanguage() {
return fromLanguage;
Expand All @@ -19,24 +22,7 @@ public String getQuery() {
return query;
}

public String getRewriteType() {
public RewriteType getRewriteType() {
return rewriteType;
}

public enum RewriteType {
NONE("none"),
INCREMENTAL("incremental"),
DATAMASKING("datamasking");

private final String type;

RewriteType(String description) {
this.type = description;
}

@Override
public String toString() {
return type;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
public class VisualizationResponseBody {

private UUID sqlNodeImageID;

private UUID relNodeImageID;

private UUID postRewriteSqlNodeImageID;
private UUID postRewriteRelNodeImageID;
public VisualizationResponseBody() {
}

Expand All @@ -25,6 +25,14 @@ public UUID getRelNodeImageID() {
return relNodeImageID;
}

public UUID getPostRewriteSqlNodeImageID() {
return postRewriteSqlNodeImageID;
}

public void setPostRewriteSqlNodeImageID(UUID postRewriteSqlNodeImageID) {
this.postRewriteSqlNodeImageID = postRewriteSqlNodeImageID;
}

public void setSqlNodeImageID(UUID sqlNodeImageID) {
this.sqlNodeImageID = sqlNodeImageID;
}
Expand All @@ -33,4 +41,12 @@ public void setRelNodeImageID(UUID relNodeImageID) {
this.relNodeImageID = relNodeImageID;
}

public UUID getPostRewriteRelNodeImageID() {
return postRewriteRelNodeImageID;
}

public void setPostRewriteRelNodeImageID(UUID postRewriteRelNodeImageID) {
this.postRewriteRelNodeImageID = postRewriteRelNodeImageID;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.coralservice.utils;

import com.fasterxml.jackson.annotation.JsonCreator;


public enum RewriteType {
NONE("none"),
INCREMENTAL("incremental"),
DATAMASKING("datamasking");

private final String type;

RewriteType(String type) {
this.type = type;
}

@Override
public String toString() {
return type;
}

@JsonCreator
public static RewriteType getDepartmentFromCode(String value) {
for (RewriteType type : RewriteType.values()) {
if (type.toString().equals(value)) {
return type;
}
}
return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.apache.calcite.sql.SqlNode;

import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.incremental.RelNodeIncrementalTransformer;
import com.linkedin.coral.transformers.CoralRelToSqlNodeConverter;
import com.linkedin.coral.trino.trino2rel.TrinoToRelConverter;
import com.linkedin.coral.vis.VisualizationUtil;

Expand All @@ -24,14 +26,23 @@ public static File createImageDir() {
return new File(System.getProperty("java.io.tmpdir") + "/images" + UUID.randomUUID());
}

public static UUID generateSqlNodeVisualization(String query, String fromLanguage, File imageDir) {
public static RelNode incrementalRewrittenRelNode = null;

public static UUID generateSqlNodeVisualization(String query, String fromLanguage, File imageDir,
RewriteType rewriteType) {

SqlNode sqlNode = null;
if (fromLanguage.equalsIgnoreCase("trino")) {
sqlNode = new TrinoToRelConverter(hiveMetastoreClient).toSqlNode(query);
} else if (fromLanguage.equalsIgnoreCase("hive")) {
sqlNode = new HiveToRelConverter(hiveMetastoreClient).toSqlNode(query);
}

if (incrementalRewrittenRelNode != null && rewriteType == RewriteType.INCREMENTAL) {
// We want to instead generate the visualization of SqlNode2 of the RHS of Coral's translation
sqlNode = new CoralRelToSqlNodeConverter().convert(incrementalRewrittenRelNode);
}

assert sqlNode != null;
VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir);

Expand All @@ -41,14 +52,20 @@ public static UUID generateSqlNodeVisualization(String query, String fromLanguag
return sqlNodeId;
}

public static UUID generateRelNodeVisualization(String query, String fromLanguage, File imageDir) {
public static UUID generateRelNodeVisualization(String query, String fromLanguage, File imageDir,
RewriteType rewriteType) {
RelNode relNode = null;
if (fromLanguage.equalsIgnoreCase("trino")) {
relNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query);
} else if (fromLanguage.equalsIgnoreCase("hive")) {
relNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
}

if (rewriteType == RewriteType.INCREMENTAL) {
relNode = RelNodeIncrementalTransformer.convertRelIncremental(relNode);
incrementalRewrittenRelNode = relNode;
}

assert relNode != null;
VisualizationUtil visualizationUtil = VisualizationUtil.create(imageDir);

Expand Down

0 comments on commit db22621

Please sign in to comment.