diff --git a/build.gradle b/build.gradle index 81c2902a..92778114 100644 --- a/build.gradle +++ b/build.gradle @@ -124,6 +124,7 @@ dependencies { compileOnly group: 'org.apache.flink', name: 'flink-json', version: flinkVersion compileOnly group: 'org.apache.flink', name: 'flink-avro', version: flinkVersion + compileOnly group: 'org.apache.flink', name: 'flink-protobuf', version: flinkVersion testImplementation (group: 'io.pravega', name: 'pravega-standalone', version: pravegaVersion) { exclude group: 'org.slf4j', module: 'slf4j-api' @@ -146,6 +147,7 @@ dependencies { testImplementation group: 'org.apache.flink', name: 'flink-table-planner_' + flinkScalaVersion, classifier: 'tests', version: flinkVersion testImplementation group: 'org.apache.flink', name: 'flink-json', version: flinkVersion testImplementation group: 'org.apache.flink', name: 'flink-avro', version: flinkVersion + testImplementation group: 'org.apache.flink', name: 'flink-protobuf', version: flinkVersion testImplementation group: 'org.hamcrest', name: 'hamcrest', version: hamcrestVersion testImplementation group: 'org.testcontainers', name: 'testcontainers', version: testcontainersVersion testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter', version: junit5Version diff --git a/src/main/java/io/pravega/connectors/flink/formats/registry/PravegaRegistryRowDataDeserializationSchema.java b/src/main/java/io/pravega/connectors/flink/formats/registry/PravegaRegistryRowDataDeserializationSchema.java index c0bb4db0..b0853859 100644 --- a/src/main/java/io/pravega/connectors/flink/formats/registry/PravegaRegistryRowDataDeserializationSchema.java +++ b/src/main/java/io/pravega/connectors/flink/formats/registry/PravegaRegistryRowDataDeserializationSchema.java @@ -18,6 +18,7 @@ import io.pravega.client.stream.Serializer; import io.pravega.connectors.flink.PravegaConfig; +import io.pravega.connectors.flink.util.MessageToRowConverter; import io.pravega.connectors.flink.util.SchemaRegistryUtils; import io.pravega.schemaregistry.client.SchemaRegistryClient; import io.pravega.schemaregistry.client.SchemaRegistryClientConfig; @@ -25,6 +26,7 @@ import io.pravega.schemaregistry.contract.data.SchemaInfo; import io.pravega.schemaregistry.contract.data.SerializationFormat; import io.pravega.schemaregistry.serializer.avro.schemas.AvroSchema; +import io.pravega.schemaregistry.serializer.protobuf.schemas.ProtobufSchema; import io.pravega.schemaregistry.serializer.shared.impl.AbstractDeserializer; import io.pravega.schemaregistry.serializer.shared.impl.EncodingCache; import io.pravega.schemaregistry.serializer.shared.impl.SerializerConfig; @@ -36,6 +38,7 @@ import org.apache.flink.formats.avro.typeutils.AvroSchemaConverter; import org.apache.flink.formats.common.TimestampFormat; import org.apache.flink.formats.json.JsonToRowDataConverters; +import org.apache.flink.formats.protobuf.PbFormatConfig.PbFormatConfigBuilder; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.DeserializationFeature; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; @@ -43,6 +46,7 @@ import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.RowType; import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; +import com.google.protobuf.GeneratedMessageV3; import javax.annotation.Nullable; import java.io.IOException; @@ -51,12 +55,17 @@ import java.util.Objects; import static org.apache.flink.util.Preconditions.checkNotNull; + /** - * Deserialization schema from Pravega Schema Registry to Flink Table/SQL internal data structure {@link RowData}. + * Deserialization schema from Pravega Schema Registry to Flink Table/SQL + * internal data structure {@link RowData}. * - *

Deserializes a byte[] message as a Pravega Schema Registry and reads the specified fields. + *

+ * Deserializes a byte[] message as a Pravega Schema Registry and + * reads the specified fields. * - *

Failures during deserialization are forwarded as wrapped IOExceptions. + *

+ * Failures during deserialization are forwarded as wrapped IOExceptions. */ public class PravegaRegistryRowDataDeserializationSchema implements DeserializationSchema { private static final long serialVersionUID = 1L; @@ -103,12 +112,25 @@ public class PravegaRegistryRowDataDeserializationSchema implements Deserializat /** Flag indicating whether to fail if a field is missing. */ private final boolean failOnMissingField; - /** Flag indicating whether to ignore invalid fields/rows (default: throw an exception). */ + /** + * Flag indicating whether to ignore invalid fields/rows (default: throw an + * exception). + */ private final boolean ignoreParseErrors; /** Timestamp format specification which is used to parse timestamp. */ private final TimestampFormat timestampFormat; + // -------------------------------------------------------------------------------------------- + // Protobuf fields + // -------------------------------------------------------------------------------------------- + + /** Protobuf serialization schema. */ + private transient ProtobufSchema pbSchema; + + /** Protobuf Message Class generated from static .proto file. */ + private GeneratedMessageV3 pbMessage; + public PravegaRegistryRowDataDeserializationSchema( RowType rowType, TypeInformation typeInfo, @@ -116,8 +138,7 @@ public PravegaRegistryRowDataDeserializationSchema( PravegaConfig pravegaConfig, boolean failOnMissingField, boolean ignoreParseErrors, - TimestampFormat timestampFormat - ) { + TimestampFormat timestampFormat) { if (ignoreParseErrors && failOnMissingField) { throw new IllegalArgumentException( "JSON format doesn't support failOnMissingField and ignoreParseErrors are both enabled."); @@ -135,8 +156,8 @@ public PravegaRegistryRowDataDeserializationSchema( @SuppressWarnings("unchecked") @Override public void open(InitializationContext context) throws Exception { - SchemaRegistryClientConfig schemaRegistryClientConfig = - SchemaRegistryUtils.getSchemaRegistryClientConfig(pravegaConfig); + SchemaRegistryClientConfig schemaRegistryClientConfig = SchemaRegistryUtils + .getSchemaRegistryClientConfig(pravegaConfig); SchemaRegistryClient schemaRegistryClient = SchemaRegistryClientFactory.withNamespace(namespace, schemaRegistryClientConfig); SerializerConfig config = SerializerConfig.builder() @@ -153,8 +174,7 @@ public void open(InitializationContext context) throws Exception { break; case Json: ObjectMapper objectMapper = new ObjectMapper(); - boolean hasDecimalType = - LogicalTypeChecks.hasNested(rowType, t -> t instanceof DecimalType); + boolean hasDecimalType = LogicalTypeChecks.hasNested(rowType, t -> t instanceof DecimalType); if (hasDecimalType) { objectMapper.enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS); } @@ -166,6 +186,10 @@ public void open(InitializationContext context) throws Exception { config.isWriteEncodingHeader(), objectMapper); break; + case Protobuf: + pbSchema = ProtobufSchema.of(pbMessage.getClass()); + deserializer = SerializerFactory.protobufDeserializer(config, pbSchema); + break; default: throw new NotImplementedException("Not supporting deserialization format"); } @@ -190,20 +214,26 @@ public Object deserializeToObject(byte[] message) { return deserializer.deserialize(ByteBuffer.wrap(message)); } - public RowData convertToRowData(Object message) { + public RowData convertToRowData(Object message) throws Exception { Object o; switch (serializationFormat) { case Avro: - AvroToRowDataConverters.AvroToRowDataConverter avroConverter = - AvroToRowDataConverters.createRowConverter(rowType); + AvroToRowDataConverters.AvroToRowDataConverter avroConverter = AvroToRowDataConverters + .createRowConverter(rowType); o = avroConverter.convert(message); break; case Json: - JsonToRowDataConverters.JsonToRowDataConverter jsonConverter = - new JsonToRowDataConverters(failOnMissingField, ignoreParseErrors, timestampFormat) - .createConverter(checkNotNull(rowType)); + JsonToRowDataConverters.JsonToRowDataConverter jsonConverter = new JsonToRowDataConverters( + failOnMissingField, ignoreParseErrors, timestampFormat) + .createConverter(checkNotNull(rowType)); o = jsonConverter.convert((JsonNode) message); break; + case Protobuf: + PbFormatConfigBuilder pbConfigBuilder = new PbFormatConfigBuilder() + .messageClassName(pbMessage.getClass().getName()); + MessageToRowConverter pbMessageConverter = new MessageToRowConverter(rowType, pbConfigBuilder.build()); + o = pbMessageConverter.convertMessageToRow(message); + break; default: throw new NotImplementedException("Not supporting deserialization format"); } @@ -214,16 +244,16 @@ private static class FlinkJsonGenericDeserializer extends AbstractDeserializerSerializes the input Flink object into GenericRecord and converts it into byte[]. + *

+ * Serializes the input Flink object into GenericRecord and converts it into + * byte[]. * - *

Result byte[] messages can be deserialized using {@link + *

+ * Result byte[] messages can be deserialized using {@link * PravegaRegistryRowDataDeserializationSchema}. */ public class PravegaRegistryRowDataSerializationSchema implements SerializationSchema { @@ -113,6 +124,16 @@ public class PravegaRegistryRowDataSerializationSchema implements SerializationS /** Flag indicating whether to serialize all decimals as plain numbers. */ private final boolean encodeDecimalAsPlainNumber; + // -------------------------------------------------------------------------------------------- + // Protobuf fields + // -------------------------------------------------------------------------------------------- + + /** Protobuf serialization schema. */ + private transient ProtobufSchema pbSchema; + + /** Protobuf Message Class generated from static .proto file. */ + private GeneratedMessageV3 pbMessage; + public PravegaRegistryRowDataSerializationSchema( RowType rowType, String groupId, @@ -137,8 +158,8 @@ public PravegaRegistryRowDataSerializationSchema( @SuppressWarnings("unchecked") @Override public void open(InitializationContext context) throws Exception { - SchemaRegistryClientConfig schemaRegistryClientConfig = - SchemaRegistryUtils.getSchemaRegistryClientConfig(pravegaConfig); + SchemaRegistryClientConfig schemaRegistryClientConfig = SchemaRegistryUtils + .getSchemaRegistryClientConfig(pravegaConfig); SchemaRegistryClient schemaRegistryClient = SchemaRegistryClientFactory.withNamespace(namespace, schemaRegistryClientConfig); SerializerConfig config = SerializerConfig.builder() @@ -162,6 +183,10 @@ public void open(InitializationContext context) throws Exception { config.isRegisterSchema(), config.isWriteEncodingHeader()); break; + case Protobuf: + pbSchema = ProtobufSchema.of(pbMessage.getClass()); + serializer = SerializerFactory.protobufSerializer(config, pbSchema); + break; default: throw new NotImplementedException("Not supporting deserialization format"); } @@ -176,6 +201,8 @@ public byte[] serialize(RowData row) { return convertToByteArray(serializeToGenericRecord(row)); case Json: return convertToByteArray(serializaToJsonNode(row)); + case Protobuf: + return convertToByteArray(serializeToMessage(row)); default: throw new NotImplementedException("Not supporting deserialization format"); } @@ -185,8 +212,8 @@ public byte[] serialize(RowData row) { } public GenericRecord serializeToGenericRecord(RowData row) { - RowDataToAvroConverters.RowDataToAvroConverter runtimeConverter = - RowDataToAvroConverters.createConverter(rowType); + RowDataToAvroConverters.RowDataToAvroConverter runtimeConverter = RowDataToAvroConverters + .createConverter(rowType); return (GenericRecord) runtimeConverter.convert(avroSchema, row); } @@ -200,6 +227,13 @@ public JsonNode serializaToJsonNode(RowData row) { return runtimeConverter.convert(mapper, node, row); } + public AbstractMessage serializeToMessage(RowData row) throws Exception { + PbFormatConfigBuilder pbConfigBuilder = new PbFormatConfigBuilder() + .messageClassName(pbMessage.getClass().getName()); + RowToMessageConverter runtimeConverter = new RowToMessageConverter(rowType, pbConfigBuilder.build()); + return runtimeConverter.convertRowToProtoMessage(row); + } + @SuppressWarnings("unchecked") public byte[] convertToByteArray(Object message) { return serializer.serialize(message).array(); @@ -208,14 +242,16 @@ public byte[] convertToByteArray(Object message) { @VisibleForTesting protected static class FlinkJsonSerializer extends AbstractSerializer { private final ObjectMapper objectMapper; + public FlinkJsonSerializer(String groupId, SchemaRegistryClient client, JSONSchema schema, - Encoder encoder, boolean registerSchema, boolean encodeHeader) { + Encoder encoder, boolean registerSchema, boolean encodeHeader) { super(groupId, client, schema, encoder, registerSchema, encodeHeader); objectMapper = new ObjectMapper(); } @Override - protected void serialize(JsonNode jsonNode, SchemaInfo schemaInfo, OutputStream outputStream) throws IOException { + protected void serialize(JsonNode jsonNode, SchemaInfo schemaInfo, OutputStream outputStream) + throws IOException { objectMapper.writeValue(outputStream, jsonNode); outputStream.flush(); } diff --git a/src/main/java/io/pravega/connectors/flink/util/MessageToRowConverter.java b/src/main/java/io/pravega/connectors/flink/util/MessageToRowConverter.java new file mode 100644 index 00000000..fac5f8c0 --- /dev/null +++ b/src/main/java/io/pravega/connectors/flink/util/MessageToRowConverter.java @@ -0,0 +1,110 @@ +package io.pravega.connectors.flink.util; + +import org.apache.flink.formats.protobuf.PbCodegenException; +import org.apache.flink.formats.protobuf.PbConstant; +import org.apache.flink.formats.protobuf.PbFormatConfig; +import org.apache.flink.formats.protobuf.PbFormatContext; +import org.apache.flink.formats.protobuf.deserialize.PbCodegenDeserializeFactory; +import org.apache.flink.formats.protobuf.deserialize.PbCodegenDeserializer; +import org.apache.flink.formats.protobuf.util.PbCodegenAppender; +import org.apache.flink.formats.protobuf.util.PbCodegenUtils; +import org.apache.flink.formats.protobuf.util.PbFormatUtils; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.GenericArrayData; +import org.apache.flink.table.data.GenericMapData; +import org.apache.flink.table.data.GenericRowData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.binary.BinaryStringData; +import org.apache.flink.table.types.logical.RowType; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import com.google.protobuf.Descriptors.FileDescriptor.Syntax; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * {@link MessageToRowConverter} can convert protobuf message data to flink row + * data by codegen + * process. + */ +public class MessageToRowConverter { + private static final Logger LOG = LoggerFactory.getLogger(MessageToRowConverter.class); + private final Method parseFromMethod; + private final Method decodeMethod; + + public MessageToRowConverter(RowType rowType, PbFormatConfig formatConfig) + throws PbCodegenException { + try { + Descriptors.Descriptor descriptor = PbFormatUtils.getDescriptor(formatConfig.getMessageClassName()); + Class messageClass = Class.forName( + formatConfig.getMessageClassName(), + true, + Thread.currentThread().getContextClassLoader()); + String fullMessageClassName = PbFormatUtils.getFullJavaName(descriptor, ""); + if (descriptor.getFile().getSyntax() == Syntax.PROTO3) { + // pb3 always read default values + formatConfig = new PbFormatConfig( + formatConfig.getMessageClassName(), + formatConfig.isIgnoreParseErrors(), + true, + formatConfig.getWriteNullStringLiterals()); + } + PbCodegenAppender codegenAppender = new PbCodegenAppender(); + PbFormatContext pbFormatContext = new PbFormatContext("", formatConfig); + String uuid = UUID.randomUUID().toString().replaceAll("\\-", ""); + String generatedClassName = "GeneratedProtoToRow_" + uuid; + String generatedPackageName = MessageToRowConverter.class.getPackage().getName(); + codegenAppender.appendLine("package " + generatedPackageName); + codegenAppender.appendLine("import " + RowData.class.getName()); + codegenAppender.appendLine("import " + ArrayData.class.getName()); + codegenAppender.appendLine("import " + BinaryStringData.class.getName()); + codegenAppender.appendLine("import " + GenericRowData.class.getName()); + codegenAppender.appendLine("import " + GenericMapData.class.getName()); + codegenAppender.appendLine("import " + GenericArrayData.class.getName()); + codegenAppender.appendLine("import " + ArrayList.class.getName()); + codegenAppender.appendLine("import " + List.class.getName()); + codegenAppender.appendLine("import " + Map.class.getName()); + codegenAppender.appendLine("import " + HashMap.class.getName()); + codegenAppender.appendLine("import " + ByteString.class.getName()); + + codegenAppender.appendSegment("public class " + generatedClassName + "{"); + codegenAppender.appendSegment( + "public static RowData " + + PbConstant.GENERATED_DECODE_METHOD + + "(" + + fullMessageClassName + + " message){"); + codegenAppender.appendLine("RowData rowData=null"); + PbCodegenDeserializer codegenDes = PbCodegenDeserializeFactory.getPbCodegenTopRowDes( + descriptor, rowType, pbFormatContext); + String genCode = codegenDes.codegen("rowData", "message", 0); + codegenAppender.appendSegment(genCode); + codegenAppender.appendLine("return rowData"); + codegenAppender.appendSegment("}"); + codegenAppender.appendSegment("}"); + + String printCode = codegenAppender.printWithLineNumber(); + LOG.debug("Protobuf decode codegen: \n" + printCode); + Class generatedClass = PbCodegenUtils.compileClass( + Thread.currentThread().getContextClassLoader(), + generatedPackageName + "." + generatedClassName, + codegenAppender.code()); + decodeMethod = generatedClass.getMethod(PbConstant.GENERATED_DECODE_METHOD, messageClass); + parseFromMethod = messageClass.getMethod(PbConstant.PB_METHOD_PARSE_FROM, byte[].class); + } catch (Exception ex) { + throw new PbCodegenException(ex); + } + } + + public RowData convertMessageToRow(Object messageObj) throws Exception { + return (RowData) decodeMethod.invoke(null, messageObj); + } +} diff --git a/src/main/java/io/pravega/connectors/flink/util/RowToMessageConverter.java b/src/main/java/io/pravega/connectors/flink/util/RowToMessageConverter.java new file mode 100644 index 00000000..d3ce823d --- /dev/null +++ b/src/main/java/io/pravega/connectors/flink/util/RowToMessageConverter.java @@ -0,0 +1,93 @@ +package io.pravega.connectors.flink.util; + +import org.apache.flink.formats.protobuf.PbCodegenException; +import org.apache.flink.formats.protobuf.PbConstant; +import org.apache.flink.formats.protobuf.PbFormatConfig; +import org.apache.flink.formats.protobuf.PbFormatContext; +import org.apache.flink.formats.protobuf.deserialize.ProtoToRowConverter; +import org.apache.flink.formats.protobuf.serialize.PbCodegenSerializeFactory; +import org.apache.flink.formats.protobuf.serialize.PbCodegenSerializer; +import org.apache.flink.formats.protobuf.util.PbCodegenAppender; +import org.apache.flink.formats.protobuf.util.PbCodegenUtils; +import org.apache.flink.formats.protobuf.util.PbFormatUtils; +import org.apache.flink.table.data.ArrayData; +import org.apache.flink.table.data.RowData; +import org.apache.flink.table.data.StringData; +import org.apache.flink.table.types.logical.RowType; + +import com.google.protobuf.AbstractMessage; +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * {@link RowToMessageConverter} can convert flink row data to binary protobuf + * message data by codegen + * process. + */ +public class RowToMessageConverter { + private static final Logger LOG = LoggerFactory.getLogger(ProtoToRowConverter.class); + private final Method encodeMethod; + + public RowToMessageConverter(RowType rowType, PbFormatConfig formatConfig) + throws PbCodegenException { + try { + Descriptors.Descriptor descriptor = PbFormatUtils + .getDescriptor(formatConfig.getMessageClassName()); + PbFormatContext formatContext = new PbFormatContext("", formatConfig); + + PbCodegenAppender codegenAppender = new PbCodegenAppender(0); + String uuid = UUID.randomUUID().toString().replaceAll("\\-", ""); + String generatedClassName = "GeneratedRowToProto_" + uuid; + String generatedPackageName = RowToMessageConverter.class.getPackage().getName(); + codegenAppender.appendLine("package " + generatedPackageName); + codegenAppender.appendLine("import " + AbstractMessage.class.getName()); + codegenAppender.appendLine("import " + Descriptors.class.getName()); + codegenAppender.appendLine("import " + RowData.class.getName()); + codegenAppender.appendLine("import " + ArrayData.class.getName()); + codegenAppender.appendLine("import " + StringData.class.getName()); + codegenAppender.appendLine("import " + ByteString.class.getName()); + codegenAppender.appendLine("import " + List.class.getName()); + codegenAppender.appendLine("import " + ArrayList.class.getName()); + codegenAppender.appendLine("import " + Map.class.getName()); + codegenAppender.appendLine("import " + HashMap.class.getName()); + + codegenAppender.begin("public class " + generatedClassName + "{"); + codegenAppender.begin( + "public static AbstractMessage " + + PbConstant.GENERATED_ENCODE_METHOD + + "(RowData rowData){"); + codegenAppender.appendLine("AbstractMessage message = null"); + PbCodegenSerializer codegenSer = PbCodegenSerializeFactory.getPbCodegenTopRowSer( + descriptor, rowType, formatContext); + String genCode = codegenSer.codegen("message", "rowData", codegenAppender.currentIndent()); + codegenAppender.appendSegment(genCode); + codegenAppender.appendLine("return message"); + codegenAppender.end("}"); + codegenAppender.end("}"); + + String printCode = codegenAppender.printWithLineNumber(); + LOG.debug("Protobuf encode codegen: \n" + printCode); + Class generatedClass = PbCodegenUtils.compileClass( + Thread.currentThread().getContextClassLoader(), + generatedPackageName + "." + generatedClassName, + codegenAppender.code()); + encodeMethod = generatedClass.getMethod(PbConstant.GENERATED_ENCODE_METHOD, RowData.class); + } catch (Exception ex) { + throw new PbCodegenException(ex); + } + } + + public AbstractMessage convertRowToProtoMessage(RowData rowData) throws Exception { + AbstractMessage message = (AbstractMessage) encodeMethod.invoke(null, rowData); + return message; + } +}