Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Coral-Service] Add rewriteType field to translation endpoint #455

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,10 @@ Please see the [Contribution Agreement](CONTRIBUTING.md).

#### /api/translations/translate
A **POST** API which takes JSON request body containing following parameters and returns the translated query:
- `fromLanguage`: Input dialect (e.g., spark, trino, hive -- see below for supported inputs)
- `toLanguage`: Output dialect (e.g., spark, trino, hive -- see below for supported outputs)
- `sourceLanguage`: Input dialect (e.g., spark, trino, hive -- see below for supported inputs)
- `targetLanguage`: Output dialect (e.g., spark, trino, hive -- see below for supported outputs)
- `query`: SQL query to translate between two dialects
- [Optional] `rewriteType`: Type of Coral IR rewrite (e.g, incremental)

#### /api/catalog-ops/execute
A **POST** API which takes a SQL statement to create a database/table/view in the local metastore
Expand Down Expand Up @@ -195,8 +196,8 @@ Creation successful
curl --header "Content-Type: application/json" \
--request POST \
--data '{
"fromLanguage":"hive",
"toLanguage":"trino",
"sourceLanguage":"hive",
"targetLanguage":"trino",
"query":"SELECT * FROM db1.airport"
}' \
http://localhost:8080/api/translations/translate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import com.linkedin.coral.coralservice.entity.IncrementalRequestBody;
import com.linkedin.coral.coralservice.entity.IncrementalResponseBody;
import com.linkedin.coral.coralservice.entity.TranslateRequestBody;
import com.linkedin.coral.coralservice.utils.RewriteType;

