Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial draft implementation of thundering herd mitigation #410

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions lib/logstash/inputs/beats.rb
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ class LogStash::Inputs::Beats < LogStash::Inputs::Base
# Close Idle clients after X seconds of inactivity.
config :client_inactivity_timeout, :validate => :number, :default => 60

# Beats handler executor thread
config :executor_threads, :validate => :number, :default => LogStash::Config::CpuCoreStrategy.maximum

def register
# For Logstash 2.4 we need to make sure that the logger is correctly set for the
# java classes before actually loading them.
Expand Down Expand Up @@ -162,7 +159,7 @@ def register
end # def register

def create_server
server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout, @executor_threads)
server = org.logstash.beats.Server.new(@host, @port, @client_inactivity_timeout)
if @ssl
ssl_context_builder = new_ssl_context_builder
if client_authentification?
Expand Down
5 changes: 2 additions & 3 deletions spec/inputs/beats_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,15 @@

context "#register" do
context "host related configuration" do
let(:config) { super.merge("host" => host, "port" => port, "client_inactivity_timeout" => client_inactivity_timeout, "executor_threads" => threads) }
let(:config) { super.merge("host" => host, "port" => port, "client_inactivity_timeout" => client_inactivity_timeout) }
let(:host) { "192.168.1.20" }
let(:port) { 9000 }
let(:client_inactivity_timeout) { 400 }
let(:threads) { 10 }

subject(:plugin) { LogStash::Inputs::Beats.new(config) }

it "sends the required options to the server" do
expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout, threads)
expect(org.logstash.beats.Server).to receive(:new).with(host, port, client_inactivity_timeout)
subject.register
end
end
Expand Down
1 change: 0 additions & 1 deletion src/main/java/org/logstash/beats/BeatsHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.AttributeKey;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down
83 changes: 79 additions & 4 deletions src/main/java/org/logstash/beats/BeatsParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.DecoderException;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand All @@ -14,12 +16,13 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.zip.Inflater;
import java.util.zip.InflaterOutputStream;


