From 618bb2ba42465e2d8a5eac264c12e4c1ee843138 Mon Sep 17 00:00:00 2001 From: Florian Hussonnois Date: Thu, 23 May 2024 12:04:23 +0200 Subject: [PATCH] feat: enhance realtime trigger lifecycle part-of: kestra-io/kestra#3767 --- .../java/io/kestra/plugin/amqp/Consume.java | 26 --- .../kestra/plugin/amqp/RealtimeTrigger.java | 153 +++++++++++++++++- src/test/resources/flows/realtime.yaml | 1 - 3 files changed, 145 insertions(+), 35 deletions(-) diff --git a/src/main/java/io/kestra/plugin/amqp/Consume.java b/src/main/java/io/kestra/plugin/amqp/Consume.java index 7606c9f..17ce49c 100644 --- a/src/main/java/io/kestra/plugin/amqp/Consume.java +++ b/src/main/java/io/kestra/plugin/amqp/Consume.java @@ -104,32 +104,6 @@ public Consume.Output run(RunContext runContext) throws Exception { } } - public Publisher stream(RunContext runContext) { - return Flux.create( - fluxSink -> { - try { - ConnectionFactory factory = this.connectionFactory(runContext); - - try ( - ConsumeThread thread = new ConsumeThread( - factory, - runContext, - this, - throwConsumer(fluxSink::next), - () -> false - ); - ) { - thread.start(); - thread.join(); - } - } catch (Throwable e) { - fluxSink.error(e); - } finally { - fluxSink.complete(); - } - }); - } - @SuppressWarnings("RedundantIfStatement") private boolean ended(AtomicInteger count, ZonedDateTime start) { if (this.maxRecords != null && count.get() >= this.maxRecords) { diff --git a/src/main/java/io/kestra/plugin/amqp/RealtimeTrigger.java b/src/main/java/io/kestra/plugin/amqp/RealtimeTrigger.java index 77ebdd1..8f243e3 100644 --- a/src/main/java/io/kestra/plugin/amqp/RealtimeTrigger.java +++ b/src/main/java/io/kestra/plugin/amqp/RealtimeTrigger.java @@ -1,19 +1,40 @@ package io.kestra.plugin.amqp; +import com.rabbitmq.client.CancelCallback; +import com.rabbitmq.client.Channel; +import com.rabbitmq.client.Connection; +import com.rabbitmq.client.ConnectionFactory; +import com.rabbitmq.client.DeliverCallback; import io.kestra.core.models.annotations.Example; import io.kestra.core.models.annotations.Plugin; import io.kestra.core.models.conditions.ConditionContext; import io.kestra.core.models.executions.Execution; -import io.kestra.core.models.triggers.*; +import io.kestra.core.models.triggers.AbstractTrigger; +import io.kestra.core.models.triggers.RealtimeTriggerInterface; +import io.kestra.core.models.triggers.TriggerContext; +import io.kestra.core.models.triggers.TriggerOutput; +import io.kestra.core.models.triggers.TriggerService; +import io.kestra.core.runners.RunContext; import io.kestra.plugin.amqp.models.Message; import io.kestra.plugin.amqp.models.SerdeType; import io.swagger.v3.oas.annotations.media.Schema; -import lombok.*; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.ToString; import lombok.experimental.SuperBuilder; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; +import java.io.IOException; import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; @SuperBuilder @ToString @@ -51,12 +72,16 @@ public class RealtimeTrigger extends AbstractTrigger implements RealtimeTriggerI @Builder.Default private String consumerTag = "Kestra"; - private Integer maxRecords; + @Builder.Default + private SerdeType serdeType = SerdeType.STRING; - private Duration maxDuration; + @Builder.Default + @Getter(AccessLevel.NONE) + private final AtomicBoolean isActive = new AtomicBoolean(true); @Builder.Default - private SerdeType serdeType = SerdeType.STRING; + @Getter(AccessLevel.NONE) + private final CountDownLatch waitForTermination = new CountDownLatch(1); @Override public Publisher evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception { @@ -69,12 +94,124 @@ public Publisher evaluate(ConditionContext conditionContext, TriggerC .virtualHost(this.virtualHost) .queue(this.queue) .consumerTag(this.consumerTag) - .maxRecords(this.maxRecords) - .maxDuration(this.maxDuration) .serdeType(this.serdeType) .build(); - return Flux.from(task.stream(conditionContext.getRunContext())) + return Flux.from(publisher(task, conditionContext.getRunContext())) .map((record) -> TriggerService.generateRealtimeExecution(this, context, record)); } + + public Publisher publisher(final Consume task, final RunContext runContext) { + return Flux.create( + emitter -> { + final AtomicReference error = new AtomicReference<>(); + try { + final String queue = runContext.render(task.getQueue()); + final String consumerTag = runContext.render(task.getConsumerTag()); + + ConnectionFactory factory = task.connectionFactory(runContext); + Connection connection = factory.newConnection(); + Channel channel = connection.createChannel(); + + final AtomicBoolean basicCancel = new AtomicBoolean(true); + emitter.onDispose(() -> { + try { + if (channel.isOpen() && channel.getConnection().isOpen()) { + if (basicCancel.compareAndSet(true, false)) { + channel.basicCancel(consumerTag); // stop consuming + } + channel.close(); + } + connection.close(); + } catch (IOException | TimeoutException e) { + runContext.logger().warn("Error while closing channel or connection: " + e.getMessage()); + } finally { + waitForTermination.countDown(); + } + }); + + DeliverCallback deliverCallback = (tag, message) -> { + try { + Message output = Message.of(message.getBody(), task.getSerdeType(), message.getProperties()); + emitter.next(output); + channel.basicAck(message.getEnvelope().getDeliveryTag(), false); + } catch (Exception e) { + error.set(e); + isActive.set(false); + } + }; + + CancelCallback cancelCallback = tag -> { + runContext.logger().info("Consumer {} has been cancelled", consumerTag); + basicCancel.set(false); + isActive.set(false); + }; + + // create basic consumer + channel.basicConsume( + queue, + false, // auto-ack + consumerTag, + deliverCallback, + cancelCallback, + (tag, sig) -> {} + ); + + // wait for consumer to be stopped + busyWait(); + + } catch (Throwable e) { + error.set(e); + } finally { + // dispose + Throwable throwable = error.get(); + if (throwable != null) { + emitter.error(throwable); + } else { + emitter.complete(); + } + } + }); + } + + private void busyWait() { + while (isActive.get()) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + isActive.set(false); // proactively stop consuming + } + } + } + + /** + * {@inheritDoc} + **/ + @Override + public void kill() { + stop(true); + } + + /** + * {@inheritDoc} + **/ + @Override + public void stop() { + stop(false); // must be non-blocking + } + + private void stop(boolean wait) { + if (!isActive.compareAndSet(true, false)) { + return; + } + + if (wait) { + try { + this.waitForTermination.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } } diff --git a/src/test/resources/flows/realtime.yaml b/src/test/resources/flows/realtime.yaml index 3966495..b5a639c 100644 --- a/src/test/resources/flows/realtime.yaml +++ b/src/test/resources/flows/realtime.yaml @@ -5,7 +5,6 @@ triggers: - id: watch type: io.kestra.plugin.amqp.RealtimeTrigger url: amqp://guest:guest@localhost:5672/my_vhost - maxRecords: 2 queue: amqpTrigger.queue