Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement static compression and encryption pipeline #858

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void packetSent(Session session, Packet packet) {
public void connected(ConnectedEvent event) {
log.info("CLIENT Connected");

event.getSession().enableEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().setEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().send(new PingPacket("hello"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public void serverClosed(ServerClosedEvent event) {
public void sessionAdded(SessionAddedEvent event) {
log.info("SERVER Session Added: {}:{}", event.getSession().getHost(), event.getSession().getPort());
((TestProtocol) event.getSession().getPacketProtocol()).setSecretKey(this.key);
event.getSession().enableEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
event.getSession().setEncryption(((TestProtocol) event.getSession().getPacketProtocol()).getEncryption());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.geysermc.mcprotocollib.network.codec.PacketDefinition;
import org.geysermc.mcprotocollib.network.codec.PacketSerializer;
import org.geysermc.mcprotocollib.network.crypt.AESEncryption;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.packet.DefaultPacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketHeader;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
Expand All @@ -23,7 +23,7 @@ public class TestProtocol extends PacketProtocol {
private static final Logger log = LoggerFactory.getLogger(TestProtocol.class);
private final PacketHeader header = new DefaultPacketHeader();
private final PacketRegistry registry = new PacketRegistry();
private AESEncryption encrypt;
private EncryptionConfig encrypt;

@SuppressWarnings("unused")
public TestProtocol() {
Expand Down Expand Up @@ -51,7 +51,7 @@ public PingPacket deserialize(ByteBuf buf, PacketCodecHelper helper, PacketDefin
});

try {
this.encrypt = new AESEncryption(key);
this.encrypt = new EncryptionConfig(new AESEncryption(key));
} catch (GeneralSecurityException e) {
log.error("Failed to create encryption", e);
}
Expand All @@ -67,7 +67,7 @@ public PacketHeader getPacketHeader() {
return this.header;
}

public PacketEncryption getEncryption() {
public EncryptionConfig getEncryption() {
return this.encrypt;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
public class MinecraftProtocolTest {
private static final Logger log = LoggerFactory.getLogger(MinecraftProtocolTest.class);
private static final boolean SPAWN_SERVER = true;
private static final boolean VERIFY_USERS = false;
private static final boolean ENCRYPT_CONNECTION = true;
private static final boolean SHOULD_AUTHENTICATE = false;
private static final String HOST = "127.0.0.1";
private static final int PORT = 25565;
private static final ProxyInfo PROXY = null;
Expand All @@ -63,7 +64,8 @@ public static void main(String[] args) {

Server server = new TcpServer(HOST, PORT, MinecraftProtocol::new);
server.setGlobalFlag(MinecraftConstants.SESSION_SERVICE_KEY, sessionService);
server.setGlobalFlag(MinecraftConstants.VERIFY_USERS_KEY, VERIFY_USERS);
server.setGlobalFlag(MinecraftConstants.ENCRYPT_CONNECTION, ENCRYPT_CONNECTION);
server.setGlobalFlag(MinecraftConstants.SHOULD_AUTHENTICATE, SHOULD_AUTHENTICATE);
server.setGlobalFlag(MinecraftConstants.SERVER_INFO_BUILDER_KEY, session ->
new ServerStatusInfo(
Component.text("Hello world!"),
Expand Down Expand Up @@ -100,7 +102,7 @@ public static void main(String[] args) {
))
);

server.setGlobalFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, 100);
server.setGlobalFlag(MinecraftConstants.SERVER_COMPRESSION_THRESHOLD, 256);
server.addListener(new ServerAdapter() {
@Override
public void serverClosed(ServerClosedEvent event) {
Expand Down Expand Up @@ -177,7 +179,7 @@ private static void status() {

private static void login() {
MinecraftProtocol protocol;
if (VERIFY_USERS) {
if (SHOULD_AUTHENTICATE) {
StepFullJavaSession.FullJavaSession fullJavaSession;
try {
fullJavaSession = MinecraftAuth.JAVA_CREDENTIALS_LOGIN.getFromInput(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package org.geysermc.mcprotocollib.network;

import io.netty.util.AttributeKey;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;

public class NetworkConstants {
public static final AttributeKey<CompressionConfig> COMPRESSION_ATTRIBUTE_KEY = AttributeKey.valueOf("compression");
public static final AttributeKey<EncryptionConfig> ENCRYPTION_ATTRIBUTE_KEY = AttributeKey.valueOf("encryption");
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;
import org.geysermc.mcprotocollib.network.event.session.SessionEvent;
import org.geysermc.mcprotocollib.network.event.session.SessionListener;
import org.geysermc.mcprotocollib.network.packet.Packet;
Expand Down Expand Up @@ -183,26 +184,21 @@ public interface Session {
void callPacketSent(Packet packet);

/**
* Gets the compression packet length threshold for this session (-1 = disabled).
* Sets the compression config for this session.
*
* @return This session's compression threshold.
* @param compressionConfig the compression to compress with,
* or null to disable compression
*/
int getCompressionThreshold();
void setCompression(@Nullable CompressionConfig compressionConfig);

/**
* Sets the compression packet length threshold for this session (-1 = disabled).
* Sets encryption for this session.
*
* @param threshold The new compression threshold.
* @param validateDecompression whether to validate that the decompression fits within size checks.
*/
void setCompressionThreshold(int threshold, boolean validateDecompression);

/**
* Enables encryption for this session.
* @param encryptionConfig the encryption to encrypt with,
* or null to disable encryption
*
* @param encryption the encryption to encrypt with
*/
void enableEncryption(PacketEncryption encryption);
void setEncryption(@Nullable EncryptionConfig encryptionConfig);

/**
* Returns true if the session is connected.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.geysermc.mcprotocollib.network.compression;

public record CompressionConfig(int threshold, PacketCompression compression, boolean validateDecompression) {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package org.geysermc.mcprotocollib.network.crypt;

public record EncryptionConfig(PacketEncryption encryption) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ public void initChannel(Channel channel) {
pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));

pipeline.addLast("encryption", new TcpPacketEncryptor());
pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), getCodecHelper()));
pipeline.addLast("compression", new TcpPacketCompression(getCodecHelper()));

pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,71 @@
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.Session;
import org.geysermc.mcprotocollib.network.compression.PacketCompression;
import lombok.RequiredArgsConstructor;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.compression.CompressionConfig;

import java.util.List;

@RequiredArgsConstructor
public class TcpPacketCompression extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private static final int MAX_UNCOMPRESSED_SIZE = 8 * 1024 * 1024; // 8MiB

private final Session session;
private final PacketCompression compression;
private final boolean validateDecompression;

public TcpPacketCompression(Session session, PacketCompression compression, boolean validateDecompression) {
this.session = session;
this.compression = compression;
this.validateDecompression = validateDecompression;
}
private final PacketCodecHelper helper;

@Override
public void handlerRemoved(ChannelHandlerContext ctx) {
this.compression.close();
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
return;
}

config.compression().close();
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
return;
}

int uncompressed = msg.readableBytes();
if (uncompressed > MAX_UNCOMPRESSED_SIZE) {
throw new IllegalArgumentException("Packet too big (is " + uncompressed + ", should be less than " + MAX_UNCOMPRESSED_SIZE + ")");
}

ByteBuf outBuf = ctx.alloc().directBuffer(uncompressed);
if (uncompressed < this.session.getCompressionThreshold()) {
if (uncompressed < config.threshold()) {
// Under the threshold, there is nothing to do.
this.session.getCodecHelper().writeVarInt(outBuf, 0);
this.helper.writeVarInt(outBuf, 0);
outBuf.writeBytes(msg);
} else {
this.session.getCodecHelper().writeVarInt(outBuf, uncompressed);
compression.deflate(msg, outBuf);
this.helper.writeVarInt(outBuf, uncompressed);
config.compression().deflate(msg, outBuf);
}

out.add(outBuf);
}

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
int claimedUncompressedSize = this.session.getCodecHelper().readVarInt(in);
CompressionConfig config = ctx.channel().attr(NetworkConstants.COMPRESSION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
return;
}

int claimedUncompressedSize = this.helper.readVarInt(in);
if (claimedUncompressedSize == 0) {
out.add(in.retain());
return;
}

if (validateDecompression) {
if (claimedUncompressedSize < this.session.getCompressionThreshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + this.session.getCompressionThreshold());
if (config.validateDecompression()) {
if (claimedUncompressedSize < config.threshold()) {
throw new DecoderException("Badly compressed packet - size of " + claimedUncompressedSize + " is below server threshold of " + config.threshold());
}

if (claimedUncompressedSize > MAX_UNCOMPRESSED_SIZE) {
Expand All @@ -67,7 +78,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {

ByteBuf uncompressed = ctx.alloc().directBuffer(claimedUncompressedSize);
try {
compression.inflate(in, uncompressed, claimedUncompressedSize);
config.compression().inflate(in, uncompressed, claimedUncompressedSize);
out.add(uncompressed);
} catch (Exception e) {
uncompressed.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.EncoderException;
import io.netty.handler.codec.MessageToMessageCodec;
import org.geysermc.mcprotocollib.network.crypt.PacketEncryption;
import org.geysermc.mcprotocollib.network.NetworkConstants;
import org.geysermc.mcprotocollib.network.crypt.EncryptionConfig;

import java.util.List;

public class TcpPacketEncryptor extends MessageToMessageCodec<ByteBuf, ByteBuf> {
private final PacketEncryption encryption;

public TcpPacketEncryptor(PacketEncryption encryption) {
this.encryption = encryption;
}

@Override
public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(msg.retain());
return;
}

ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), msg);

int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();

try {
encryption.encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
config.encryption().encrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
Expand All @@ -35,13 +36,19 @@ public void encode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {

@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
EncryptionConfig config = ctx.channel().attr(NetworkConstants.ENCRYPTION_ATTRIBUTE_KEY).get();
if (config == null) {
out.add(in.retain());
return;
}

ByteBuf heapBuf = this.ensureHeapBuffer(ctx.alloc(), in).slice();

int inBytes = heapBuf.readableBytes();
int baseOffset = heapBuf.arrayOffset() + heapBuf.readerIndex();

try {
encryption.decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
config.encryption().decrypt(heapBuf.array(), baseOffset, inBytes, heapBuf.array(), baseOffset);
out.add(heapBuf);
} catch (Exception e) {
heapBuf.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ public void initChannel(Channel channel) {
pipeline.addLast("read-timeout", new ReadTimeoutHandler(session.getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(session.getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));

pipeline.addLast("encryption", new TcpPacketEncryptor());
pipeline.addLast("sizer", new TcpPacketSizer(protocol.getPacketHeader(), session.getCodecHelper()));
pipeline.addLast("compression", new TcpPacketCompression(session.getCodecHelper()));

pipeline.addLast("codec", new TcpPacketCodec(session, false));
pipeline.addLast("manager", session);
Expand Down
Loading
Loading