Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance realtime trigger lifecycle #49

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions src/main/java/io/kestra/plugin/amqp/Consume.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,32 +104,6 @@ public Consume.Output run(RunContext runContext) throws Exception {
}
}

public Publisher<Message> stream(RunContext runContext) {
return Flux.<Message>create(
fluxSink -> {
try {
ConnectionFactory factory = this.connectionFactory(runContext);

try (
ConsumeThread thread = new ConsumeThread(
factory,
runContext,
this,
throwConsumer(fluxSink::next),
() -> false
);
) {
thread.start();
thread.join();
}
} catch (Throwable e) {
fluxSink.error(e);
} finally {
fluxSink.complete();
}
});
}

@SuppressWarnings("RedundantIfStatement")
private boolean ended(AtomicInteger count, ZonedDateTime start) {
if (this.maxRecords != null && count.get() >= this.maxRecords) {
Expand Down
153 changes: 145 additions & 8 deletions src/main/java/io/kestra/plugin/amqp/RealtimeTrigger.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
package io.kestra.plugin.amqp;

import com.rabbitmq.client.CancelCallback;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.DeliverCallback;
import io.kestra.core.models.annotations.Example;
import io.kestra.core.models.annotations.Plugin;
import io.kestra.core.models.conditions.ConditionContext;
import io.kestra.core.models.executions.Execution;
import io.kestra.core.models.triggers.*;
import io.kestra.core.models.triggers.AbstractTrigger;
import io.kestra.core.models.triggers.RealtimeTriggerInterface;
import io.kestra.core.models.triggers.TriggerContext;
import io.kestra.core.models.triggers.TriggerOutput;
import io.kestra.core.models.triggers.TriggerService;
import io.kestra.core.runners.RunContext;
import io.kestra.plugin.amqp.models.Message;
import io.kestra.plugin.amqp.models.SerdeType;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.*;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.ToString;
import lombok.experimental.SuperBuilder;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;

import java.io.IOException;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

@SuperBuilder
@ToString
Expand Down Expand Up @@ -51,12 +72,16 @@ public class RealtimeTrigger extends AbstractTrigger implements RealtimeTriggerI
@Builder.Default
private String consumerTag = "Kestra";

private Integer maxRecords;
@Builder.Default
private SerdeType serdeType = SerdeType.STRING;

private Duration maxDuration;
@Builder.Default
@Getter(AccessLevel.NONE)
private final AtomicBoolean isActive = new AtomicBoolean(true);

@Builder.Default
private SerdeType serdeType = SerdeType.STRING;
@Getter(AccessLevel.NONE)
private final CountDownLatch waitForTermination = new CountDownLatch(1);

@Override
public Publisher<Execution> evaluate(ConditionContext conditionContext, TriggerContext context) throws Exception {
Expand All @@ -69,12 +94,124 @@ public Publisher<Execution> evaluate(ConditionContext conditionContext, TriggerC
.virtualHost(this.virtualHost)
.queue(this.queue)
.consumerTag(this.consumerTag)
.maxRecords(this.maxRecords)
.maxDuration(this.maxDuration)
.serdeType(this.serdeType)
.build();

return Flux.from(task.stream(conditionContext.getRunContext()))
return Flux.from(publisher(task, conditionContext.getRunContext()))
.map((record) -> TriggerService.generateRealtimeExecution(this, context, record));
}

public Publisher<Message> publisher(final Consume task, final RunContext runContext) {
return Flux.create(
emitter -> {
final AtomicReference<Throwable> error = new AtomicReference<>();
try {
final String queue = runContext.render(task.getQueue());
final String consumerTag = runContext.render(task.getConsumerTag());

ConnectionFactory factory = task.connectionFactory(runContext);
Connection connection = factory.newConnection();
Channel channel = connection.createChannel();

final AtomicBoolean basicCancel = new AtomicBoolean(true);
emitter.onDispose(() -> {
try {
if (channel.isOpen() && channel.getConnection().isOpen()) {
if (basicCancel.compareAndSet(true, false)) {
channel.basicCancel(consumerTag); // stop consuming
}
channel.close();
}
connection.close();
} catch (IOException | TimeoutException e) {
runContext.logger().warn("Error while closing channel or connection: " + e.getMessage());
} finally {
waitForTermination.countDown();
}
});

DeliverCallback deliverCallback = (tag, message) -> {
try {
Message output = Message.of(message.getBody(), task.getSerdeType(), message.getProperties());
emitter.next(output);
channel.basicAck(message.getEnvelope().getDeliveryTag(), false);
} catch (Exception e) {
error.set(e);
isActive.set(false);
}
};

CancelCallback cancelCallback = tag -> {
runContext.logger().info("Consumer {} has been cancelled", consumerTag);
basicCancel.set(false);
isActive.set(false);
};

// create basic consumer
channel.basicConsume(
queue,
false, // auto-ack
consumerTag,
deliverCallback,
cancelCallback,
(tag, sig) -> {}
);

// wait for consumer to be stopped
busyWait();

} catch (Throwable e) {
error.set(e);
} finally {
// dispose
Throwable throwable = error.get();
if (throwable != null) {
emitter.error(throwable);
} else {
emitter.complete();
}
}
});
}

private void busyWait() {
while (isActive.get()) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
isActive.set(false); // proactively stop consuming
}
}
}

/**
* {@inheritDoc}
**/
@Override
public void kill() {
stop(true);
}

/**
* {@inheritDoc}
**/
@Override
public void stop() {
stop(false); // must be non-blocking
}

private void stop(boolean wait) {
if (!isActive.compareAndSet(true, false)) {
return;
}

if (wait) {
try {
this.waitForTermination.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
1 change: 0 additions & 1 deletion src/test/resources/flows/realtime.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ triggers:
- id: watch
type: io.kestra.plugin.amqp.RealtimeTrigger
url: amqp://guest:guest@localhost:5672/my_vhost
maxRecords: 2
queue: amqpTrigger.queue


Expand Down
Loading