Skip to content

Commit

Permalink
fix: handle complex types when generating Avro values (#94)
Browse files Browse the repository at this point in the history
* fix: handle complex types when generating Avro values

* chore: add TODO comment

* fix: handle null value

* test: add unit tests

* test: add unit tests
  • Loading branch information
yvrng authored and Skraye committed Oct 4, 2024
1 parent 0504335 commit 34b43bf
Show file tree
Hide file tree
Showing 4 changed files with 449 additions and 97 deletions.
97 changes: 35 additions & 62 deletions src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package io.kestra.plugin.kafka;

import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import io.confluent.kafka.serializers.KafkaAvroDeserializer;
import io.confluent.kafka.serializers.KafkaJsonDeserializer;
import io.confluent.kafka.serializers.KafkaJsonSerializer;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.kafka.serdes.GenericRecordToMapDeserializer;
import io.kestra.plugin.kafka.serdes.KafkaAvroSerializer;
import io.kestra.plugin.kafka.serdes.MapToGenericRecordSerializer;
import io.kestra.plugin.kafka.serdes.SerdeType;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;
Expand Down Expand Up @@ -50,69 +53,39 @@ protected static Properties createProperties(Map<String, String> mapProperties,
return properties;
}

protected static Serializer<?> getTypedSerializer(SerdeType s) throws Exception {
switch (s) {
case STRING:
return new StringSerializer();
case INTEGER:
return new IntegerSerializer();
case FLOAT:
return new FloatSerializer();
case DOUBLE:
return new DoubleSerializer();
case LONG:
return new LongSerializer();
case SHORT:
return new ShortSerializer();
case BYTE_ARRAY:
return new ByteArraySerializer();
case BYTE_BUFFER:
return new ByteBufferSerializer();
case BYTES:
return new BytesSerializer();
case UUID:
return new UUIDSerializer();
case VOID:
return new VoidSerializer();
case AVRO:
return new KafkaAvroSerializer();
case JSON:
return new KafkaJsonSerializer<>();
default:
throw new Exception();
}
protected static Serializer<?> getTypedSerializer(SerdeType s, @Nullable AvroSchema avroSchema) {
return switch (s) {
case STRING -> new StringSerializer();
case INTEGER -> new IntegerSerializer();
case FLOAT -> new FloatSerializer();
case DOUBLE -> new DoubleSerializer();
case LONG -> new LongSerializer();
case SHORT -> new ShortSerializer();
case BYTE_ARRAY -> new ByteArraySerializer();
case BYTE_BUFFER -> new ByteBufferSerializer();
case BYTES -> new BytesSerializer();
case UUID -> new UUIDSerializer();
case VOID -> new VoidSerializer();
case AVRO -> new MapToGenericRecordSerializer(new KafkaAvroSerializer(), avroSchema);
case JSON -> new KafkaJsonSerializer<>();
};
}

protected static Deserializer<?> getTypedDeserializer(SerdeType s) throws Exception {
switch (s) {
case STRING:
return new StringDeserializer();
case INTEGER:
return new IntegerDeserializer();
case FLOAT:
return new FloatDeserializer();
case DOUBLE:
return new DoubleDeserializer();
case LONG:
return new LongDeserializer();
case SHORT:
return new ShortDeserializer();
case BYTE_ARRAY:
return new ByteArrayDeserializer();
case BYTE_BUFFER:
return new ByteBufferDeserializer();
case BYTES:
return new BytesDeserializer();
case UUID:
return new UUIDDeserializer();
case VOID:
return new VoidDeserializer();
case AVRO:
return new GenericRecordToMapDeserializer(new KafkaAvroDeserializer());
case JSON:
return new KafkaJsonDeserializer<>();
default:
throw new Exception();
}
protected static Deserializer<?> getTypedDeserializer(SerdeType s) {
return switch (s) {
case STRING -> new StringDeserializer();
case INTEGER -> new IntegerDeserializer();
case FLOAT -> new FloatDeserializer();
case DOUBLE -> new DoubleDeserializer();
case LONG -> new LongDeserializer();
case SHORT -> new ShortDeserializer();
case BYTE_ARRAY -> new ByteArrayDeserializer();
case BYTE_BUFFER -> new ByteBufferDeserializer();
case BYTES -> new BytesDeserializer();
case UUID -> new UUIDDeserializer();
case VOID -> new VoidDeserializer();
case AVRO -> new GenericRecordToMapDeserializer(new KafkaAvroDeserializer());
case JSON -> new KafkaJsonDeserializer<>();
};
}
}
46 changes: 11 additions & 35 deletions src/main/java/io/kestra/plugin/kafka/Produce.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.kestra.plugin.kafka;

import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
Expand All @@ -9,16 +11,15 @@
import io.kestra.core.serializers.FileSerde;
import io.kestra.core.utils.IdUtils;
import io.kestra.plugin.kafka.serdes.SerdeType;
import jakarta.annotation.Nullable;
import java.util.Optional;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.producer.ProducerRecord;
Expand Down Expand Up @@ -177,8 +178,8 @@ public Output run(RunContext runContext) throws Exception {

Properties serdesProperties = createProperties(this.serdeProperties, runContext);

Serializer keySerial = getTypedSerializer(this.keySerializer);
Serializer valSerial = getTypedSerializer(this.valueSerializer);
Serializer keySerial = getTypedSerializer(this.keySerializer, parseAvroSchema(runContext, keyAvroSchema));
Serializer valSerial = getTypedSerializer(this.valueSerializer, parseAvroSchema(runContext, valueAvroSchema));

keySerial.configure(serdesProperties, true);
valSerial.configure(serdesProperties, false);
Expand Down Expand Up @@ -236,24 +237,9 @@ public Output run(RunContext runContext) throws Exception {
}
}

private GenericRecord buildAvroRecord(RunContext runContext, String dataSchema, Map<String, Object> map) throws Exception {
Schema.Parser parser = new Schema.Parser();
Schema schema = parser.parse(runContext.render(dataSchema));
return buildAvroRecord(schema, map);
}

private GenericRecord buildAvroRecord(Schema schema, Map<String, Object> map) {
GenericRecord avroRecord = new GenericData.Record(schema);
for (String k : map.keySet()) {
Object value = map.get(k);
Schema fieldSchema = schema.getField(k).schema();
if (fieldSchema.getType().equals(Schema.Type.RECORD)) {
avroRecord.put(k, buildAvroRecord(fieldSchema, (Map<String, Object>) value));
} else {
avroRecord.put(k, value);
}
}
return avroRecord;
@Nullable
private static AvroSchema parseAvroSchema(RunContext runContext, @Nullable String avroSchema) throws IllegalVariableEvaluationException {
return Optional.ofNullable(avroSchema).map(throwFunction(runContext::render)).map(AvroSchema::new).orElse(null);
}

@SuppressWarnings("unchecked")
Expand All @@ -273,18 +259,8 @@ private ProducerRecord<Object, Object> producerRecord(RunContext runContext, Kaf

map = runContext.render(map);

if (this.keySerializer == SerdeType.AVRO) {
key = buildAvroRecord(runContext, this.keyAvroSchema, (Map<String, Object>) map.get("key"));
} else {
key = map.get("key");
}

if (this.valueSerializer == SerdeType.AVRO) {
value = buildAvroRecord(runContext, this.valueAvroSchema, (Map<String, Object>) map.get("value"));
} else {
value = map.get("value");
}

key = map.get("key");
value = map.get("value");

if (map.containsKey("topic")) {
topic = runContext.render((String) map.get("topic"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.kestra.plugin.kafka.serdes;

import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericArray;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.GenericRecord;
import org.apache.kafka.common.serialization.Serializer;

public class MapToGenericRecordSerializer implements Serializer<Object> {

private final KafkaAvroSerializer serializer;
private final AvroSchema schema;

public MapToGenericRecordSerializer(KafkaAvroSerializer serializer, AvroSchema schema) {
this.serializer = serializer;
this.schema = schema;
}

@Override
public void configure(Map<String, ?> configs, boolean isKey) {
this.serializer.configure(configs, isKey);
}

@Override
public byte[] serialize(String topic, Object data) {
return serializer.serialize(topic, buildValue(schema.rawSchema(), data));
}

@Override
public void close() {
this.serializer.close();
}

private static Object buildValue(Schema schema, Object data) {
if (data == null) {
return null;
}
return switch (schema.getType()) {
case UNION -> buildUnionValue(schema, data);
case RECORD -> buildRecordValue(schema, (Map<String, ?>) data);
case MAP -> buildMapValue(schema, (Map<String, ?>) data);
case ARRAY -> buildArrayValue(schema, (Collection<?>) data);
case ENUM -> buildEnumValue(schema, (String) data);
case FIXED -> buildFixedValue(schema, (byte[]) data);
case STRING, BYTES, INT, LONG, FLOAT, DOUBLE, BOOLEAN, NULL -> data;
};
}

private static Object buildUnionValue(Schema schema, Object value) {
// TODO using the first non-null schema allows support for optional values, but not polymorphism
for (Schema s : schema.getTypes()) {
if (!s.getType().equals(Schema.Type.NULL)) {
return buildValue(s, value);
}
}
throw new IllegalArgumentException();
}

private static GenericRecord buildRecordValue(Schema schema, Map<String, ?> data) {
final var record = new org.apache.avro.generic.GenericData.Record(schema);
data.forEach((key, value) -> record.put(key, buildValue(schema.getField(key).schema(), value)));
return record;
}

private static Map<String, ?> buildMapValue(Schema schema, Map<String, ?> data) {
final var record = new LinkedHashMap<String, Object>();
data.forEach((key, value) -> record.put(key, buildValue(schema.getValueType(), value)));
return record;
}

private static GenericArray<?> buildArrayValue(Schema schema, Collection<?> data) {
final var values = data.stream().map(value -> buildValue(schema.getElementType(), value)).toList();
return new org.apache.avro.generic.GenericData.Array<>(schema, values);
}

private static GenericEnumSymbol<?> buildEnumValue(Schema schema, String data) {
return new org.apache.avro.generic.GenericData.EnumSymbol(schema, data);
}

private static GenericFixed buildFixedValue(Schema schema, byte[] data) {
return new org.apache.avro.generic.GenericData.Fixed(schema, data);
}
}
Loading

0 comments on commit 34b43bf

Please sign in to comment.