diff --git a/avro-converter/src/main/java/io/confluent/connect/avro/AvroData.java b/avro-converter/src/main/java/io/confluent/connect/avro/AvroData.java index b12caaf8fb4..3cad0e29e9a 100644 --- a/avro-converter/src/main/java/io/confluent/connect/avro/AvroData.java +++ b/avro-converter/src/main/java/io/confluent/connect/avro/AvroData.java @@ -245,6 +245,7 @@ public Object convert(Schema schema, Object value) { static final String AVRO_LOGICAL_DECIMAL_SCALE_PROP = "scale"; static final String AVRO_LOGICAL_DECIMAL_PRECISION_PROP = "precision"; static final String CONNECT_AVRO_DECIMAL_PRECISION_PROP = "connect.decimal.precision"; + static final String CONNECT_AVRO_FIXED_SIZE = "connect.fixed.size"; static final Integer CONNECT_AVRO_DECIMAL_PRECISION_DEFAULT = 64; private static final HashMap TO_AVRO_LOGICAL_CONVERTERS @@ -451,11 +452,32 @@ private static Object fromConnectData( requireContainer); case BYTES: { - ByteBuffer bytesValue = value instanceof byte[] ? ByteBuffer.wrap((byte[]) value) : - (ByteBuffer) value; + value = value instanceof byte[] ? ByteBuffer.wrap((byte[]) value) : + (ByteBuffer) value; + if (schema != null && schema.parameters() != null + && schema.parameters().containsKey(CONNECT_AVRO_FIXED_SIZE)) { + int size = Integer.parseInt(schema.parameters().get(CONNECT_AVRO_FIXED_SIZE)); + org.apache.avro.Schema fixedSchema = null; + if (avroSchema.getType() == org.apache.avro.Schema.Type.UNION) { + for (org.apache.avro.Schema memberSchema : avroSchema.getTypes()) { + if (memberSchema.getType() == org.apache.avro.Schema.Type.FIXED + && memberSchema.getFixedSize() == size + && unionMemberFieldName(enhancedSchemaSupport, memberSchema) + .equals(schema.name())) { + fixedSchema = memberSchema; + } + } + if (fixedSchema == null) { + throw new DataException("Fixed size " + size + " not in union " + avroSchema); + } + } else { + fixedSchema = avroSchema; + } + value = new GenericData.Fixed(fixedSchema, ((ByteBuffer)value).array()); + } return maybeAddContainer( avroSchema, - maybeWrapSchemaless(schema, bytesValue, ANYTHING_SCHEMA_BYTES_FIELD), + maybeWrapSchemaless(schema, value, ANYTHING_SCHEMA_BYTES_FIELD), requireContainer); } @@ -768,7 +790,18 @@ public org.apache.avro.Schema fromConnectSchema(Schema schema, } break; case BYTES: - baseSchema = org.apache.avro.SchemaBuilder.builder().bytesType(); + if (schema.parameters() != null + && schema.parameters().containsKey(CONNECT_AVRO_FIXED_SIZE)) { + String doc = schema.parameters() != null + ? schema.parameters().get(CONNECT_RECORD_DOC_PROP) + : null; + baseSchema = org.apache.avro.SchemaBuilder.builder().fixed(name) + .namespace(namespace) + .doc(doc) + .size(Integer.parseInt(schema.parameters().get(CONNECT_AVRO_FIXED_SIZE))); + } else { + baseSchema = org.apache.avro.SchemaBuilder.builder().bytesType(); + } if (Decimal.LOGICAL_NAME.equalsIgnoreCase(schema.name())) { int scale = Integer.parseInt(schema.parameters().get(Decimal.SCALE_FIELD)); baseSchema.addProp(AVRO_LOGICAL_DECIMAL_SCALE_PROP, new IntNode(scale)); @@ -1269,9 +1302,10 @@ private Object toConnectData(Schema schema, Object value) { for (Field field : schema.fields()) { Schema fieldSchema = field.schema(); - if (isInstanceOfAvroSchemaTypeForSimpleSchema(fieldSchema, value) - || (valueRecordSchema != null && valueRecordSchema.equals(fieldSchema))) { - converted = new Struct(schema).put(unionMemberFieldName(fieldSchema), + if (isInstanceOfAvroSchemaTypeForSimpleSchema(enhancedSchemaSupport, fieldSchema, + value) || (fieldSchema.equals(valueRecordSchema))) { + converted = new Struct(schema).put(unionMemberFieldName(enhancedSchemaSupport, + fieldSchema), toConnectData(fieldSchema, value)); break; } @@ -1374,6 +1408,9 @@ private Schema toConnectSchema(org.apache.avro.Schema schema, boolean forceOptio } else { builder = SchemaBuilder.bytes(); } + if (schema.getType() == org.apache.avro.Schema.Type.FIXED) { + builder.parameter(CONNECT_AVRO_FIXED_SIZE, String.valueOf(schema.getFixedSize())); + } break; case DOUBLE: builder = SchemaBuilder.float64(); @@ -1475,7 +1512,7 @@ private Schema toConnectSchema(org.apache.avro.Schema schema, boolean forceOptio if (memberSchema.getType() == org.apache.avro.Schema.Type.NULL) { builder.optional(); } else { - String fieldName = unionMemberFieldName(memberSchema); + String fieldName = unionMemberFieldName(enhancedSchemaSupport, memberSchema); if (fieldNames.contains(fieldName)) { throw new DataException("Multiple union schemas map to the Connect union field name"); } @@ -1559,7 +1596,8 @@ private Schema toConnectSchema(org.apache.avro.Schema schema, boolean forceOptio fieldDefaultVal = schema.getJsonProp(CONNECT_DEFAULT_VALUE_PROP); } if (fieldDefaultVal != null) { - builder.defaultValue(defaultValueFromAvro(builder, schema, fieldDefaultVal)); + Object value = defaultValueFromAvro(builder, schema, fieldDefaultVal); + builder.defaultValue(value); } JsonNode connectNameJson = schema.getJsonProp(CONNECT_NAME_PROP); @@ -1571,6 +1609,7 @@ private Schema toConnectSchema(org.apache.avro.Schema schema, boolean forceOptio name = connectNameJson.asText(); } else if (schema.getType() == org.apache.avro.Schema.Type.RECORD + || schema.getType() == org.apache.avro.Schema.Type.FIXED || schema.getType() == org.apache.avro.Schema.Type.ENUM) { name = schema.getFullName(); } @@ -1691,8 +1730,8 @@ private Object defaultValueFromAvro(Schema schema, if (memberAvroSchema.getType() == org.apache.avro.Schema.Type.NULL) { return null; } else { - return defaultValueFromAvro(schema.field(unionMemberFieldName(memberAvroSchema)).schema(), - memberAvroSchema, value); + return defaultValueFromAvro(schema.field(unionMemberFieldName(enhancedSchemaSupport, + memberAvroSchema)).schema(), memberAvroSchema, value); } } default: { @@ -1703,8 +1742,10 @@ private Object defaultValueFromAvro(Schema schema, } - private String unionMemberFieldName(org.apache.avro.Schema schema) { + private static String unionMemberFieldName(boolean enhancedSchemaSupport, + org.apache.avro.Schema schema) { if (schema.getType() == org.apache.avro.Schema.Type.RECORD + || schema.getType() == org.apache.avro.Schema.Type.FIXED || schema.getType() == org.apache.avro.Schema.Type.ENUM) { if (enhancedSchemaSupport) { return schema.getFullName(); @@ -1715,8 +1756,9 @@ private String unionMemberFieldName(org.apache.avro.Schema schema) { return schema.getType().getName(); } - private String unionMemberFieldName(Schema schema) { - if (schema.type() == Schema.Type.STRUCT || isEnumSchema(schema)) { + private static String unionMemberFieldName(boolean enhancedSchemaSupport, Schema schema) { + if (schema.type() == Schema.Type.STRUCT || isEnumSchema(schema) + || (schema.type() == Schema.Type.BYTES && schema.name() != null)) { if (enhancedSchemaSupport) { return schema.name(); } else { @@ -1732,7 +1774,8 @@ private static boolean isEnumSchema(Schema schema) { && schema.name().equals(AVRO_TYPE_ENUM); } - private static boolean isInstanceOfAvroSchemaTypeForSimpleSchema(Schema fieldSchema, + private static boolean isInstanceOfAvroSchemaTypeForSimpleSchema(boolean enhancedSchemaSupport, + Schema fieldSchema, Object value) { List classes = SIMPLE_AVRO_SCHEMA_TYPES.get(fieldSchema.type()); if (classes == null) { @@ -1740,12 +1783,39 @@ private static boolean isInstanceOfAvroSchemaTypeForSimpleSchema(Schema fieldSch } for (Class type : classes) { if (type.isInstance(value)) { - return true; + if (fieldSchema.type() == Schema.Type.BYTES + && fieldSchema.parameters() != null + && fieldSchema.parameters().containsKey(CONNECT_AVRO_FIXED_SIZE)) { + if (fixedValueSizeOf(enhancedSchemaSupport, fieldSchema, value, + Integer.parseInt(fieldSchema.parameters().get(CONNECT_AVRO_FIXED_SIZE)))) { + return true; + } + } else { + return true; + } } } return false; } + /** + * Get size of bytes value tagged as fixed + */ + private static boolean fixedValueSizeOf(boolean enhancedSchemaSupport, Schema fieldSchema, + Object value, int size) { + if (value instanceof byte[]) { + return ((byte[]) value).length == size; + } else if (value instanceof ByteBuffer) { + return ((ByteBuffer)value).remaining() == size; + } else if (value instanceof GenericFixed) { + return unionMemberFieldName(enhancedSchemaSupport, ((GenericFixed) value).getSchema()) + .equals(fieldSchema.name()); + } else { + throw new DataException("Invalid class for fixed, expecting GenericFixed, byte[]" + + " or ByteBuffer but found " + value.getClass()); + } + } + /** * Split a full dotted-syntax name into a namespace and a single-component name. */ diff --git a/avro-converter/src/test/java/io/confluent/connect/avro/AvroDataTest.java b/avro-converter/src/test/java/io/confluent/connect/avro/AvroDataTest.java index 91bf176ef4a..63d523c4f2b 100644 --- a/avro-converter/src/test/java/io/confluent/connect/avro/AvroDataTest.java +++ b/avro-converter/src/test/java/io/confluent/connect/avro/AvroDataTest.java @@ -61,6 +61,7 @@ import io.confluent.kafka.serializers.NonRecordContainer; import static io.confluent.connect.avro.AvroData.AVRO_TYPE_ENUM; +import static io.confluent.connect.avro.AvroData.CONNECT_AVRO_FIXED_SIZE; import static io.confluent.connect.avro.AvroData.CONNECT_ENUM_DOC_PROP; import static io.confluent.connect.avro.AvroData.CONNECT_RECORD_DOC_PROP; import static org.junit.Assert.*; @@ -166,6 +167,42 @@ public void testFromConnectString() { checkNonRecordConversionNull(Schema.OPTIONAL_STRING_SCHEMA); } + @Test + public void testFromConnectBytesFixed() { + org.apache.avro.Schema avroSchema = org.apache.avro.SchemaBuilder.builder().fixed("sample").size(4); + GenericData.Fixed avroObj = new GenericData.Fixed(avroSchema, "foob".getBytes()); + avroSchema.addProp("connect.parameters", ImmutableMap.of("connect.fixed.size", "4")); + avroSchema.addProp("connect.name", "sample"); + SchemaAndValue schemaAndValue = avroData.toConnectData(avroSchema, avroObj); + checkNonRecordConversion(avroSchema, avroObj, schemaAndValue.schema(), schemaAndValue.value(), + avroData); + } + + @Test + public void testFromConnectFixedUnion() { + org.apache.avro.Schema unionSchema = org.apache.avro.SchemaBuilder.builder().unionOf() + .type(avroFixed("sample", 4)).and() + .type(avroFixed("other", 6)).and() + .type(avroFixed("sameOther", 6)).endUnion(); + Schema union = SchemaBuilder.struct() + .name("io.confluent.connect.avro.Union") + .field("sample", connectFixedOptional("sample", 4)) + .field("other", connectFixedOptional("other", 6)) + .field("sameOther", connectFixedOptional("sameOther", 6)) + .build(); + Struct unionSample = new Struct(union).put("sample", ByteBuffer.wrap("foob".getBytes())); + Struct unionOther = new Struct(union).put("other", ByteBuffer.wrap("foobar".getBytes())); + Struct unionSameOther = new Struct(union).put("sameOther", ByteBuffer.wrap("foobar".getBytes())); + + GenericData genericData = GenericData.get(); + assertEquals(0, + genericData.resolveUnion(unionSchema, avroData.fromConnectData(union, unionSample))); + assertEquals(1, + genericData.resolveUnion(unionSchema, avroData.fromConnectData(union, unionOther))); + assertEquals(2, + genericData.resolveUnion(unionSchema, avroData.fromConnectData(union, unionSameOther))); + } + @Test public void testFromConnectEnum() { AvroDataConfig avroDataConfig = new AvroDataConfig.Builder() @@ -1134,18 +1171,42 @@ public void testToConnectNull() { @Test public void testToConnectFixed() { - // Our conversion simply loses the fixed size information. - org.apache.avro.Schema avroSchema = org.apache.avro.SchemaBuilder.builder() - .fixed("sample").size(4); - assertEquals(new SchemaAndValue(Schema.BYTES_SCHEMA, ByteBuffer.wrap("foob".getBytes())), - avroData.toConnectData(avroSchema, "foob".getBytes())); + org.apache.avro.Schema sampleSchema = avroFixed("sample", 4); + Schema sample = connectFixed("sample", 4); + + // Fixed size is preserved + assertEquals(new SchemaAndValue(sample, ByteBuffer.wrap("foob".getBytes())).schema(), + avroData.toConnectData(sampleSchema, "foob".getBytes()).schema()); - assertEquals(new SchemaAndValue(Schema.BYTES_SCHEMA, ByteBuffer.wrap("foob".getBytes())), - avroData.toConnectData(avroSchema, ByteBuffer.wrap("foob".getBytes()))); + // byte[], ByteBuffer and avro Generic Fixed types are valid inputs + assertEquals(new SchemaAndValue(sample, ByteBuffer.wrap("foob".getBytes())), + avroData.toConnectData(sampleSchema, "foob".getBytes())); + assertEquals(new SchemaAndValue(sample, ByteBuffer.wrap("foob".getBytes())), + avroData.toConnectData(sampleSchema, ByteBuffer.wrap("foob".getBytes()))); + GenericData.Fixed valueFixed4 = new GenericData.Fixed(sampleSchema, "foob".getBytes()); + assertEquals(new SchemaAndValue(sample, ByteBuffer.wrap("foob".getBytes())), + avroData.toConnectData(sampleSchema, valueFixed4)); + } - // test with actual fixed type - assertEquals(new SchemaAndValue(Schema.BYTES_SCHEMA, ByteBuffer.wrap("foob".getBytes())), - avroData.toConnectData(avroSchema, new GenericData.Fixed(avroSchema, "foob".getBytes()))); + private static Schema connectFixed(String name, int size) { + return connectFixed(name, size, false); + } + + private static Schema connectFixedOptional(String name, int size) { + return connectFixed(name, size, true); + } + + private static Schema connectFixed(String name, int size, boolean optional) { + SchemaBuilder builder = SchemaBuilder.bytes().name(name) + .parameter(CONNECT_AVRO_FIXED_SIZE, String.valueOf(size)); + if (optional) { + builder = builder.optional(); + } + return builder.build(); + } + + private static org.apache.avro.Schema avroFixed(String name, int size) { + return org.apache.avro.SchemaBuilder.builder().fixed(name).size(size); } @Test @@ -1250,7 +1311,38 @@ public void testToConnectUnionRecordConflict() { GenericRecord recordTest = new GenericRecordBuilder(avroRecordSchema1).set("test", 12).build(); avroData.toConnectData(avroSchema, recordTest); } - + + @Test + public void testToConnectFixedUnion() { + org.apache.avro.Schema sampleSchema = avroFixed("sample", 4); + org.apache.avro.Schema otherSchema = avroFixed("other", 6); + org.apache.avro.Schema sameOtherSchema = avroFixed("sameOther", 6); + org.apache.avro.Schema unionSchema = org.apache.avro.SchemaBuilder.builder() + .unionOf().type(sampleSchema).and().type(otherSchema).and().type(sameOtherSchema) + .endUnion(); + Schema union = SchemaBuilder.struct() + .name("io.confluent.connect.avro.Union") + .field("sample", connectFixedOptional("sample", 4)) + .field("other", connectFixedOptional("other", 6)) + .field("sameOther", connectFixedOptional("sameOther", 6)) + .build(); + GenericData.Fixed valueSample = new GenericData.Fixed(sampleSchema, "foob".getBytes()); + GenericData.Fixed valueOther = new GenericData.Fixed(otherSchema, "foobar".getBytes()); + GenericData.Fixed valueSameOther = new GenericData.Fixed(sameOtherSchema, "foobar".getBytes()); + Schema connectSchema = avroData.toConnectSchema(unionSchema); + assertEquals(union, connectSchema); + Struct unionSame = new Struct(union).put("sample", ByteBuffer.wrap("foob".getBytes())); + assertEquals(new SchemaAndValue(union, unionSame), + avroData.toConnectData(unionSchema, valueSample)); + Struct unionOther = new Struct(union).put("other", ByteBuffer.wrap("foobar".getBytes())); + assertEquals(new SchemaAndValue(union, unionOther), + avroData.toConnectData(unionSchema, valueOther)); + Struct unionSameOther = new Struct(union).put("sameOther", + ByteBuffer.wrap("foobar".getBytes())); + assertEquals(new SchemaAndValue(union, unionSameOther), + avroData.toConnectData(unionSchema, valueSameOther)); + } + @Test public void testToConnectUnionRecordConflictWithEnhanced() { // If the records have the same name but are in different namespaces,