Skip to content

Commit

Permalink
custom transformation implementation (#2040)
Browse files Browse the repository at this point in the history
* custom transformation implementation

* adding a test

* addressing comments
  • Loading branch information
shreyakhajanchi authored Dec 4, 2024
1 parent ca317e0 commit 09f68d2
Show file tree
Hide file tree
Showing 11 changed files with 548 additions and 36 deletions.
6 changes: 6 additions & 0 deletions v2/spanner-to-sourcedb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@
<artifactId>beam-it-jdbc</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.cloud.teleport.v2</groupId>
<artifactId>spanner-custom-shard</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<!-- TODO - Remove when https://github.com/apache/beam/pull/29732 is released. -->
<dependency>
<groupId>com.google.cloud.teleport</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
import com.google.cloud.teleport.v2.spanner.migrations.shard.Shard;
import com.google.cloud.teleport.v2.spanner.migrations.spanner.SpannerSchema;
import com.google.cloud.teleport.v2.spanner.migrations.transformation.CustomTransformation;
import com.google.cloud.teleport.v2.spanner.migrations.utils.SecretManagerAccessorImpl;
import com.google.cloud.teleport.v2.spanner.migrations.utils.SessionFileReader;
import com.google.cloud.teleport.v2.spanner.migrations.utils.ShardFileReader;
Expand Down Expand Up @@ -367,6 +368,53 @@ public interface Options extends PipelineOptions, StreamingOptions {
String getSourceType();

void setSourceType(String value);

@TemplateParameter.GcsReadFile(
order = 25,
optional = true,
description = "Custom transformation jar location in Cloud Storage",
helpText =
"Custom jar location in Cloud Storage that contains the custom transformation logic for processing records"
+ " in reverse replication.")
@Default.String("")
String getTransformationJarPath();

void setTransformationJarPath(String value);

@TemplateParameter.Text(
order = 26,
optional = true,
description = "Custom class name for transformation",
helpText =
"Fully qualified class name having the custom transformation logic. It is a"
+ " mandatory field in case transformationJarPath is specified")
@Default.String("")
String getTransformationClassName();

void setTransformationClassName(String value);

@TemplateParameter.Text(
order = 27,
optional = true,
description = "Custom parameters for transformation",
helpText =
"String containing any custom parameters to be passed to the custom transformation class.")
@Default.String("")
String getTransformationCustomParameters();

void setTransformationCustomParameters(String value);

@TemplateParameter.Text(
order = 28,
optional = true,
description = "Directory name for holding filtered records",
helpText =
"Records skipped from reverse replication are written to this directory. Default"
+ " directory name is skip.")
@Default.String("filteredEvents")
String getFilterEventsDirectoryName();

void setFilterEventsDirectoryName(String value);
}

/**
Expand Down Expand Up @@ -541,6 +589,11 @@ public static PipelineResult run(Options options) {
} else {
mergedRecords = dlqRecords;
}
CustomTransformation customTransformation =
CustomTransformation.builder(
options.getTransformationJarPath(), options.getTransformationClassName())
.setCustomParameters(options.getTransformationCustomParameters())
.build();
SourceWriterTransform.Result sourceWriterOutput =
mergedRecords
.apply(
Expand Down Expand Up @@ -578,7 +631,8 @@ public static PipelineResult run(Options options) {
options.getShadowTablePrefix(),
options.getSkipDirectoryName(),
connectionPoolSizePerWorker,
options.getSourceType()));
options.getSourceType(),
customTransformation));

PCollection<FailsafeElement<String, String>> dlqPermErrorRecords =
reconsumedElements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ public class Constants {
// The Tag for skipped records
public static final TupleTag<String> SKIPPED_TAG = new TupleTag<String>() {};

// The Tag for records filtered via custom transformation.
public static final TupleTag<String> FILTERED_TAG = new TupleTag<String>() {};

// Message written to the file for skipped records
public static final String SKIPPED_TAG_MESSAGE = "Skipped record from reverse replication";

Expand All @@ -72,4 +75,8 @@ public class Constants {
public static final String DEFAULT_SHARD_ID = "single_shard";

public static final String SOURCE_MYSQL = "mysql";

// Message written to the file for filtered records
public static final String FILTERED_TAG_MESSAGE =
"Filtered record from custom transformation in reverse replication";
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ
sourceTable,
dmlGeneratorRequest.getNewValuesJson(),
dmlGeneratorRequest.getKeyValuesJson(),
dmlGeneratorRequest.getSourceDbTimezoneOffset());
dmlGeneratorRequest.getSourceDbTimezoneOffset(),
dmlGeneratorRequest.getCustomTransformationResponse());
if (pkcolumnNameValues == null) {
LOG.warn(
"Cannot reverse replicate for table {} without primary key, skipping the record",
Expand Down Expand Up @@ -194,7 +195,8 @@ private static DMLGeneratorResponse generateUpsertStatement(
sourceTable,
dmlGeneratorRequest.getNewValuesJson(),
dmlGeneratorRequest.getKeyValuesJson(),
dmlGeneratorRequest.getSourceDbTimezoneOffset());
dmlGeneratorRequest.getSourceDbTimezoneOffset(),
dmlGeneratorRequest.getCustomTransformationResponse());
return getUpsertStatement(
sourceTable.getName(),
sourceTable.getPrimaryKeySet(),
Expand All @@ -207,7 +209,8 @@ private static Map<String, String> getColumnValues(
SourceTable sourceTable,
JSONObject newValuesJson,
JSONObject keyValuesJson,
String sourceDbTimezoneOffset) {
String sourceDbTimezoneOffset,
Map<String, Object> customTransformationResponse) {
Map<String, String> response = new HashMap<>();

/*
Expand All @@ -224,13 +227,21 @@ private static Map<String, String> getColumnValues(
as the column will be stored with default/null values
*/
Set<String> sourcePKs = sourceTable.getPrimaryKeySet();
Set<String> customTransformColumns = null;
if (customTransformationResponse != null) {
customTransformColumns = customTransformationResponse.keySet();
}
for (Map.Entry<String, SourceColumnDefinition> entry : sourceTable.getColDefs().entrySet()) {
SourceColumnDefinition sourceColDef = entry.getValue();

String colName = sourceColDef.getName();
if (sourcePKs.contains(colName)) {
continue; // we only need non-primary keys
}
if (customTransformColumns != null && customTransformColumns.contains(colName)) {
response.put(colName, customTransformationResponse.get(colName).toString());
continue;
}

String colId = entry.getKey();
SpannerColumnDefinition spannerColDef = spannerTable.getColDefs().get(colId);
Expand Down Expand Up @@ -272,7 +283,8 @@ private static Map<String, String> getPkColumnValues(
SourceTable sourceTable,
JSONObject newValuesJson,
JSONObject keyValuesJson,
String sourceDbTimezoneOffset) {
String sourceDbTimezoneOffset,
Map<String, Object> customTransformationResponse) {
Map<String, String> response = new HashMap<>();
/*
Get all primary key col ids from source table
Expand All @@ -286,6 +298,10 @@ private static Map<String, String> getPkColumnValues(
if the column does not exist in any of the JSON - return null
*/
ColumnPK[] sourcePKs = sourceTable.getPrimaryKeys();
Set<String> customTransformColumns = null;
if (customTransformationResponse != null) {
customTransformColumns = customTransformationResponse.keySet();
}

for (int i = 0; i < sourcePKs.length; i++) {
ColumnPK currentSourcePK = sourcePKs[i];
Expand All @@ -298,6 +314,13 @@ private static Map<String, String> getPkColumnValues(
sourceColDef.getName());
return null;
}
if (customTransformColumns != null
&& customTransformColumns.contains(sourceColDef.getName())) {
response.put(
sourceColDef.getName(),
customTransformationResponse.get(sourceColDef.getName()).toString());
continue;
}
String spannerColumnName = spannerColDef.getName();
String columnValue = "";
if (keyValuesJson.has(spannerColumnName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@
*/
package com.google.cloud.teleport.v2.templates.dbutils.processor;

import com.google.cloud.teleport.v2.spanner.exceptions.InvalidTransformationException;
import com.google.cloud.teleport.v2.spanner.migrations.convertors.ChangeEventToMapConvertor;
import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
import com.google.cloud.teleport.v2.spanner.utils.ISpannerMigrationTransformer;
import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationRequest;
import com.google.cloud.teleport.v2.spanner.utils.MigrationTransformationResponse;
import com.google.cloud.teleport.v2.templates.changestream.TrimmedShardedDataChangeRecord;
import com.google.cloud.teleport.v2.templates.dbutils.dao.source.IDao;
import com.google.cloud.teleport.v2.templates.dbutils.dml.IDMLGenerator;
import com.google.cloud.teleport.v2.templates.models.DMLGeneratorRequest;
import com.google.cloud.teleport.v2.templates.models.DMLGeneratorResponse;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Map;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.joda.time.Duration;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -35,14 +42,18 @@
public class InputRecordProcessor {

private static final Logger LOG = LoggerFactory.getLogger(InputRecordProcessor.class);
private static final Distribution applyCustomTransformationResponseTimeMetric =
Metrics.distribution(
InputRecordProcessor.class, "apply_custom_transformation_impl_latency_ms");

public static void processRecord(
public static boolean processRecord(
TrimmedShardedDataChangeRecord spannerRecord,
Schema schema,
IDao dao,
String shardId,
String sourceDbTimezoneOffset,
IDMLGenerator dmlGenerator)
IDMLGenerator dmlGenerator,
ISpannerMigrationTransformer spannerToSourceTransformer)
throws Exception {

try {
Expand All @@ -53,17 +64,43 @@ public static void processRecord(
String newValueJsonStr = spannerRecord.getMod().getNewValuesJson();
JSONObject newValuesJson = new JSONObject(newValueJsonStr);
JSONObject keysJson = new JSONObject(keysJsonStr);
Map<String, Object> customTransformationResponse = null;

if (spannerToSourceTransformer != null) {
org.joda.time.Instant startTimestamp = org.joda.time.Instant.now();
Map<String, Object> mapRequest =
ChangeEventToMapConvertor.combineJsonObjects(keysJson, newValuesJson);
MigrationTransformationRequest migrationTransformationRequest =
new MigrationTransformationRequest(tableName, mapRequest, shardId, modType);
MigrationTransformationResponse migrationTransformationResponse = null;
try {
migrationTransformationResponse =
spannerToSourceTransformer.toSourceRow(migrationTransformationRequest);
} catch (Exception e) {
throw new InvalidTransformationException(e);
}
org.joda.time.Instant endTimestamp = org.joda.time.Instant.now();
applyCustomTransformationResponseTimeMetric.update(
new Duration(startTimestamp, endTimestamp).getMillis());
if (migrationTransformationResponse.isEventFiltered()) {
Metrics.counter(InputRecordProcessor.class, "filtered_events_" + shardId).inc();
return true;
}
if (migrationTransformationResponse != null) {
customTransformationResponse = migrationTransformationResponse.getResponseRow();
}
}
DMLGeneratorRequest dmlGeneratorRequest =
new DMLGeneratorRequest.Builder(
modType, tableName, newValuesJson, keysJson, sourceDbTimezoneOffset)
.setSchema(schema)
.setCustomTransformationResponse(customTransformationResponse)
.build();

DMLGeneratorResponse dmlGeneratorResponse = dmlGenerator.getDMLStatement(dmlGeneratorRequest);
if (dmlGeneratorResponse.getDmlStatement().isEmpty()) {
LOG.warn("DML statement is empty for table: " + tableName);
return;
return false;
}
dao.write(dmlGeneratorResponse.getDmlStatement());

Expand All @@ -79,7 +116,7 @@ public static void processRecord(
long replicationLag = ChronoUnit.SECONDS.between(commitTsInst, instTime);

lagMetric.update(replicationLag); // update the lag metric

return false;
} catch (Exception e) {
LOG.error(
"The exception while processing shardId: {} is {} ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.google.cloud.teleport.v2.templates.models;

import com.google.cloud.teleport.v2.spanner.migrations.schema.Schema;
import java.util.Map;
import org.json.JSONObject;

/**
Expand Down Expand Up @@ -51,13 +52,16 @@ public class DMLGeneratorRequest {
// The timezone offset of the source database, used for handling timezone-specific data.
private final String sourceDbTimezoneOffset;

private Map<String, Object> customTransformationResponse;

public DMLGeneratorRequest(Builder builder) {
this.modType = builder.modType;
this.spannerTableName = builder.spannerTableName;
this.schema = builder.schema;
this.newValuesJson = builder.newValuesJson;
this.keyValuesJson = builder.keyValuesJson;
this.sourceDbTimezoneOffset = builder.sourceDbTimezoneOffset;
this.customTransformationResponse = builder.customTransformationResponse;
}

public String getModType() {
Expand All @@ -84,13 +88,18 @@ public String getSourceDbTimezoneOffset() {
return sourceDbTimezoneOffset;
}

public Map<String, Object> getCustomTransformationResponse() {
return customTransformationResponse;
}

public static class Builder {
private final String modType;
private final String spannerTableName;
private final JSONObject newValuesJson;
private final JSONObject keyValuesJson;
private final String sourceDbTimezoneOffset;
private Schema schema;
private Map<String, Object> customTransformationResponse;

public Builder(
String modType,
Expand All @@ -110,6 +119,12 @@ public Builder setSchema(Schema schema) {
return this;
}

public Builder setCustomTransformationResponse(
Map<String, Object> customTransformationResponse) {
this.customTransformationResponse = customTransformationResponse;
return this;
}

public DMLGeneratorRequest build() {
return new DMLGeneratorRequest(this);
}
Expand Down
Loading

0 comments on commit 09f68d2

Please sign in to comment.