diff --git a/core/src/main/java/io/undertow/UndertowLogger.java b/core/src/main/java/io/undertow/UndertowLogger.java index f771ed8ab9..7fd2147acd 100644 --- a/core/src/main/java/io/undertow/UndertowLogger.java +++ b/core/src/main/java/io/undertow/UndertowLogger.java @@ -484,4 +484,8 @@ void nodeConfigCreated(URI connectionURI, String balancer, String domain, String @LogMessage(level = WARN) @Message(id = 5106, value = "Content mismatch for '%s'. Expected length '%s', but was '%s'.") void contentEntryMismatch(Object key, long indicatedSize, long written); -} + + @LogMessage(level = WARN) + @Message(id = 5107, value = "Failed to set web socket timeout.") + void failedToSetWSTimeout(@Cause Exception e); +} \ No newline at end of file diff --git a/core/src/main/java/io/undertow/server/handlers/cache/ResponseCachingStreamSinkConduit.java b/core/src/main/java/io/undertow/server/handlers/cache/ResponseCachingStreamSinkConduit.java index 0206ce7623..6d4008d1e6 100644 --- a/core/src/main/java/io/undertow/server/handlers/cache/ResponseCachingStreamSinkConduit.java +++ b/core/src/main/java/io/undertow/server/handlers/cache/ResponseCachingStreamSinkConduit.java @@ -167,4 +167,4 @@ public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOExcep public int writeFinal(ByteBuffer src) throws IOException { return Conduits.writeFinalBasic(this, src); } -} +} \ No newline at end of file diff --git a/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java b/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java index 4b76338091..d52eb8725d 100644 --- a/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java +++ b/core/src/main/java/io/undertow/websockets/core/WebSocketChannel.java @@ -17,6 +17,7 @@ */ package io.undertow.websockets.core; +import io.undertow.UndertowLogger; import io.undertow.conduits.IdleTimeoutConduit; import io.undertow.server.protocol.framed.AbstractFramedChannel; import io.undertow.server.protocol.framed.AbstractFramedStreamSourceChannel; @@ -28,6 +29,8 @@ import org.xnio.ChannelListeners; import org.xnio.IoUtils; import org.xnio.OptionMap; +import org.xnio.Options; + import io.undertow.connector.ByteBufferPool; import io.undertow.connector.PooledByteBuffer; import org.xnio.StreamConnection; @@ -50,7 +53,16 @@ * @author Stuart Douglas */ public abstract class WebSocketChannel extends AbstractFramedChannel { - + /** + * Configure a read timeout for a web socket, in milliseconds. If its present it will override {@link org.xnio.Options.READ_TIMEOUT}. If the given amount of time elapses without + * a successful read taking place, the socket's next read will throw a {@link ReadTimeoutException}. + */ + public static final String WEB_SOCKETS_READ_TIMEOUT = "io.undertow.websockets.core.read-timeout"; + /** + * Configure a write timeout for a web socket, in milliseconds. If its present it will override {@link org.xnio.Options.WRITE_TIMEOUT}. If the given amount of time elapses without + * a successful write taking place, the socket's next write will throw a {@link WriteTimeoutException}. + */ + public static final String WEB_SOCKETS_WRITE_TIMEOUT = "io.undertow.websockets.core.write-timeout"; private final boolean client; private final WebSocketVersion version; @@ -106,6 +118,22 @@ protected WebSocketChannel(final StreamConnection connectedStreamChannel, ByteBu this.hasReservedOpCode = extensionFunction.hasExtensionOpCode(); this.subProtocol = subProtocol; this.peerConnections = peerConnections; + final String webSocketReadTimeout = System.getProperty(WEB_SOCKETS_READ_TIMEOUT); + if(webSocketReadTimeout != null) { + try { + this.setOption(Options.READ_TIMEOUT, Integer.parseInt(webSocketReadTimeout)); + } catch (Exception e) { + UndertowLogger.ROOT_LOGGER.failedToSetWSTimeout(e); + } + } + final String webSocketWriteTimeout = System.getProperty(WEB_SOCKETS_WRITE_TIMEOUT); + if(webSocketWriteTimeout != null) { + try { + this.setOption(Options.WRITE_TIMEOUT, Integer.parseInt(webSocketWriteTimeout)); + } catch (Exception e) { + UndertowLogger.ROOT_LOGGER.failedToSetWSTimeout(e); + } + } addCloseTask(new ChannelListener() { @Override public void handleEvent(WebSocketChannel channel) { diff --git a/core/src/test/java/io/undertow/testutils/DefaultServer.java b/core/src/test/java/io/undertow/testutils/DefaultServer.java index a04bd949a0..75525d48c3 100644 --- a/core/src/test/java/io/undertow/testutils/DefaultServer.java +++ b/core/src/test/java/io/undertow/testutils/DefaultServer.java @@ -443,7 +443,7 @@ public static boolean startServer() { } else { server = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), 7777 + PROXY_OFFSET), acceptListener, serverOptions); - proxyOpenListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.BUFFER_PIPELINED_DATA, true)); + proxyOpenListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).getMap()); proxyAcceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(proxyOpenListener)); proxyServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT)), proxyAcceptListener, serverOptions); loadBalancingProxyClient = new LoadBalancingProxyClient(GSSAPIAuthenticationMechanism.EXCLUSIVITY_CHECKER) @@ -466,7 +466,7 @@ public static boolean startServer() { server = ssl.createSslConnectionServer(worker, new InetSocketAddress(getHostAddress("default"), 7777 + PROXY_OFFSET), acceptListener, serverOptions); server.resumeAccepts(); - proxyOpenListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.BUFFER_PIPELINED_DATA, true)); + proxyOpenListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).getMap()); proxyAcceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(proxyOpenListener)); proxyServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT)), proxyAcceptListener, serverOptions); loadBalancingProxyClient = new LoadBalancingProxyClient(GSSAPIAuthenticationMechanism.EXCLUSIVITY_CHECKER) @@ -488,13 +488,13 @@ public static boolean startServer() { proxyOpenListener.setRootHandler(proxyHandler); proxyServer.resumeAccepts(); } else if (h2c || h2cUpgrade) { - openListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.ENABLE_HTTP2, true, UndertowOptions.HTTP2_PADDING_SIZE, 10)); + openListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.ENABLE_HTTP2, true).set(UndertowOptions.HTTP2_PADDING_SIZE, 10).getMap()); acceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(openListener)); InetSocketAddress targetAddress = new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT) + PROXY_OFFSET); server = worker.createStreamConnectionServer(targetAddress, acceptListener, serverOptions); - proxyOpenListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.BUFFER_PIPELINED_DATA, true)); + proxyOpenListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).getMap()); proxyAcceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(proxyOpenListener)); proxyServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT)), proxyAcceptListener, serverOptions); loadBalancingProxyClient = new LoadBalancingProxyClient(GSSAPIAuthenticationMechanism.EXCLUSIVITY_CHECKER) @@ -519,13 +519,13 @@ public static boolean startServer() { } else if (https) { XnioSsl clientSsl = new UndertowXnioSsl(xnio, OptionMap.EMPTY, SSL_BUFFER_POOL, createClientSslContext()); - openListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.BUFFER_PIPELINED_DATA, true)); + openListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).getMap()); acceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(openListener)); server = ssl.createSslConnectionServer(worker, new InetSocketAddress(getHostAddress("default"), 7777 + PROXY_OFFSET), acceptListener, serverOptions); server.getAcceptSetter().set(acceptListener); server.resumeAccepts(); - proxyOpenListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.BUFFER_PIPELINED_DATA, true)); + proxyOpenListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).getMap()); proxyAcceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(proxyOpenListener)); proxyServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT)), proxyAcceptListener, serverOptions); loadBalancingProxyClient = new LoadBalancingProxyClient(GSSAPIAuthenticationMechanism.EXCLUSIVITY_CHECKER) @@ -545,7 +545,7 @@ public static boolean startServer() { if (h2) { UndertowLogger.ROOT_LOGGER.error("HTTP2 selected but Netty ALPN was not on the boot class path"); } - openListener = new HttpOpenListener(pool, OptionMap.builder().set(UndertowOptions.BUFFER_PIPELINED_DATA, true).set(UndertowOptions.ENABLE_CONNECTOR_STATISTICS, true).set(UndertowOptions.REQUIRE_HOST_HTTP11, true).getMap()); + openListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).set(UndertowOptions.ENABLE_CONNECTOR_STATISTICS, true).set(UndertowOptions.REQUIRE_HOST_HTTP11, true).getMap()); acceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(openListener)); if (!proxy) { server = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT)), acceptListener, serverOptions); @@ -553,7 +553,7 @@ public static boolean startServer() { InetSocketAddress targetAddress = new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT) + PROXY_OFFSET); server = worker.createStreamConnectionServer(targetAddress, acceptListener, serverOptions); - proxyOpenListener = new HttpOpenListener(pool, OptionMap.create(UndertowOptions.BUFFER_PIPELINED_DATA, true)); + proxyOpenListener = new HttpOpenListener(pool, OptionMap.builder().addAll(serverOptions).set(UndertowOptions.BUFFER_PIPELINED_DATA, true).getMap()); proxyAcceptListener = ChannelListeners.openListenerAdapter(wrapOpenListener(proxyOpenListener)); proxyServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(getHostAddress(DEFAULT)), getHostPort(DEFAULT)), proxyAcceptListener, serverOptions); loadBalancingProxyClient = new LoadBalancingProxyClient(GSSAPIAuthenticationMechanism.EXCLUSIVITY_CHECKER) diff --git a/core/src/test/java/io/undertow/websockets/core/protocol/WebSocketTimeoutTestCase.java b/core/src/test/java/io/undertow/websockets/core/protocol/WebSocketTimeoutTestCase.java new file mode 100644 index 0000000000..9f96bd818b --- /dev/null +++ b/core/src/test/java/io/undertow/websockets/core/protocol/WebSocketTimeoutTestCase.java @@ -0,0 +1,181 @@ +/* + * JBoss, Home of Professional Open Source. + * Copyright 2023 Red Hat, Inc., and individual contributors + * as indicated by the @author tags. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.undertow.websockets.core.protocol; + +import java.io.IOException; +import java.net.URI; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.xnio.FutureResult; +import org.xnio.OptionMap; +import org.xnio.Options; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketVersion; +import io.netty.util.CharsetUtil; +import io.undertow.testutils.DefaultServer; +import io.undertow.testutils.HttpOneOnly; +import io.undertow.util.NetworkUtils; +import io.undertow.websockets.WebSocketConnectionCallback; +import io.undertow.websockets.WebSocketProtocolHandshakeHandler; +import io.undertow.websockets.core.AbstractReceiveListener; +import io.undertow.websockets.core.BufferedTextMessage; +import io.undertow.websockets.core.WebSocketChannel; +import io.undertow.websockets.core.WebSockets; +import io.undertow.websockets.spi.WebSocketHttpExchange; +import io.undertow.websockets.utils.FrameChecker; +import io.undertow.websockets.utils.WebSocketTestClient; + +@RunWith(DefaultServer.class) +@HttpOneOnly +public class WebSocketTimeoutTestCase { + + protected void beforeTest(int regularTimeouts, int wsReadTimeout, int wsWriteTimeout) { + DefaultServer.stopServer(); + DefaultServer.setServerOptions(OptionMap.builder() + .set(Options.READ_TIMEOUT, regularTimeouts) + .set(Options.WRITE_TIMEOUT, regularTimeouts).getMap()); + + DefaultServer.setUndertowOptions(OptionMap.builder() + .set(Options.READ_TIMEOUT, regularTimeouts) + .set(Options.WRITE_TIMEOUT, regularTimeouts).getMap()); + DefaultServer.startServer(); + SCHEDULER = Executors.newScheduledThreadPool(2); + System.setProperty(WebSocketChannel.WEB_SOCKETS_READ_TIMEOUT, ""+wsReadTimeout); + System.setProperty(WebSocketChannel.WEB_SOCKETS_WRITE_TIMEOUT, ""+wsWriteTimeout); + } + + @After + public void afterTest() { + DefaultServer.stopServer(); + DefaultServer.setServerOptions(OptionMap.EMPTY); + DefaultServer.setUndertowOptions(OptionMap.EMPTY); + SCHEDULER.shutdown(); + System.clearProperty(WebSocketChannel.WEB_SOCKETS_READ_TIMEOUT); + System.clearProperty(WebSocketChannel.WEB_SOCKETS_WRITE_TIMEOUT); + } + + protected static final int TESTABLE_TIMEOUT_VALUE = 2000; + protected static final int NON_TESTABLE_TIMEOUT_VALUE = 30180; + protected static final int DEFAULTS_IO_TIMEOTU_VALUE = 500; + private ScheduledExecutorService SCHEDULER; + + protected WebSocketVersion getVersion() { + return WebSocketVersion.V13; + } + + + @Test + public void testServerReadTimeout() throws Exception { + beforeTest(DEFAULTS_IO_TIMEOTU_VALUE, TESTABLE_TIMEOUT_VALUE, NON_TESTABLE_TIMEOUT_VALUE); + final AtomicBoolean connected = new AtomicBoolean(false); + DefaultServer.setRootHandler(new WebSocketProtocolHandshakeHandler(new WebSocketConnectionCallback() { + @Override + public void onConnect(final WebSocketHttpExchange exchange, final WebSocketChannel channel) { + connected.set(true); + channel.getReceiveSetter().set(new AbstractReceiveListener() { + @Override + protected void onFullTextMessage(WebSocketChannel channel, BufferedTextMessage message) throws IOException { + String string = message.getData(); + + if (string.equals("hello")) { + WebSockets.sendText("world", channel, null); + } else { + WebSockets.sendText(string, channel, null); + } + } + }); + channel.resumeReceives(); + } + })); + + final FutureResult latch = new FutureResult(); + WebSocketTestClient client = new WebSocketTestClient(getVersion(), new URI("ws://" + NetworkUtils.formatPossibleIpv6Address(DefaultServer.getHostAddress("default")) + ":" + DefaultServer.getHostPort("default") + "/")); + client.connect(); + client.send(new TextWebSocketFrame(Unpooled.copiedBuffer("hello", CharsetUtil.US_ASCII)), new FrameChecker(TextWebSocketFrame.class, "world".getBytes(CharsetUtil.US_ASCII), latch)); + latch.getIoFuture().get(); + + final long watchStart = System.currentTimeMillis(); + final long watchTimeout = System.currentTimeMillis()+TESTABLE_TIMEOUT_VALUE+500; + final FutureResult timeoutLatch = new FutureResult(); + ReadTimeoutChannelGuard readTimeoutChannelGuard = new ReadTimeoutChannelGuard(client, timeoutLatch, watchTimeout); + + final ScheduledFuture sf = SCHEDULER.scheduleAtFixedRate(readTimeoutChannelGuard, 0, 50, TimeUnit.MILLISECONDS); + readTimeoutChannelGuard.setTaskScheduledFuture(sf); + + final Long watchTimeEnd = timeoutLatch.getIoFuture().get(); + if(watchTimeEnd == -1) { + Assert.fail("Timeout did not happen... in time. Were waiting '"+watchTimeout+"' ms, timeout should happen in '"+TESTABLE_TIMEOUT_VALUE+"' ms."); + } else { + long timeSpent = watchTimeEnd - watchStart; + //lets be generous and give 150ms diff( there is "fuzz" coded for 50ms in undertow as well + if(!(timeSpent<=TESTABLE_TIMEOUT_VALUE+150)) { + Assert.fail("Timeout did not happen... in time. Socket timeout out in '"+timeSpent+"' ms, supposed to happen in '"+TESTABLE_TIMEOUT_VALUE+"' ms."); + } + } + } + + private static class ReadTimeoutChannelGuard implements Runnable{ + private final WebSocketTestClient channel; + private final FutureResult resultHandler; + private final long watchEnd; + private ScheduledFuture sf; + + ReadTimeoutChannelGuard(final WebSocketTestClient channel, final FutureResult resultHandler, final long watchEnd) { + super(); + this.channel = channel; + this.resultHandler = resultHandler; + this.watchEnd = watchEnd; + } + + public void setTaskScheduledFuture(ScheduledFuture sf2) { + this.sf = sf2; + } + + @Override + public void run() { + if(System.currentTimeMillis() > watchEnd) { + sf.cancel(false); + if(channelActive()) { + resultHandler.setResult(new Long(-1)); + } else { + resultHandler.setResult(System.currentTimeMillis()); + } + } else { + if(!channelActive()) { + sf.cancel(false); + resultHandler.setResult(System.currentTimeMillis()); + } + } + } + + private boolean channelActive() { + return channel.isOpen(); + } + + } + } diff --git a/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java b/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java index 373c186b55..8345338783 100644 --- a/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java +++ b/core/src/test/java/io/undertow/websockets/utils/WebSocketTestClient.java @@ -173,6 +173,27 @@ public void onError(Throwable t) { } } + public boolean isActive() { + if(this.ch != null) { + return this.ch.isActive(); + } + return false; + } + + public boolean isOpen() { + if(this.ch != null) { + return this.ch.isOpen(); + } + return false; + } + + public boolean isWritable() { + if(this.ch != null) { + return this.ch.isWritable(); + } + return false; + } + public interface FrameListener { /** * Is called if an WebSocketFrame was received