Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Coral-Service] Add rewriteType field to translation endpoint #455

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ Please see the [Contribution Agreement](CONTRIBUTING.md).

#### /api/translations/translate
A **POST** API which takes JSON request body containing following parameters and returns the translated query:
- `fromLanguage`: Input dialect (e.g., spark, trino, hive -- see below for supported inputs)
- `toLanguage`: Output dialect (e.g., spark, trino, hive -- see below for supported outputs)
- `sourceLanguage`: Input dialect (e.g., spark, trino, hive -- see below for supported inputs)
- `targetLanguage`: Output dialect (e.g., spark, trino, hive -- see below for supported outputs)
- `query`: SQL query to translate between two dialects
- [Optional] `rewriteType`: Type of Coral IR rewrite (e.g, incremental)

#### /api/catalog-ops/execute
A **POST** API which takes a SQL statement to create a database/table/view in the local metastore
Expand Down Expand Up @@ -195,8 +196,8 @@ Creation successful
curl --header "Content-Type: application/json" \
--request POST \
--data '{
"fromLanguage":"hive",
"toLanguage":"trino",
"sourceLanguage":"hive",
"targetLanguage":"trino",
"query":"SELECT * FROM db1.airport"
}' \
http://localhost:8080/api/translations/translate
Expand All @@ -216,3 +217,4 @@ FROM "db1"."airport"
2. Hive to Spark
3. Trino to Spark
Note: During Trino to Spark translations, views referenced in queries are considered to be defined in HiveQL and hence cannot be used when translating a view from Trino. Currently, only referencing base tables is supported in Trino queries. This translation path is currently a POC and may need further improvements.
4. Spark to Trino
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import com.linkedin.coral.coralservice.entity.IncrementalRequestBody;
import com.linkedin.coral.coralservice.entity.IncrementalResponseBody;
import com.linkedin.coral.coralservice.entity.TranslateRequestBody;
import com.linkedin.coral.coralservice.utils.RewriteType;