public class BeatsParser extends ByteToMessageDecoder {
private final static Logger logger = LogManager.getLogger(BeatsParser.class);
private final static long maxDirectMemory = io.netty.util.internal.PlatformDependent.maxDirectMemory();

private Batch batch;

Expand All @@ -45,15 +48,19 @@ private enum States {
private int requiredBytes = 0;
private int sequence = 0;
private boolean decodingCompressedBuffer = false;
private long usedDirectMemory;
private boolean closeCalled = false;

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {

if(!hasEnoughBytes(in)) {
if (decodingCompressedBuffer){
throw new InvalidFrameProtocolException("Insufficient bytes in compressed content to decode: " + currentState);
}
return;
}
usedDirectMemory = ((PooledByteBufAllocator) ctx.alloc()).metric().usedDirectMemory();

switch (currentState) {
case READ_HEADER: {
Expand Down Expand Up @@ -178,6 +185,14 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t

case READ_COMPRESSED_FRAME: {
logger.trace("Running: READ_COMPRESSED_FRAME");

if (usedDirectMemory + requiredBytes > maxDirectMemory * 0.90) {
ctx.channel().config().setAutoRead(false);
ctx.close();
closeCalled = true;
throw new IOException("not enough memory to decompress this from " + ctx.channel().id());
}

// Use the compressed size as the safe start for the buffer.
ByteBuf buffer = inflateCompressedFrame(ctx, in);
transition(States.READ_HEADER);
Expand All @@ -190,14 +205,18 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
} finally {
decodingCompressedBuffer = false;
buffer.release();
ctx.channel().config().setAutoRead(false);
ctx.channel().eventLoop().schedule(() -> {
ctx.channel().config().setAutoRead(true);
}, 5, TimeUnit.MILLISECONDS);
transition(States.READ_HEADER);
}
break;
}
case READ_JSON: {
logger.trace("Running: READ_JSON");
((V2Batch)batch).addMessage(sequence, in, requiredBytes);
if(batch.isComplete()) {
((V2Batch) batch).addMessage(sequence, in, requiredBytes);
if (batch.isComplete()) {
if(logger.isTraceEnabled()) {
logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence);
}
Expand All @@ -212,14 +231,15 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) t
}

private ByteBuf inflateCompressedFrame(final ChannelHandlerContext ctx, final ByteBuf in) throws IOException {

ByteBuf buffer = ctx.alloc().buffer(requiredBytes);
Inflater inflater = new Inflater();
try (
ByteBufOutputStream buffOutput = new ByteBufOutputStream(buffer);
InflaterOutputStream inflaterStream = new InflaterOutputStream(buffOutput, inflater)
) {
in.readBytes(inflaterStream, requiredBytes);
}finally{
}finally {
inflater.end();
}
return buffer;
Expand Down Expand Up @@ -247,4 +267,59 @@ private void batchComplete() {
batch = null;
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
//System.out.println("channelRead(" + ctx.channel().isActive() + ": " + ctx.channel().id() + ":" + currentState + ":" + decodingCompressedBuffer);
if (closeCalled) {
((ByteBuf) msg).release();
//if(batch != null) batch.release();
return;
}
usedDirectMemory = ((PooledByteBufAllocator) ctx.alloc()).metric().usedDirectMemory();

// If we're just beginning a new frame on this channel,
// don't accumulate more data for 25 ms if usage of direct memory is above 20%
//
// The goal here is to avoid thundering herd: many beats connecting and sending data
// at the same time. As some channels progress to other states they'll use more memory
// but also give it back once a full batch is read.
if ((!decodingCompressedBuffer) && (this.currentState != States.READ_COMPRESSED_FRAME)) {
if (usedDirectMemory > (maxDirectMemory * 0.40)) {
ctx.channel().config().setAutoRead(false);
//System.out.println("pausing reads on " + ctx.channel().id());
ctx.channel().eventLoop().schedule(() -> {
//System.out.println("resuming reads on " + ctx.channel().id());
ctx.channel().config().setAutoRead(true);
}, 200, TimeUnit.MILLISECONDS);
} else {
//System.out.println("no need to pause reads on " + ctx.channel().id());
}
} else if (usedDirectMemory > maxDirectMemory * 0.90) {
ctx.channel().config().setAutoRead(false);
ctx.close();
closeCalled = true;
((ByteBuf) msg).release();
if (batch != null) batch.release();
throw new IOException("about to explode, cut them all down " + ctx.channel().id());
}
super.channelRead(ctx, msg);
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
System.out.println(cause.getClass().toString() + ":" + ctx.channel().id().toString() + ":" + this.currentState + "|" + cause.getMessage());
if (cause instanceof DecoderException) {
ctx.channel().config().setAutoRead(false);
if (!closeCalled) ctx.close();
} else if (cause instanceof OutOfMemoryError) {
cause.printStackTrace();
ctx.channel().config().setAutoRead(false);
if (!closeCalled) ctx.close();
} else if (cause instanceof IOException) {
ctx.channel().config().setAutoRead(false);
if (!closeCalled) ctx.close();
} else {
super.exceptionCaught(ctx, cause);
}
}
}
2 changes: 1 addition & 1 deletion src/main/java/org/logstash/beats/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ static public void main(String[] args) throws Exception {
// Check for leaks.
// ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.PARANOID);

Server server = new Server("0.0.0.0", DEFAULT_PORT, 15, Runtime.getRuntime().availableProcessors());
Server server = new Server("0.0.0.0", DEFAULT_PORT, 15);

if(args.length > 0 && args[0].equals("ssl")) {
logger.debug("Using SSL");
Expand Down
17 changes: 5 additions & 12 deletions src/main/java/org/logstash/beats/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@ public class Server {

private final int port;
private final String host;
private final int beatsHeandlerThreadCount;
private NioEventLoopGroup workGroup;
private IMessageListener messageListener = new MessageListener();
private SslHandlerProvider sslHandlerProvider;
private BeatsInitializer beatsInitializer;

private final int clientInactivityTimeoutSeconds;

public Server(String host, int p, int clientInactivityTimeoutSeconds, int threadCount) {
public Server(String host, int p, int clientInactivityTimeoutSeconds) {
this.host = host;
port = p;
this.clientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds;
beatsHeandlerThreadCount = threadCount;
}

public void setSslHandlerProvider(SslHandlerProvider sslHandlerProvider){
Expand All @@ -49,7 +47,7 @@ public Server listen() throws InterruptedException {
try {
logger.info("Starting server on port: {}", this.port);

beatsInitializer = new BeatsInitializer(messageListener, clientInactivityTimeoutSeconds, beatsHeandlerThreadCount);
beatsInitializer = new BeatsInitializer(messageListener, clientInactivityTimeoutSeconds);

ServerBootstrap server = new ServerBootstrap();
server.group(workGroup)
Expand Down Expand Up @@ -99,21 +97,18 @@ private class BeatsInitializer extends ChannelInitializer<SocketChannel> {
private final String CONNECTION_HANDLER = "connection-handler";
private final String BEATS_ACKER = "beats-acker";


private final int DEFAULT_IDLESTATEHANDLER_THREAD = 4;
private final int IDLESTATE_WRITER_IDLE_TIME_SECONDS = 5;

private final EventExecutorGroup idleExecutorGroup;
private final EventExecutorGroup beatsHandlerExecutorGroup;
private final IMessageListener localMessageListener;
private final int localClientInactivityTimeoutSeconds;

BeatsInitializer(IMessageListener messageListener, int clientInactivityTimeoutSeconds, int beatsHandlerThread) {
BeatsInitializer(IMessageListener messageListener, int clientInactivityTimeoutSeconds) {
// Keeps a local copy of Server settings, so they can't be modified once it starts listening
this.localMessageListener = messageListener;
this.localClientInactivityTimeoutSeconds = clientInactivityTimeoutSeconds;
idleExecutorGroup = new DefaultEventExecutorGroup(DEFAULT_IDLESTATEHANDLER_THREAD);
beatsHandlerExecutorGroup = new DefaultEventExecutorGroup(beatsHandlerThread);
}

public void initChannel(SocketChannel socket){
Expand All @@ -126,11 +121,10 @@ public void initChannel(SocketChannel socket){
new IdleStateHandler(localClientInactivityTimeoutSeconds, IDLESTATE_WRITER_IDLE_TIME_SECONDS, localClientInactivityTimeoutSeconds));
pipeline.addLast(BEATS_ACKER, new AckEncoder());
pipeline.addLast(CONNECTION_HANDLER, new ConnectionHandler());
pipeline.addLast(beatsHandlerExecutorGroup, new BeatsParser(), new BeatsHandler(localMessageListener));
pipeline.addLast(new BeatsParser());
pipeline.addLast(new BeatsHandler(localMessageListener));
}



@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
logger.warn("Exception caught in channel initializer", cause);
Expand All @@ -144,7 +138,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E
public void shutdownEventExecutor() {
try {
idleExecutorGroup.shutdownGracefully().sync();
beatsHandlerExecutorGroup.shutdownGracefully().sync();
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
Expand Down
7 changes: 3 additions & 4 deletions src/test/java/org/logstash/beats/ServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public class ServerTest {
private int randomPort;
private EventLoopGroup group;
private final String host = "0.0.0.0";
private final int threadCount = 10;

@Before
public void setUp() {
Expand All @@ -50,7 +49,7 @@ public void testServerShouldTerminateConnectionWhenExceptionHappen() throws Inte

final CountDownLatch latch = new CountDownLatch(concurrentConnections);

final Server server = new Server(host, randomPort, inactivityTime, threadCount);
final Server server = new Server(host, randomPort, inactivityTime);
final AtomicBoolean otherCause = new AtomicBoolean(false);
server.setMessageListener(new MessageListener() {
public void onNewConnection(ChannelHandlerContext ctx) {
Expand Down Expand Up @@ -114,7 +113,7 @@ public void testServerShouldTerminateConnectionIdleForTooLong() throws Interrupt

final CountDownLatch latch = new CountDownLatch(concurrentConnections);
final AtomicBoolean exceptionClose = new AtomicBoolean(false);
final Server server = new Server(host, randomPort, inactivityTime, threadCount);
final Server server = new Server(host, randomPort, inactivityTime);
server.setMessageListener(new MessageListener() {
@Override
public void onNewConnection(ChannelHandlerContext ctx) {
Expand Down Expand Up @@ -170,7 +169,7 @@ public void run() {

@Test
public void testServerShouldAcceptConcurrentConnection() throws InterruptedException {
final Server server = new Server(host, randomPort, 30, threadCount);
final Server server = new Server(host, randomPort, 30);
SpyListener listener = new SpyListener();
server.setMessageListener(listener);
Runnable serverTask = new Runnable() {
Expand Down