diff --git a/src/main/java/org/logstash/beats/BeatsParser.java b/src/main/java/org/logstash/beats/BeatsParser.java index 61337d3b..5e03663f 100644 --- a/src/main/java/org/logstash/beats/BeatsParser.java +++ b/src/main/java/org/logstash/beats/BeatsParser.java @@ -3,8 +3,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.PooledByteBufAllocator; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -14,12 +16,14 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.zip.Inflater; import java.util.zip.InflaterOutputStream; public class BeatsParser extends ByteToMessageDecoder { private final static Logger logger = LogManager.getLogger(BeatsParser.class); + private final static long maxDirectMemory = io.netty.util.internal.PlatformDependent.maxDirectMemory(); private Batch batch; @@ -45,15 +49,18 @@ private enum States { private int requiredBytes = 0; private int sequence = 0; private boolean decodingCompressedBuffer = false; + private long usedDirectMemory; + private boolean closeCalled = false; @Override protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws InvalidFrameProtocolException, IOException { - if(!hasEnoughBytes(in)) { - if (decodingCompressedBuffer){ + if (!hasEnoughBytes(in)) { + if (decodingCompressedBuffer) { throw new InvalidFrameProtocolException("Insufficient bytes in compressed content to decode: " + currentState); } return; } + usedDirectMemory = ((PooledByteBufAllocator) ctx.alloc()).metric().usedDirectMemory(); switch (currentState) { case READ_HEADER: { @@ -178,6 +185,13 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t case READ_COMPRESSED_FRAME: { logger.trace("Running: READ_COMPRESSED_FRAME"); + + if (usedDirectMemory + requiredBytes > maxDirectMemory * 0.90) { + ctx.channel().config().setAutoRead(false); + ctx.close(); + closeCalled = true; + throw new IOException("not enough memory to decompress this from " + ctx.channel().id()); + } inflateCompressedFrame(ctx, in, (buffer) -> { transition(States.READ_HEADER); @@ -188,6 +202,10 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } } finally { decodingCompressedBuffer = false; + ctx.channel().config().setAutoRead(false); + ctx.channel().eventLoop().schedule(() -> { + ctx.channel().config().setAutoRead(true); + }, 5, TimeUnit.MILLISECONDS); transition(States.READ_HEADER); } }); @@ -195,9 +213,9 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) t } case READ_JSON: { logger.trace("Running: READ_JSON"); - ((V2Batch)batch).addMessage(sequence, in, requiredBytes); - if(batch.isComplete()) { - if(logger.isTraceEnabled()) { + ((V2Batch) batch).addMessage(sequence, in, requiredBytes); + if (batch.isComplete()) { + if (logger.isTraceEnabled()) { logger.trace("Sending batch size: " + this.batch.size() + ", windowSize: " + batch.getBatchSize() + " , seq: " + sequence); } out.add(batch); @@ -256,6 +274,62 @@ private void batchComplete() { batch = null; } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + //System.out.println("channelRead(" + ctx.channel().isActive() + ": " + ctx.channel().id() + ":" + currentState + ":" + decodingCompressedBuffer); + if (closeCalled) { + ((ByteBuf) msg).release(); + //if(batch != null) batch.release(); + return; + } + usedDirectMemory = ((PooledByteBufAllocator) ctx.alloc()).metric().usedDirectMemory(); + + // If we're just beginning a new frame on this channel, + // don't accumulate more data for 25 ms if usage of direct memory is above 20% + // + // The goal here is to avoid thundering herd: many beats connecting and sending data + // at the same time. As some channels progress to other states they'll use more memory + // but also give it back once a full batch is read. + if ((!decodingCompressedBuffer) && (this.currentState != States.READ_COMPRESSED_FRAME)) { + if (usedDirectMemory > (maxDirectMemory * 0.40)) { + ctx.channel().config().setAutoRead(false); + //System.out.println("pausing reads on " + ctx.channel().id()); + ctx.channel().eventLoop().schedule(() -> { + //System.out.println("resuming reads on " + ctx.channel().id()); + ctx.channel().config().setAutoRead(true); + }, 200, TimeUnit.MILLISECONDS); + } else { + //System.out.println("no need to pause reads on " + ctx.channel().id()); + } + } else if (usedDirectMemory > maxDirectMemory * 0.90) { + ctx.channel().config().setAutoRead(false); + ctx.close(); + closeCalled = true; + ((ByteBuf) msg).release(); + if (batch != null) batch.release(); + throw new IOException("about to explode, cut them all down " + ctx.channel().id()); + } + super.channelRead(ctx, msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + System.out.println(cause.getClass().toString() + ":" + ctx.channel().id().toString() + ":" + this.currentState + "|" + cause.getMessage()); + if (cause instanceof DecoderException) { + ctx.channel().config().setAutoRead(false); + if (!closeCalled) ctx.close(); + } else if (cause instanceof OutOfMemoryError) { + cause.printStackTrace(); + ctx.channel().config().setAutoRead(false); + if (!closeCalled) ctx.close(); + } else if (cause instanceof IOException) { + ctx.channel().config().setAutoRead(false); + if (!closeCalled) ctx.close(); + } else { + super.exceptionCaught(ctx, cause); + } + } + @FunctionalInterface private interface CheckedConsumer { void accept(T t) throws IOException;