import static com.linkedin.coral.coralservice.utils.CommonUtils.*;
import static com.linkedin.coral.coralservice.utils.CoralProvider.*;
import static com.linkedin.coral.coralservice.utils.IncrementalUtils.*;
import static com.linkedin.coral.coralservice.utils.TranslationUtils.*;
Expand Down Expand Up @@ -59,32 +61,35 @@ public ResponseEntity translate(@RequestBody TranslateRequestBody translateReque
final String sourceLanguage = translateRequestBody.getSourceLanguage();
final String targetLanguage = translateRequestBody.getTargetLanguage();
final String query = translateRequestBody.getQuery();
final RewriteType rewriteType = translateRequestBody.getRewriteType();

// TODO: Allow translations between the same language
if (sourceLanguage.equalsIgnoreCase(targetLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Please choose different languages to translate between.\n");
}

if (!isValidSourceLanguage(sourceLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Currently, only Hive, Trino and Spark are supported as source languages.\n");
}

String translatedSql = null;

try {
// TODO: add more translations once n-to-one-to-n is completed
// From Trino
if (sourceLanguage.equalsIgnoreCase("trino")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateTrinoToSpark(query);
}
}
// From Hive
else if (sourceLanguage.equalsIgnoreCase("hive")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateHiveToSpark(query);
}
// To Trino
else if (targetLanguage.equalsIgnoreCase("trino")) {
translatedSql = translateHiveToTrino(query);
if (rewriteType == null) {
// Invalid rewriteType values are deserialized as null
translatedSql = translateQuery(query, sourceLanguage, targetLanguage);
} else {
switch (rewriteType) {
case INCREMENTAL:
translatedSql = getIncrementalQuery(query, sourceLanguage, targetLanguage);
break;
case DATAMASKING:
case NONE:
default:
translatedSql = translateQuery(query, sourceLanguage, targetLanguage);
break;
}
}
} catch (Throwable t) {
Expand All @@ -110,14 +115,15 @@ public ResponseEntity getIncrementalInfo(@RequestBody IncrementalRequestBody inc
throws JSONException {
final String query = incrementalRequestBody.getQuery();
final List<String> tableNames = incrementalRequestBody.getTableNames();
final String language = incrementalRequestBody.getLanguage();
final String language = incrementalRequestBody.getLanguage(); // source language

// Response will contain incremental query and incremental table names
IncrementalResponseBody incrementalResponseBody = new IncrementalResponseBody();
incrementalResponseBody.setIncrementalQuery(null);
try {
if (language.equalsIgnoreCase("spark")) {
String incrementalQuery = getSparkIncrementalQueryFromUserSql(query);
// TODO: rename language to sourceLanguage and add a targetLanguage field IncrementalRequestBody to use here
String incrementalQuery = getIncrementalQuery(query, language, "spark");
for (String tableName : tableNames) {
/* Generate underscore delimited and incremental table names
Table name: db.t1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.linkedin.coral.coralservice.utils.RewriteType;
import com.linkedin.coral.coralservice.utils.VisualizationUtils;

import static com.linkedin.coral.coralservice.utils.CommonUtils.*;
import static com.linkedin.coral.coralservice.utils.VisualizationUtils.*;


Expand All @@ -43,9 +44,9 @@ public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody
final String query = visualizationRequestBody.getQuery();
final RewriteType rewriteType = visualizationRequestBody.getRewriteType();

if (!visualizationUtils.isValidSourceLanguage(sourceLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Currently, only Hive, Spark, and Trino are supported as engines to generate graphs using.\n");
if (!isValidSourceLanguage(sourceLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST).body(
"Currently, only Hive, Spark, and Trino SQL are supported as source languages for Coral IR visualization. \n");
}

// A list of UUIDs in this order of:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
*/
package com.linkedin.coral.coralservice.entity;

import com.linkedin.coral.coralservice.utils.RewriteType;


public class TranslateRequestBody {
private String sourceLanguage;
private String targetLanguage;
private String query;

private RewriteType rewriteType;

public String getSourceLanguage() {
return sourceLanguage;
}
Expand All @@ -21,4 +26,8 @@ public String getTargetLanguage() {
public String getQuery() {
return query;
}

public RewriteType getRewriteType() {
return rewriteType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.coralservice.utils;

public class CommonUtils {

public static boolean isValidSourceLanguage(String sourceLanguage) {
return sourceLanguage.equalsIgnoreCase("trino") || sourceLanguage.equalsIgnoreCase("hive")
|| sourceLanguage.equalsIgnoreCase("spark");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,36 @@
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.incremental.RelNodeIncrementalTransformer;
import com.linkedin.coral.spark.CoralSpark;
import com.linkedin.coral.trino.rel2trino.RelToTrinoConverter;
import com.linkedin.coral.trino.trino2rel.TrinoToRelConverter;

import static com.linkedin.coral.coralservice.utils.CoralProvider.*;


public class IncrementalUtils {
public static String getIncrementalQuery(String query, String sourceLanguage, String targetLanguage) {
RelNode originalNode;

switch (sourceLanguage.toLowerCase()) {
case "trino":
originalNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query);
break;
case "hive":
default:
originalNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
break;
}

public static String getSparkIncrementalQueryFromUserSql(String query) {
RelNode originalNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(originalNode);
CoralSpark coralSpark = CoralSpark.create(incrementalRelNode, hiveMetastoreClient);
return coralSpark.getSparkSql();

switch (targetLanguage.toLowerCase()) {
case "trino":
default:
return new RelToTrinoConverter(hiveMetastoreClient).convert(incrementalRelNode);
case "spark":
CoralSpark coralSpark = CoralSpark.create(incrementalRelNode, hiveMetastoreClient);
return coralSpark.getSparkSql();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public String toString() {
}

@JsonCreator
public static RewriteType getDepartmentFromCode(String value) {
public static RewriteType getRewriteTypeFromCode(String value) {
for (RewriteType type : RewriteType.values()) {
if (type.toString().equals(value)) {
return type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,30 @@ public static String translateHiveToSpark(String query) {
CoralSpark coralSpark = CoralSpark.create(relNode, hiveMetastoreClient);
return coralSpark.getSparkSql();
}

public static String translateQuery(String query, String sourceLanguage, String targetLanguage) {
String translatedSql = null;

// TODO: add more translations once n-to-one-to-n is completed
// From Trino
if (sourceLanguage.equalsIgnoreCase("trino")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateTrinoToSpark(query);
}
}
// From Hive or Spark
else if (sourceLanguage.equalsIgnoreCase("hive") || sourceLanguage.equalsIgnoreCase("spark")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateHiveToSpark(query);
}
// To Trino
else if (targetLanguage.equalsIgnoreCase("trino")) {
translatedSql = translateHiveToTrino(query);
}
}

return translatedSql;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ public static File getImageDir() {
return new File(System.getProperty("java.io.tmpdir") + "/images" + UUID.randomUUID());
}

public boolean isValidSourceLanguage(String sourceLanguage) {
return sourceLanguage.equalsIgnoreCase("trino") || sourceLanguage.equalsIgnoreCase("hive");
}

public ArrayList<UUID> generateIRVisualizations(String query, String sourceLanguage, File imageDir,
RewriteType rewriteType) {
ArrayList<UUID> imageIDList = new ArrayList<>();
Expand Down Expand Up @@ -89,10 +85,8 @@ private RelNode getRelNode(String query, String sourceLanguage) {
RelNode relNode = null;
if (sourceLanguage.equalsIgnoreCase("trino")) {
relNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query);
} else if (sourceLanguage.equalsIgnoreCase("hive")) {
} else if (sourceLanguage.equalsIgnoreCase("hive") || sourceLanguage.equalsIgnoreCase("spark")) {
relNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
} else if (sourceLanguage.equalsIgnoreCase("spark")) {

}

return relNode;
Expand Down