From 34b43bfaec59b87e071e9fcf98085d7f8519a4b7 Mon Sep 17 00:00:00 2001 From: Yoann Vernageau <6807151+yvrng@users.noreply.github.com> Date: Mon, 23 Sep 2024 19:59:37 +0200 Subject: [PATCH] fix: handle complex types when generating Avro values (#94) * fix: handle complex types when generating Avro values * chore: add TODO comment * fix: handle null value * test: add unit tests * test: add unit tests --- .../plugin/kafka/AbstractKafkaConnection.java | 97 ++---- .../java/io/kestra/plugin/kafka/Produce.java | 46 +-- .../serdes/MapToGenericRecordSerializer.java | 88 +++++ .../io/kestra/plugin/kafka/KafkaTest.java | 315 ++++++++++++++++++ 4 files changed, 449 insertions(+), 97 deletions(-) create mode 100644 src/main/java/io/kestra/plugin/kafka/serdes/MapToGenericRecordSerializer.java diff --git a/src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java b/src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java index f6dff3c..536f9bc 100644 --- a/src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java +++ b/src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java @@ -1,5 +1,6 @@ package io.kestra.plugin.kafka; +import io.confluent.kafka.schemaregistry.avro.AvroSchema; import io.confluent.kafka.serializers.KafkaAvroDeserializer; import io.confluent.kafka.serializers.KafkaJsonDeserializer; import io.confluent.kafka.serializers.KafkaJsonSerializer; @@ -7,7 +8,9 @@ import io.kestra.core.runners.RunContext; import io.kestra.plugin.kafka.serdes.GenericRecordToMapDeserializer; import io.kestra.plugin.kafka.serdes.KafkaAvroSerializer; +import io.kestra.plugin.kafka.serdes.MapToGenericRecordSerializer; import io.kestra.plugin.kafka.serdes.SerdeType; +import jakarta.annotation.Nullable; import jakarta.validation.constraints.NotNull; import lombok.*; import lombok.experimental.SuperBuilder; @@ -50,69 +53,39 @@ protected static Properties createProperties(Map mapProperties, return properties; } - protected static Serializer getTypedSerializer(SerdeType s) throws Exception { - switch (s) { - case STRING: - return new StringSerializer(); - case INTEGER: - return new IntegerSerializer(); - case FLOAT: - return new FloatSerializer(); - case DOUBLE: - return new DoubleSerializer(); - case LONG: - return new LongSerializer(); - case SHORT: - return new ShortSerializer(); - case BYTE_ARRAY: - return new ByteArraySerializer(); - case BYTE_BUFFER: - return new ByteBufferSerializer(); - case BYTES: - return new BytesSerializer(); - case UUID: - return new UUIDSerializer(); - case VOID: - return new VoidSerializer(); - case AVRO: - return new KafkaAvroSerializer(); - case JSON: - return new KafkaJsonSerializer<>(); - default: - throw new Exception(); - } + protected static Serializer getTypedSerializer(SerdeType s, @Nullable AvroSchema avroSchema) { + return switch (s) { + case STRING -> new StringSerializer(); + case INTEGER -> new IntegerSerializer(); + case FLOAT -> new FloatSerializer(); + case DOUBLE -> new DoubleSerializer(); + case LONG -> new LongSerializer(); + case SHORT -> new ShortSerializer(); + case BYTE_ARRAY -> new ByteArraySerializer(); + case BYTE_BUFFER -> new ByteBufferSerializer(); + case BYTES -> new BytesSerializer(); + case UUID -> new UUIDSerializer(); + case VOID -> new VoidSerializer(); + case AVRO -> new MapToGenericRecordSerializer(new KafkaAvroSerializer(), avroSchema); + case JSON -> new KafkaJsonSerializer<>(); + }; } - protected static Deserializer getTypedDeserializer(SerdeType s) throws Exception { - switch (s) { - case STRING: - return new StringDeserializer(); - case INTEGER: - return new IntegerDeserializer(); - case FLOAT: - return new FloatDeserializer(); - case DOUBLE: - return new DoubleDeserializer(); - case LONG: - return new LongDeserializer(); - case SHORT: - return new ShortDeserializer(); - case BYTE_ARRAY: - return new ByteArrayDeserializer(); - case BYTE_BUFFER: - return new ByteBufferDeserializer(); - case BYTES: - return new BytesDeserializer(); - case UUID: - return new UUIDDeserializer(); - case VOID: - return new VoidDeserializer(); - case AVRO: - return new GenericRecordToMapDeserializer(new KafkaAvroDeserializer()); - case JSON: - return new KafkaJsonDeserializer<>(); - default: - throw new Exception(); - } + protected static Deserializer getTypedDeserializer(SerdeType s) { + return switch (s) { + case STRING -> new StringDeserializer(); + case INTEGER -> new IntegerDeserializer(); + case FLOAT -> new FloatDeserializer(); + case DOUBLE -> new DoubleDeserializer(); + case LONG -> new LongDeserializer(); + case SHORT -> new ShortDeserializer(); + case BYTE_ARRAY -> new ByteArrayDeserializer(); + case BYTE_BUFFER -> new ByteBufferDeserializer(); + case BYTES -> new BytesDeserializer(); + case UUID -> new UUIDDeserializer(); + case VOID -> new VoidDeserializer(); + case AVRO -> new GenericRecordToMapDeserializer(new KafkaAvroDeserializer()); + case JSON -> new KafkaJsonDeserializer<>(); + }; } } diff --git a/src/main/java/io/kestra/plugin/kafka/Produce.java b/src/main/java/io/kestra/plugin/kafka/Produce.java index 2d2c6a5..a2459f8 100644 --- a/src/main/java/io/kestra/plugin/kafka/Produce.java +++ b/src/main/java/io/kestra/plugin/kafka/Produce.java @@ -1,5 +1,7 @@ package io.kestra.plugin.kafka; +import io.confluent.kafka.schemaregistry.avro.AvroSchema; +import io.kestra.core.exceptions.IllegalVariableEvaluationException; import io.kestra.core.models.annotations.Example; import io.kestra.core.models.annotations.Plugin; import io.kestra.core.models.annotations.PluginProperty; @@ -9,6 +11,8 @@ import io.kestra.core.serializers.FileSerde; import io.kestra.core.utils.IdUtils; import io.kestra.plugin.kafka.serdes.SerdeType; +import jakarta.annotation.Nullable; +import java.util.Optional; import lombok.AccessLevel; import lombok.Builder; import lombok.EqualsAndHashCode; @@ -16,9 +20,6 @@ import lombok.NoArgsConstructor; import lombok.ToString; import lombok.experimental.SuperBuilder; -import org.apache.avro.Schema; -import org.apache.avro.generic.GenericData; -import org.apache.avro.generic.GenericRecord; import org.apache.kafka.clients.producer.KafkaProducer; import org.apache.kafka.clients.producer.ProducerConfig; import org.apache.kafka.clients.producer.ProducerRecord; @@ -177,8 +178,8 @@ public Output run(RunContext runContext) throws Exception { Properties serdesProperties = createProperties(this.serdeProperties, runContext); - Serializer keySerial = getTypedSerializer(this.keySerializer); - Serializer valSerial = getTypedSerializer(this.valueSerializer); + Serializer keySerial = getTypedSerializer(this.keySerializer, parseAvroSchema(runContext, keyAvroSchema)); + Serializer valSerial = getTypedSerializer(this.valueSerializer, parseAvroSchema(runContext, valueAvroSchema)); keySerial.configure(serdesProperties, true); valSerial.configure(serdesProperties, false); @@ -236,24 +237,9 @@ public Output run(RunContext runContext) throws Exception { } } - private GenericRecord buildAvroRecord(RunContext runContext, String dataSchema, Map map) throws Exception { - Schema.Parser parser = new Schema.Parser(); - Schema schema = parser.parse(runContext.render(dataSchema)); - return buildAvroRecord(schema, map); - } - - private GenericRecord buildAvroRecord(Schema schema, Map map) { - GenericRecord avroRecord = new GenericData.Record(schema); - for (String k : map.keySet()) { - Object value = map.get(k); - Schema fieldSchema = schema.getField(k).schema(); - if (fieldSchema.getType().equals(Schema.Type.RECORD)) { - avroRecord.put(k, buildAvroRecord(fieldSchema, (Map) value)); - } else { - avroRecord.put(k, value); - } - } - return avroRecord; + @Nullable + private static AvroSchema parseAvroSchema(RunContext runContext, @Nullable String avroSchema) throws IllegalVariableEvaluationException { + return Optional.ofNullable(avroSchema).map(throwFunction(runContext::render)).map(AvroSchema::new).orElse(null); } @SuppressWarnings("unchecked") @@ -273,18 +259,8 @@ private ProducerRecord producerRecord(RunContext runContext, Kaf map = runContext.render(map); - if (this.keySerializer == SerdeType.AVRO) { - key = buildAvroRecord(runContext, this.keyAvroSchema, (Map) map.get("key")); - } else { - key = map.get("key"); - } - - if (this.valueSerializer == SerdeType.AVRO) { - value = buildAvroRecord(runContext, this.valueAvroSchema, (Map) map.get("value")); - } else { - value = map.get("value"); - } - + key = map.get("key"); + value = map.get("value"); if (map.containsKey("topic")) { topic = runContext.render((String) map.get("topic")); diff --git a/src/main/java/io/kestra/plugin/kafka/serdes/MapToGenericRecordSerializer.java b/src/main/java/io/kestra/plugin/kafka/serdes/MapToGenericRecordSerializer.java new file mode 100644 index 0000000..ab13a4e --- /dev/null +++ b/src/main/java/io/kestra/plugin/kafka/serdes/MapToGenericRecordSerializer.java @@ -0,0 +1,88 @@ +package io.kestra.plugin.kafka.serdes; + +import io.confluent.kafka.schemaregistry.avro.AvroSchema; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericArray; +import org.apache.avro.generic.GenericEnumSymbol; +import org.apache.avro.generic.GenericFixed; +import org.apache.avro.generic.GenericRecord; +import org.apache.kafka.common.serialization.Serializer; + +public class MapToGenericRecordSerializer implements Serializer { + + private final KafkaAvroSerializer serializer; + private final AvroSchema schema; + + public MapToGenericRecordSerializer(KafkaAvroSerializer serializer, AvroSchema schema) { + this.serializer = serializer; + this.schema = schema; + } + + @Override + public void configure(Map configs, boolean isKey) { + this.serializer.configure(configs, isKey); + } + + @Override + public byte[] serialize(String topic, Object data) { + return serializer.serialize(topic, buildValue(schema.rawSchema(), data)); + } + + @Override + public void close() { + this.serializer.close(); + } + + private static Object buildValue(Schema schema, Object data) { + if (data == null) { + return null; + } + return switch (schema.getType()) { + case UNION -> buildUnionValue(schema, data); + case RECORD -> buildRecordValue(schema, (Map) data); + case MAP -> buildMapValue(schema, (Map) data); + case ARRAY -> buildArrayValue(schema, (Collection) data); + case ENUM -> buildEnumValue(schema, (String) data); + case FIXED -> buildFixedValue(schema, (byte[]) data); + case STRING, BYTES, INT, LONG, FLOAT, DOUBLE, BOOLEAN, NULL -> data; + }; + } + + private static Object buildUnionValue(Schema schema, Object value) { + // TODO using the first non-null schema allows support for optional values, but not polymorphism + for (Schema s : schema.getTypes()) { + if (!s.getType().equals(Schema.Type.NULL)) { + return buildValue(s, value); + } + } + throw new IllegalArgumentException(); + } + + private static GenericRecord buildRecordValue(Schema schema, Map data) { + final var record = new org.apache.avro.generic.GenericData.Record(schema); + data.forEach((key, value) -> record.put(key, buildValue(schema.getField(key).schema(), value))); + return record; + } + + private static Map buildMapValue(Schema schema, Map data) { + final var record = new LinkedHashMap(); + data.forEach((key, value) -> record.put(key, buildValue(schema.getValueType(), value))); + return record; + } + + private static GenericArray buildArrayValue(Schema schema, Collection data) { + final var values = data.stream().map(value -> buildValue(schema.getElementType(), value)).toList(); + return new org.apache.avro.generic.GenericData.Array<>(schema, values); + } + + private static GenericEnumSymbol buildEnumValue(Schema schema, String data) { + return new org.apache.avro.generic.GenericData.EnumSymbol(schema, data); + } + + private static GenericFixed buildFixedValue(Schema schema, byte[] data) { + return new org.apache.avro.generic.GenericData.Fixed(schema, data); + } +} diff --git a/src/test/java/io/kestra/plugin/kafka/KafkaTest.java b/src/test/java/io/kestra/plugin/kafka/KafkaTest.java index b4fa0f3..1f59b82 100644 --- a/src/test/java/io/kestra/plugin/kafka/KafkaTest.java +++ b/src/test/java/io/kestra/plugin/kafka/KafkaTest.java @@ -32,6 +32,7 @@ import java.util.function.Consumer; import java.util.stream.Stream; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -415,6 +416,320 @@ void produceComplexAvro() throws Exception { assertThat(reproduceRunOutput.getMessagesCount(), is(1)); } + @Test + void produceAvro_withUnion_andRecord() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("product", Map.of("id", "v1")); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "product", + "type": [ + "null", + {"type": "record", "name": "Version", "fields": [{"name": "id", "type": "string"}]} + ] + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + + @Test + void produceAvro_withUnion_andRecord_null() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = new LinkedHashMap<>(); + value.put("product", null); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "product", + "type": [ + "null", + {"type": "record", "name": "Version", "fields": [{"name": "id", "type": "string"}]} + ] + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withRecord() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("address", Map.of("city", "Paris", "country", "FR", "longitude", 2.3522, "latitude", 48.8566)); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "address", + "type": { + "type": "record", + "name": "Address", + "fields": [ + {"name": "city", "type": "string"}, + {"name": "country", "type": "string"}, + {"name": "longitude", "type": "float"}, + {"name": "latitude", "type": "float"} + ] + } + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withMap() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("map", Map.of("foo", 42, "bar", 17)); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "map", + "type": {"type": "map", "values": "int"} + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withMap_andRecord() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("map", Map.of("foo", Map.of("id", "v1"), "bar", Map.of("id", "v2"))); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "map", + "type": {"type": "map", "values": {"type": "record", "name": "Version", "fields": [{"name": "id", "type": "string"}]}} + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withArray() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("array", List.of("foo", "bar")); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "array", + "type": {"type": "array", "items": "string"} + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withArray_andRecord() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("array", List.of(Map.of("id", "v1"), Map.of("id", "v2"))); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "array", + "type": {"type": "array", "items": {"type": "record", "name": "Version", "fields": [{"name": "id", "type": "string"}]}} + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withEnum() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("state", "SUCCESS"); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "state", + "type": {"name": "StateEnum", "type": "enum", "symbols": ["SUCCESS", "FAILED"]} + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + + @Test + void produceAvro_withFixed() throws Exception { + RunContext runContext = runContextFactory.of(ImmutableMap.of()); + String topic = "tu_" + IdUtils.create(); + + Map value = Map.of("base64", Base64.getEncoder().encode("Hello, World!".getBytes(UTF_8))); + + Produce task = Produce.builder() + .properties(Map.of("bootstrap.servers", this.bootstrap)) + .serdeProperties(Map.of("schema.registry.url", this.registry)) + .keySerializer(SerdeType.STRING) + .valueSerializer(SerdeType.AVRO) + .topic(topic) + .valueAvroSchema(""" + { + "type": "record", + "name": "Sample", + "namespace": "io.kestra.examples", + "fields": [ + { + "name": "base64", + "type": {"name": "Base64", "type": "fixed", "size": 16} + } + ] + } + """) + .from(Map.of("value", value)) + .build(); + + Produce.Output output = task.run(runContext); + assertThat(output.getMessagesCount(), is(1)); + } + @Test void shouldConsumeGivenTopicPattern() throws Exception { RunContext runContext = runContextFactory.of(ImmutableMap.of());