Skip to content

Commit

Permalink
refactor(properties): migrate to v2 properties
Browse files Browse the repository at this point in the history
  • Loading branch information
mgabelle committed Nov 5, 2024
1 parent 750cf99 commit 17e69d9
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 375 deletions.
20 changes: 9 additions & 11 deletions src/main/java/io/kestra/plugin/kafka/AbstractKafkaConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.confluent.kafka.serializers.KafkaAvroDeserializer;
import io.confluent.kafka.serializers.KafkaJsonDeserializer;
import io.confluent.kafka.serializers.KafkaJsonSerializer;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.Task;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.kafka.serdes.GenericRecordToMapDeserializer;
Expand All @@ -18,10 +19,7 @@
import org.apache.kafka.common.serialization.*;

import java.nio.file.Path;
import java.util.Base64;
import java.util.Collections;
import java.util.Map;
import java.util.Properties;
import java.util.*;

import static io.kestra.core.utils.Rethrow.throwBiConsumer;

Expand All @@ -32,21 +30,21 @@
@NoArgsConstructor
public abstract class AbstractKafkaConnection extends Task implements KafkaConnectionInterface {
@NotNull
protected Map<String, String> properties;
protected Property<Map<String, String>> properties;

@Builder.Default
protected Map<String, String> serdeProperties = Collections.emptyMap();
protected Property<Map<String, String>> serdeProperties = Property.of(new HashMap<>());

protected static Properties createProperties(Map<String, String> mapProperties, RunContext runContext) throws Exception {
protected static Properties createProperties(Property<Map<String, String>> mapProperties, RunContext runContext) throws Exception {
Properties properties = new Properties();

mapProperties
final Map<String, String> renderedMapProperties = runContext.render(mapProperties).asMap(String.class, String.class);
renderedMapProperties
.forEach(throwBiConsumer((key, value) -> {
if (key.equals(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG) || key.equals(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)) {
Path path = runContext.workingDir().createTempFile(Base64.getDecoder().decode(runContext.render(value).replace("\n", "")));
Path path = runContext.workingDir().createTempFile(Base64.getDecoder().decode(value.replace("\n", "")));
properties.put(key, path.toAbsolutePath().toString());
} else {
properties.put(key, runContext.render(value));
properties.put(key, value);
}
}));

Expand Down
120 changes: 63 additions & 57 deletions src/main/java/io/kestra/plugin/kafka/Consume.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.executions.metrics.Counter;
import io.kestra.core.models.property.Property;
import io.kestra.core.models.tasks.RunnableTask;
import io.kestra.core.runners.RunContext;
import io.kestra.core.serializers.FileSerde;
import io.kestra.core.utils.Await;
import io.kestra.core.utils.Rethrow;
import io.kestra.plugin.kafka.serdes.SerdeType;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
Expand Down Expand Up @@ -100,26 +100,26 @@
public class Consume extends AbstractKafkaConnection implements RunnableTask<Consume.Output>, ConsumeInterface {
private Object topic;

private String topicPattern;
private Property<String> topicPattern;

private List<Integer> partitions;
private Property<List<Integer>> partitions;

private String groupId;
private Property<String> groupId;

@Builder.Default
private SerdeType keyDeserializer = SerdeType.STRING;
private Property<SerdeType> keyDeserializer = Property.of(SerdeType.STRING);

@Builder.Default
private SerdeType valueDeserializer = SerdeType.STRING;
private Property<SerdeType> valueDeserializer = Property.of(SerdeType.STRING);

private String since;
private Property<String> since;

@Builder.Default
private Duration pollDuration = Duration.ofSeconds(5);
private Property<Duration> pollDuration = Property.of(Duration.ofSeconds(5));

private Integer maxRecords;
private Property<Integer> maxRecords;

private Duration maxDuration;
private Property<Duration> maxDuration;

@Getter(AccessLevel.PACKAGE)
private ConsumerSubscription subscription;
Expand All @@ -130,12 +130,12 @@ public KafkaConsumer<Object, Object> consumer(RunContext runContext) throws Exce
Thread.currentThread().setContextClassLoader(this.getClass().getClassLoader());

final Properties consumerProps = createProperties(this.properties, runContext);

if (this.groupId != null) {
consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, runContext.render(groupId));
final Optional<String> renderedGroupId = runContext.render(groupId).as(String.class);
if (renderedGroupId.isPresent()) {
consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, renderedGroupId.get());
} else if (consumerProps.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
// groupId can be passed from properties
this.groupId = consumerProps.getProperty(ConsumerConfig.GROUP_ID_CONFIG);
this.groupId = Property.of(consumerProps.getProperty(ConsumerConfig.GROUP_ID_CONFIG));
}

if (!consumerProps.contains(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)) {
Expand All @@ -153,8 +153,8 @@ public KafkaConsumer<Object, Object> consumer(RunContext runContext) throws Exce
// by default, enable Avro LogicalType
serdesProperties.put(KafkaAvroSerializerConfig.AVRO_USE_LOGICAL_TYPE_CONVERTERS_CONFIG, true);

final Deserializer keyDeserializer = getTypedDeserializer(this.keyDeserializer);
final Deserializer valDeserializer = getTypedDeserializer(this.valueDeserializer);
final Deserializer keyDeserializer = getTypedDeserializer(runContext.render(this.keyDeserializer).as(SerdeType.class).orElse(SerdeType.STRING));
final Deserializer valDeserializer = getTypedDeserializer(runContext.render(this.valueDeserializer).as(SerdeType.class).orElse(SerdeType.STRING));

keyDeserializer.configure(serdesProperties, true);
valDeserializer.configure(serdesProperties, false);
Expand All @@ -170,7 +170,7 @@ public Output run(RunContext runContext) throws Exception {
KafkaConsumer<Object, Object> consumer = this.consumer(runContext)
) {
this.subscription = topicSubscription(runContext);
this.subscription.subscribe(consumer, this);
this.subscription.subscribe(runContext, consumer, this);

Map<String, Integer> count = new HashMap<>();
AtomicInteger total = new AtomicInteger();
Expand All @@ -179,24 +179,23 @@ public Output run(RunContext runContext) throws Exception {
boolean empty;

do {
records = consumer.poll(this.pollDuration);
records = consumer.poll(runContext.render(this.pollDuration).as(Duration.class).orElse(Duration.ofSeconds(5)));
empty = records.isEmpty();

records.forEach(throwConsumer(record -> {
FileSerde.write(output, this.recordToMessage(record));
records.forEach(throwConsumer(consumerRecord -> {
FileSerde.write(output, this.recordToMessage(consumerRecord));

total.getAndIncrement();
count.compute(record.topic(), (s, integer) -> integer == null ? 1 : integer + 1);
count.compute(consumerRecord.topic(), (s, integer) -> integer == null ? 1 : integer + 1);
}));
}
while (!this.ended(empty, total, started));
while (!this.ended(runContext, empty, total, started));

if (this.groupId != null) {
consumer.commitSync();
}

// flush & close
consumer.close();
output.flush();

count
Expand All @@ -209,29 +208,31 @@ public Output run(RunContext runContext) throws Exception {
}
}

public Message recordToMessage(ConsumerRecord<Object, Object> record) {
public Message recordToMessage(ConsumerRecord<Object, Object> consumerRecord) {
return Message.builder()
.key(record.key())
.value(record.value())
.headers(processHeaders(record.headers()))
.topic(record.topic())
.partition(record.partition())
.timestamp(Instant.ofEpochMilli(record.timestamp()))
.offset(record.offset())
.key(consumerRecord.key())
.value(consumerRecord.value())
.headers(processHeaders(consumerRecord.headers()))
.topic(consumerRecord.topic())
.partition(consumerRecord.partition())
.timestamp(Instant.ofEpochMilli(consumerRecord.timestamp()))
.offset(consumerRecord.offset())
.build();
}

@SuppressWarnings("RedundantIfStatement")
private boolean ended(Boolean empty, AtomicInteger count, ZonedDateTime start) {
if (empty) {
private boolean ended(RunContext runContext, Boolean empty, AtomicInteger count, ZonedDateTime start) throws IllegalVariableEvaluationException {
if (Boolean.TRUE.equals(empty)) {
return true;
}

if (this.maxRecords != null && count.get() > this.maxRecords) {
final Optional<Integer> renderedMaxRecords = runContext.render(this.maxRecords).as(Integer.class);
if (renderedMaxRecords.isPresent() && count.get() > renderedMaxRecords.get()) {
return true;
}

if (this.maxDuration != null && ZonedDateTime.now().toEpochSecond() > start.plus(this.maxDuration).toEpochSecond()) {
final Optional<Duration> renderedPollDuration = runContext.render(this.pollDuration).as(Duration.class);
if (renderedPollDuration.isPresent() && ZonedDateTime.now().toEpochSecond() > start.plus(renderedPollDuration.get()).toEpochSecond()) {
return true;
}

Expand All @@ -241,33 +242,38 @@ private boolean ended(Boolean empty, AtomicInteger count, ZonedDateTime start) {
public ConsumerSubscription topicSubscription(final RunContext runContext) throws IllegalVariableEvaluationException {
validateConfiguration();

if (this.topic != null && (partitions != null && !partitions.isEmpty())) {
final Optional<String> renderedGroupId = runContext.render(groupId).as(String.class);
final List<Integer> renderedPartitions = runContext.render(partitions).asList(String.class);

if (this.topic != null && !renderedPartitions.isEmpty()) {
List<TopicPartition> topicPartitions = getTopicPartitions(runContext);
return TopicPartitionsSubscription.forTopicPartitions(groupId, topicPartitions, evaluateSince(runContext));
return TopicPartitionsSubscription.forTopicPartitions(renderedGroupId.orElse(null), topicPartitions, evaluateSince(runContext));
}

if (this.topic != null && groupId == null) {
if (this.topic != null && renderedGroupId.isEmpty()) {
return TopicPartitionsSubscription.forTopics(null, evaluateTopics(runContext), evaluateSince(runContext));
}

if (this.topic != null) {
return new TopicListSubscription(groupId, evaluateTopics(runContext));
return new TopicListSubscription(renderedGroupId.get(), evaluateTopics(runContext));
}

if (this.topicPattern != null) {
final Optional<String> renderedPattern = runContext.render(topicPattern).as(String.class);
if (renderedPattern.isPresent()) {
try {
return new TopicPatternSubscription(groupId, Pattern.compile(this.topicPattern));
return new TopicPatternSubscription(renderedGroupId.orElse(null), Pattern.compile(renderedPattern.get()));
} catch (PatternSyntaxException e) {
throw new IllegalArgumentException("Invalid regex for `topicPattern`: " + this.topicPattern);
throw new IllegalArgumentException("Invalid regex for `topicPattern`: " + renderedPattern.get());
}
}
throw new IllegalArgumentException("Failed to create KafkaConsumer subscription");
}

private List<TopicPartition> getTopicPartitions(RunContext runContext) throws IllegalVariableEvaluationException {
List<String> topics = evaluateTopics(runContext);
final List<Integer> renderedPartitions = runContext.render(partitions).asList(String.class);
return topics.stream()
.flatMap(topic -> partitions.stream().map(partition -> new TopicPartition(topic, partition)))
.flatMap(topic -> renderedPartitions.stream().map(partition -> new TopicPartition(topic, partition)))
.toList();
}

Expand All @@ -277,8 +283,8 @@ private List<TopicPartition> getTopicPartitions(RunContext runContext) throws Il
@SuppressWarnings("unchecked")
private List<String> evaluateTopics(final RunContext runContext) throws IllegalVariableEvaluationException {
List<String> topics;
if (this.topic instanceof String) {
topics = List.of(runContext.render((String) this.topic));
if (this.topic instanceof String topicString) {
topics = List.of(runContext.render(topicString));
} else if (this.topic instanceof List) {
topics = runContext.render((List<String>) this.topic);
} else {
Expand All @@ -292,8 +298,7 @@ private List<String> evaluateTopics(final RunContext runContext) throws IllegalV
*/
@Nullable
private Long evaluateSince(final RunContext runContext) throws IllegalVariableEvaluationException {
return Optional.ofNullable(this.since)
.map(Rethrow.throwFunction(runContext::render))
return runContext.render(this.since).as(String.class)
.map(ZonedDateTime::parse)
.map(ChronoZonedDateTime::toInstant)
.map(Instant::toEpochMilli)
Expand Down Expand Up @@ -329,7 +334,7 @@ static List<Pair<String, String>> processHeaders(final Headers headers) {
return StreamSupport
.stream(headers.spliterator(), false)
.map(header -> Pair.of(header.key(), new String(header.value(), StandardCharsets.UTF_8)))
.collect(Collectors.toList());
.toList();
}

@Builder
Expand All @@ -352,13 +357,14 @@ public static class Output implements io.kestra.core.models.tasks.Output {
@VisibleForTesting
interface ConsumerSubscription {

void subscribe(Consumer<Object, Object> consumer, ConsumeInterface consumeInterface);
void subscribe(RunContext runContext, Consumer<Object, Object> consumer, ConsumeInterface consumeInterface) throws IllegalVariableEvaluationException;

default void waitForSubscription(final Consumer<Object, Object> consumer,
final ConsumeInterface consumeInterface) {
default void waitForSubscription(RunContext runContext,
final Consumer<Object, Object> consumer,
final ConsumeInterface consumeInterface) throws IllegalVariableEvaluationException {
var timeout = consumeInterface.getMaxDuration() != null ?
consumeInterface.getMaxDuration() :
consumeInterface.getPollDuration();
runContext.render(consumeInterface.getMaxDuration()).as(Duration.class).orElse(null) :
runContext.render(consumeInterface.getPollDuration()).as(Duration.class).orElse(null);
// Wait for the subscription to happen, this avoids possible no result for the first poll due to the poll timeout
Await.until(() -> !consumer.subscription().isEmpty(), timeout);
}
Expand All @@ -370,7 +376,7 @@ default void waitForSubscription(final Consumer<Object, Object> consumer,
@VisibleForTesting
record TopicPatternSubscription(String groupId, Pattern pattern) implements ConsumerSubscription {
@Override
public void subscribe(final Consumer<Object, Object> consumer,
public void subscribe(RunContext runContext, final Consumer<Object, Object> consumer,
final ConsumeInterface consumeInterface) {
consumer.subscribe(pattern);
}
Expand All @@ -388,9 +394,9 @@ public String toString() {
record TopicListSubscription(String groupId, List<String> topics) implements ConsumerSubscription {

@Override
public void subscribe(final Consumer<Object, Object> consumer, final ConsumeInterface consumeInterface) {
public void subscribe(RunContext runContext, final Consumer<Object, Object> consumer, final ConsumeInterface consumeInterface) throws IllegalVariableEvaluationException {
consumer.subscribe(topics);
waitForSubscription(consumer, consumeInterface);
waitForSubscription(runContext, consumer, consumeInterface);
}

@Override
Expand Down Expand Up @@ -438,7 +444,7 @@ public static TopicPartitionsSubscription forTopics(final String groupId,
}

@Override
public void subscribe(final Consumer<Object, Object> consumer, final ConsumeInterface consumeInterface) {
public void subscribe(RunContext runContext, final Consumer<Object, Object> consumer, final ConsumeInterface consumeInterface) {
if (this.topicPartitions == null) {
this.topicPartitions = allPartitionsForTopics(consumer, topics);
}
Expand Down
13 changes: 4 additions & 9 deletions src/main/java/io/kestra/plugin/kafka/ConsumeInterface.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package io.kestra.plugin.kafka;

import io.kestra.core.models.annotations.PluginProperty;
import io.kestra.plugin.kafka.serdes.SerdeType;
import io.kestra.core.models.property.Property;
import io.swagger.v3.oas.annotations.media.Schema;

import java.time.Duration;
import java.util.List;

import jakarta.validation.constraints.NotNull;

Expand All @@ -15,21 +13,18 @@ public interface ConsumeInterface extends KafkaConsumerInterface {
title = "The maximum number of records to fetch before stopping the consumption process.",
description = "It's a soft limit evaluated every second."
)
@PluginProperty(dynamic = false)
Integer getMaxRecords();
Property<Integer> getMaxRecords();

@Schema(
title = "The maximum duration to wait for new records before stopping the consumption process.",
description = "It's a soft limit evaluated every second."
)
@PluginProperty(dynamic = false)
Duration getMaxDuration();
Property<Duration> getMaxDuration();

@Schema(
title = "How often to poll for a record.",
description = "If no records are available, the maximum wait duration to wait for new records. "
)
@NotNull
@PluginProperty(dynamic = true)
Duration getPollDuration();
Property<Duration> getPollDuration();
}
Loading

0 comments on commit 17e69d9

Please sign in to comment.