diff --git a/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyChannel.java b/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyChannel.java index 72c96efeed..753a2c5e91 100644 --- a/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyChannel.java +++ b/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyChannel.java @@ -35,8 +35,12 @@ public NettyChannel(io.netty.channel.Channel channel, ProtocolConfig config) { this.config = config; if (channel != null) { // can't get the remote address while using udp, so the remoteAddress is null - this.remoteAddress = ((InetSocketAddress) channel.remoteAddress()); - this.localAddress = (InetSocketAddress) channel.localAddress(); + if (channel.remoteAddress() instanceof InetSocketAddress) { + this.remoteAddress = ((InetSocketAddress) channel.remoteAddress()); + } + if (channel.localAddress() instanceof InetSocketAddress) { + this.localAddress = (InetSocketAddress) channel.localAddress(); + } } // listen for the close event if (channel != null && channel.closeFuture() != null) { diff --git a/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyCodecAdapter.java b/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyCodecAdapter.java index ec41574204..670d02c51f 100644 --- a/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyCodecAdapter.java +++ b/trpc-transport/trpc-transport-netty/src/main/java/com/tencent/trpc/transport/netty/NettyCodecAdapter.java @@ -170,6 +170,9 @@ private void decode(ChannelHandlerContext ctx, ByteBuf input, List out) } } } while (message.isReadable()); + } catch (Exception e) { + message.skipBytes(message.readableBytes()); + throw new TransportException("tcp|decode failure", e); } finally { NettyChannelManager.removeChannelIfDisconnected(ctx.channel()); } diff --git a/trpc-transport/trpc-transport-netty/src/test/java/com/tencent/trpc/transport/netty/NettyCodecAdapterTest.java b/trpc-transport/trpc-transport-netty/src/test/java/com/tencent/trpc/transport/netty/NettyCodecAdapterTest.java new file mode 100644 index 0000000000..17e0d7bd0b --- /dev/null +++ b/trpc-transport/trpc-transport-netty/src/test/java/com/tencent/trpc/transport/netty/NettyCodecAdapterTest.java @@ -0,0 +1,82 @@ +package com.tencent.trpc.transport.netty; + +import com.tencent.trpc.core.common.config.ProtocolConfig; +import com.tencent.trpc.core.exception.ErrorCode; +import com.tencent.trpc.core.exception.TRpcException; +import com.tencent.trpc.core.exception.TransportException; +import com.tencent.trpc.core.transport.codec.Codec; +import io.netty.buffer.AbstractByteBufAllocator; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderException; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + + +public class NettyCodecAdapterTest { + + @Test + public void testTcpDecodeIllegalPacket1() { + Codec codec = mock(Codec.class); + doThrow(TRpcException.newFrameException(ErrorCode.TRPC_CLIENT_DECODE_ERR, "the request protocol is not trpc")) + .when(codec).decode(any(), any()); + + + ProtocolConfig protocolConfig = new ProtocolConfig(); + // set batchDecoder true + protocolConfig.setBatchDecoder(true); + NettyCodecAdapter nettyCodecAdapter = NettyCodecAdapter.createTcpCodecAdapter(codec, protocolConfig); + + ChannelHandler decoder = nettyCodecAdapter.getDecoder(); + EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + embeddedChannel.pipeline().addLast(decoder); + + ByteBuf byteBuf = AbstractByteBufAllocator.DEFAULT.heapBuffer(); + byteBuf.writeBytes("testTcpDecodeIllegalPacket1".getBytes(StandardCharsets.UTF_8)); + + // write illegal packet + EmbeddedChannel tmpEmbeddedChannel = embeddedChannel; + DecoderException decoderException = Assert.assertThrows(DecoderException.class, () -> { + tmpEmbeddedChannel.writeInbound(byteBuf); + }); + + Assert.assertTrue(decoderException.getCause() instanceof TransportException); + Assert.assertEquals(byteBuf.refCnt(), 0); + } + + @Test + public void testTcpDecodeIllegalPacket2() { + Codec codec = mock(Codec.class); + doThrow(TRpcException.newFrameException(ErrorCode.TRPC_CLIENT_DECODE_ERR, "the request protocol is not trpc")) + .when(codec).decode(any(), any()); + + + ProtocolConfig protocolConfig = new ProtocolConfig(); + // set batchDecoder false + protocolConfig.setBatchDecoder(false); + NettyCodecAdapter nettyCodecAdapter = NettyCodecAdapter.createTcpCodecAdapter(codec, protocolConfig); + + ChannelHandler decoder = nettyCodecAdapter.getDecoder(); + EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + embeddedChannel.pipeline().addLast(decoder); + + ByteBuf byteBuf = AbstractByteBufAllocator.DEFAULT.heapBuffer(); + byteBuf.writeBytes("testTcpDecodeIllegalPacket1".getBytes(StandardCharsets.UTF_8)); + + // write illegal packet + EmbeddedChannel tmpEmbeddedChannel = embeddedChannel; + DecoderException decoderException = Assert.assertThrows(DecoderException.class, () -> { + tmpEmbeddedChannel.writeInbound(byteBuf); + }); + + Assert.assertTrue(decoderException.getCause() instanceof TransportException); + Assert.assertEquals(byteBuf.refCnt(), 0); + } +}