import static com.linkedin.coral.coralservice.utils.CommonUtils.*;
import static com.linkedin.coral.coralservice.utils.CoralProvider.*;
import static com.linkedin.coral.coralservice.utils.IncrementalUtils.*;
import static com.linkedin.coral.coralservice.utils.TranslationUtils.*;
Expand Down Expand Up @@ -59,32 +61,35 @@ public ResponseEntity translate(@RequestBody TranslateRequestBody translateReque
final String sourceLanguage = translateRequestBody.getSourceLanguage();
final String targetLanguage = translateRequestBody.getTargetLanguage();
final String query = translateRequestBody.getQuery();
final RewriteType rewriteType = translateRequestBody.getRewriteType();

// TODO: Allow translations between the same language
if (sourceLanguage.equalsIgnoreCase(targetLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Please choose different languages to translate between.\n");
}

if (!isValidSourceLanguage(sourceLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Currently, only Hive, and Trino are supported as source/input languages.\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for /input.

}

String translatedSql = null;

try {
// TODO: add more translations once n-to-one-to-n is completed
// From Trino
if (sourceLanguage.equalsIgnoreCase("trino")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateTrinoToSpark(query);
}
}
// From Hive
else if (sourceLanguage.equalsIgnoreCase("hive")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateHiveToSpark(query);
}
// To Trino
else if (targetLanguage.equalsIgnoreCase("trino")) {
translatedSql = translateHiveToTrino(query);
if (rewriteType == null) {
// Invalid rewriteType values are deserialized as null
translatedSql = translateQuery(query, sourceLanguage, targetLanguage);
} else {
switch (rewriteType) {
case INCREMENTAL:
translatedSql = getIncrementalQuery(query, sourceLanguage, targetLanguage);
break;
case DATAMASKING:
case NONE:
default:
translatedSql = translateQuery(query, sourceLanguage, targetLanguage);
break;
}
}
} catch (Throwable t) {
Expand Down Expand Up @@ -117,6 +122,7 @@ public ResponseEntity getIncrementalInfo(@RequestBody IncrementalRequestBody inc
incrementalResponseBody.setIncrementalQuery(null);
try {
if (language.equalsIgnoreCase("spark")) {
// TODO: replace getSparkIncrementalQueryFromUserSql with getIncrementalQuery
String incrementalQuery = getSparkIncrementalQueryFromUserSql(query);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work when the source/target specified by the user is not Spark?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getIncrementalQuery(String query, String sourceLanguage, String targetLanguage) takes in a sourceLanguage param.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still not quite clear. Why are not we using getIncrementalQuery here and leaving it as a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to use getIncrementalQuery, we would need to /api/incremental/rewrite to take in a targetLanguage field in the request.

Since changes to this endpoint could break the existing DBT integration, it might be worth it to make this it's own separate effort especially since this is a more complicated endpoint to test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we hardcode "spark" and get rid of getSparkIncrementalQueryFromUserSql? The TODO will look different as well.

for (String tableName : tableNames) {
/* Generate underscore delimited and incremental table names
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.linkedin.coral.coralservice.utils.RewriteType;
import com.linkedin.coral.coralservice.utils.VisualizationUtils;

import static com.linkedin.coral.coralservice.utils.CommonUtils.*;
import static com.linkedin.coral.coralservice.utils.VisualizationUtils.*;


Expand All @@ -43,9 +44,9 @@ public ResponseEntity getIRVisualizations(@RequestBody VisualizationRequestBody
final String query = visualizationRequestBody.getQuery();
final RewriteType rewriteType = visualizationRequestBody.getRewriteType();

if (!visualizationUtils.isValidSourceLanguage(sourceLanguage)) {
if (!isValidSourceLanguage(sourceLanguage)) {
return ResponseEntity.status(HttpStatus.BAD_REQUEST)
.body("Currently, only Hive, Spark, and Trino are supported as engines to generate graphs using.\n");
.body("Currently, only Hive, and Trino are supported as engines to generate graphs using.\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us allow Spark and route it to Hive. Also you can say
Currently, only Hive, Spark, and Trino SQL are supported as source languages for Coral IR visualization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also do this for the translation endpoint?

}

// A list of UUIDs in this order of:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
*/
package com.linkedin.coral.coralservice.entity;

import com.linkedin.coral.coralservice.utils.RewriteType;


public class TranslateRequestBody {
private String sourceLanguage;
private String targetLanguage;
private String query;

private RewriteType rewriteType;

public String getSourceLanguage() {
return sourceLanguage;
}
Expand All @@ -21,4 +26,8 @@ public String getTargetLanguage() {
public String getQuery() {
return query;
}

public RewriteType getRewriteType() {
return rewriteType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.coralservice.utils;

public class CommonUtils {

public static boolean isValidSourceLanguage(String sourceLanguage) {
return sourceLanguage.equalsIgnoreCase("trino") || sourceLanguage.equalsIgnoreCase("hive");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.incremental.RelNodeIncrementalTransformer;
import com.linkedin.coral.spark.CoralSpark;
import com.linkedin.coral.trino.rel2trino.RelToTrinoConverter;
import com.linkedin.coral.trino.trino2rel.TrinoToRelConverter;

import static com.linkedin.coral.coralservice.utils.CoralProvider.*;

Expand All @@ -23,4 +25,29 @@ public static String getSparkIncrementalQueryFromUserSql(String query) {
return coralSpark.getSparkSql();
}

public static String getIncrementalQuery(String query, String sourceLanguage, String targetLanguage) {
RelNode originalNode;

switch (sourceLanguage.toLowerCase()) {
case "trino":
originalNode = new TrinoToRelConverter(hiveMetastoreClient).convertSql(query);
break;
case "hive":
default:
originalNode = new HiveToRelConverter(hiveMetastoreClient).convertSql(query);
break;
}

RelNode incrementalRelNode = RelNodeIncrementalTransformer.convertRelIncremental(originalNode);

switch (targetLanguage.toLowerCase()) {
case "trino":
default:
return new RelToTrinoConverter(hiveMetastoreClient).convert(incrementalRelNode);
case "spark":
CoralSpark coralSpark = CoralSpark.create(incrementalRelNode, hiveMetastoreClient);
return coralSpark.getSparkSql();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ public String toString() {
}

@JsonCreator
public static RewriteType getDepartmentFromCode(String value) {
public static RewriteType getRewriteTypeFromCode(String value) {
for (RewriteType type : RewriteType.values()) {
if (type.toString().equals(value)) {
return type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,30 @@ public static String translateHiveToSpark(String query) {
CoralSpark coralSpark = CoralSpark.create(relNode, hiveMetastoreClient);
return coralSpark.getSparkSql();
}

public static String translateQuery(String query, String sourceLanguage, String targetLanguage) {
String translatedSql = null;

// TODO: add more translations once n-to-one-to-n is completed
// From Trino
if (sourceLanguage.equalsIgnoreCase("trino")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateTrinoToSpark(query);
}
}
// From Hive
else if (sourceLanguage.equalsIgnoreCase("hive")) {
// To Spark
if (targetLanguage.equalsIgnoreCase("spark")) {
translatedSql = translateHiveToSpark(query);
}
// To Trino
else if (targetLanguage.equalsIgnoreCase("trino")) {
translatedSql = translateHiveToTrino(query);
}
}

return translatedSql;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ public static File getImageDir() {
return new File(System.getProperty("java.io.tmpdir") + "/images" + UUID.randomUUID());
}

public boolean isValidSourceLanguage(String sourceLanguage) {
return sourceLanguage.equalsIgnoreCase("trino") || sourceLanguage.equalsIgnoreCase("hive");
}

public ArrayList<UUID> generateIRVisualizations(String query, String sourceLanguage, File imageDir,
RewriteType rewriteType) {
ArrayList<UUID> imageIDList = new ArrayList<>();
Expand Down
Loading