Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle complex types when generating Avro values #94

Merged
merged 5 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 35 additions & 62 deletions src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package io.kestra.plugin.kafka;

import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import io.confluent.kafka.serializers.KafkaAvroDeserializer;
import io.confluent.kafka.serializers.KafkaJsonDeserializer;
import io.confluent.kafka.serializers.KafkaJsonSerializer;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.kafka.serdes.GenericRecordToMapDeserializer;
import io.kestra.plugin.kafka.serdes.KafkaAvroSerializer;
import io.kestra.plugin.kafka.serdes.MapToGenericRecordSerializer;
import io.kestra.plugin.kafka.serdes.SerdeType;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import lombok.*;
import lombok.experimental.SuperBuilder;
Expand Down Expand Up @@ -50,69 +53,39 @@ protected static Properties createProperties(Map<String, String> mapProperties,
return properties;
}

protected static Serializer<?> getTypedSerializer(SerdeType s) throws Exception {
switch (s) {
case STRING:
return new StringSerializer();
case INTEGER:
return new IntegerSerializer();
case FLOAT:
return new FloatSerializer();
case DOUBLE:
return new DoubleSerializer();
case LONG:
return new LongSerializer();
case SHORT:
return new ShortSerializer();
case BYTE_ARRAY:
return new ByteArraySerializer();
case BYTE_BUFFER:
return new ByteBufferSerializer();
case BYTES:
return new BytesSerializer();
case UUID:
return new UUIDSerializer();
case VOID:
return new VoidSerializer();
case AVRO:
return new KafkaAvroSerializer();
case JSON:
return new KafkaJsonSerializer<>();
default:
throw new Exception();
}
protected static Serializer<?> getTypedSerializer(SerdeType s, @Nullable AvroSchema avroSchema) {
return switch (s) {
case STRING -> new StringSerializer();
case INTEGER -> new IntegerSerializer();
case FLOAT -> new FloatSerializer();
case DOUBLE -> new DoubleSerializer();
case LONG -> new LongSerializer();
case SHORT -> new ShortSerializer();
case BYTE_ARRAY -> new ByteArraySerializer();
case BYTE_BUFFER -> new ByteBufferSerializer();
case BYTES -> new BytesSerializer();
case UUID -> new UUIDSerializer();
case VOID -> new VoidSerializer();
case AVRO -> new MapToGenericRecordSerializer(new KafkaAvroSerializer(), avroSchema);
case JSON -> new KafkaJsonSerializer<>();
};
}

protected static Deserializer<?> getTypedDeserializer(SerdeType s) throws Exception {
switch (s) {
case STRING:
return new StringDeserializer();
case INTEGER:
return new IntegerDeserializer();
case FLOAT:
return new FloatDeserializer();
case DOUBLE:
return new DoubleDeserializer();
case LONG:
return new LongDeserializer();
case SHORT:
return new ShortDeserializer();
case BYTE_ARRAY:
return new ByteArrayDeserializer();
case BYTE_BUFFER:
return new ByteBufferDeserializer();
case BYTES:
return new BytesDeserializer();
case UUID:
return new UUIDDeserializer();
case VOID:
return new VoidDeserializer();
case AVRO:
return new GenericRecordToMapDeserializer(new KafkaAvroDeserializer());
case JSON:
return new KafkaJsonDeserializer<>();
default:
throw new Exception();
}
protected static Deserializer<?> getTypedDeserializer(SerdeType s) {
return switch (s) {
case STRING -> new StringDeserializer();
case INTEGER -> new IntegerDeserializer();
case FLOAT -> new FloatDeserializer();
case DOUBLE -> new DoubleDeserializer();
case LONG -> new LongDeserializer();
case SHORT -> new ShortDeserializer();
case BYTE_ARRAY -> new ByteArrayDeserializer();
case BYTE_BUFFER -> new ByteBufferDeserializer();
case BYTES -> new BytesDeserializer();
case UUID -> new UUIDDeserializer();
case VOID -> new VoidDeserializer();
case AVRO -> new GenericRecordToMapDeserializer(new KafkaAvroDeserializer());
case JSON -> new KafkaJsonDeserializer<>();
};
}
}
46 changes: 11 additions & 35 deletions src/main/java/io/kestra/plugin/kafka/Produce.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.kestra.plugin.kafka;

import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.annotations.PluginProperty;
Expand All @@ -9,16 +11,15 @@
import io.kestra.core.serializers.FileSerde;
import io.kestra.core.utils.IdUtils;
import io.kestra.plugin.kafka.serdes.SerdeType;
import jakarta.annotation.Nullable;
import java.util.Optional;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.producer.ProducerRecord;
Expand Down Expand Up @@ -177,8 +178,8 @@ public Output run(RunContext runContext) throws Exception {

Properties serdesProperties = createProperties(this.serdeProperties, runContext);

Serializer keySerial = getTypedSerializer(this.keySerializer);
Serializer valSerial = getTypedSerializer(this.valueSerializer);
Serializer keySerial = getTypedSerializer(this.keySerializer, parseAvroSchema(runContext, keyAvroSchema));
Serializer valSerial = getTypedSerializer(this.valueSerializer, parseAvroSchema(runContext, valueAvroSchema));

keySerial.configure(serdesProperties, true);
valSerial.configure(serdesProperties, false);
Expand Down Expand Up @@ -234,24 +235,9 @@ public Output run(RunContext runContext) throws Exception {
}
}

private GenericRecord buildAvroRecord(RunContext runContext, String dataSchema, Map<String, Object> map) throws Exception {
Schema.Parser parser = new Schema.Parser();
Schema schema = parser.parse(runContext.render(dataSchema));
return buildAvroRecord(schema, map);
}

private GenericRecord buildAvroRecord(Schema schema, Map<String, Object> map) {
GenericRecord avroRecord = new GenericData.Record(schema);
for (String k : map.keySet()) {
Object value = map.get(k);
Schema fieldSchema = schema.getField(k).schema();
if (fieldSchema.getType().equals(Schema.Type.RECORD)) {
avroRecord.put(k, buildAvroRecord(fieldSchema, (Map<String, Object>) value));
} else {
avroRecord.put(k, value);
}
}
return avroRecord;
@Nullable
private static AvroSchema parseAvroSchema(RunContext runContext, @Nullable String avroSchema) throws IllegalVariableEvaluationException {
return Optional.ofNullable(avroSchema).map(throwFunction(runContext::render)).map(AvroSchema::new).orElse(null);
}

@SuppressWarnings("unchecked")
Expand All @@ -271,18 +257,8 @@ private ProducerRecord<Object, Object> producerRecord(RunContext runContext, Kaf

map = runContext.render(map);

if (this.keySerializer == SerdeType.AVRO) {
key = buildAvroRecord(runContext, this.keyAvroSchema, (Map<String, Object>) map.get("key"));
} else {
key = map.get("key");
}

if (this.valueSerializer == SerdeType.AVRO) {
value = buildAvroRecord(runContext, this.valueAvroSchema, (Map<String, Object>) map.get("value"));
} else {
value = map.get("value");
}

key = map.get("key");
value = map.get("value");

if (map.containsKey("topic")) {
topic = runContext.render((String) map.get("topic"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.kestra.plugin.kafka.serdes;

import io.confluent.kafka.schemaregistry.avro.AvroSchema;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericArray;
import org.apache.avro.generic.GenericEnumSymbol;
import org.apache.avro.generic.GenericFixed;
import org.apache.avro.generic.GenericRecord;
import org.apache.kafka.common.serialization.Serializer;

public class MapToGenericRecordSerializer implements Serializer<Object> {

private final KafkaAvroSerializer serializer;
private final AvroSchema schema;

public MapToGenericRecordSerializer(KafkaAvroSerializer serializer, AvroSchema schema) {
this.serializer = serializer;
this.schema = schema;
}

@Override
public void configure(Map<String, ?> configs, boolean isKey) {
this.serializer.configure(configs, isKey);
}

@Override
public byte[] serialize(String topic, Object data) {
return serializer.serialize(topic, buildValue(schema.rawSchema(), data));
}

@Override
public void close() {
this.serializer.close();
}

private static Object buildValue(Schema schema, Object data) {
if (data == null) {
return null;
}
return switch (schema.getType()) {
case UNION -> buildUnionValue(schema, data);
case RECORD -> buildRecordValue(schema, (Map<String, ?>) data);
case MAP -> buildMapValue(schema, (Map<String, ?>) data);
case ARRAY -> buildArrayValue(schema, (Collection<?>) data);
case ENUM -> buildEnumValue(schema, (String) data);
case FIXED -> buildFixedValue(schema, (byte[]) data);
case STRING, BYTES, INT, LONG, FLOAT, DOUBLE, BOOLEAN, NULL -> data;
};
}

private static Object buildUnionValue(Schema schema, Object value) {
// TODO using the first non-null schema allows support for optional values, but not polymorphism
for (Schema s : schema.getTypes()) {
if (!s.getType().equals(Schema.Type.NULL)) {
return buildValue(s, value);
}
}
throw new IllegalArgumentException();
}

private static GenericRecord buildRecordValue(Schema schema, Map<String, ?> data) {
final var record = new org.apache.avro.generic.GenericData.Record(schema);
data.forEach((key, value) -> record.put(key, buildValue(schema.getField(key).schema(), value)));
return record;
}

private static Map<String, ?> buildMapValue(Schema schema, Map<String, ?> data) {
final var record = new LinkedHashMap<String, Object>();
data.forEach((key, value) -> record.put(key, buildValue(schema.getValueType(), value)));
return record;
}

private static GenericArray<?> buildArrayValue(Schema schema, Collection<?> data) {
final var values = data.stream().map(value -> buildValue(schema.getElementType(), value)).toList();
return new org.apache.avro.generic.GenericData.Array<>(schema, values);
}

private static GenericEnumSymbol<?> buildEnumValue(Schema schema, String data) {
return new org.apache.avro.generic.GenericData.EnumSymbol(schema, data);
}

private static GenericFixed buildFixedValue(Schema schema, byte[] data) {
return new org.apache.avro.generic.GenericData.Fixed(schema, data);
}
}
Loading