diff --git a/core/src/main/java/io/undertow/server/HttpServerExchange.java b/core/src/main/java/io/undertow/server/HttpServerExchange.java index a8e2435fde..3515907dd1 100644 --- a/core/src/main/java/io/undertow/server/HttpServerExchange.java +++ b/core/src/main/java/io/undertow/server/HttpServerExchange.java @@ -80,6 +80,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import static org.xnio.Bits.allAreSet; import static org.xnio.Bits.anyAreClear; @@ -166,7 +167,8 @@ public final class HttpServerExchange extends AbstractAttachable { // mutable state - private int state = 200; + private volatile int state = 200; + private static final AtomicIntegerFieldUpdater stateUpdater = AtomicIntegerFieldUpdater.newUpdater(HttpServerExchange.class, "state"); private HttpString requestMethod = HttpString.EMPTY; private String requestScheme; @@ -488,9 +490,9 @@ public HttpServerExchange setRequestURI(final String requestURI) { public HttpServerExchange setRequestURI(final String requestURI, boolean containsHost) { this.requestURI = requestURI; if (containsHost) { - this.state |= FLAG_URI_CONTAINS_HOST; + setFlags(FLAG_URI_CONTAINS_HOST); } else { - this.state &= ~FLAG_URI_CONTAINS_HOST; + clearFlags(FLAG_URI_CONTAINS_HOST); } return this; } @@ -771,9 +773,9 @@ void updateBytesSent(long bytes) { public HttpServerExchange setPersistent(final boolean persistent) { if (persistent) { - this.state = this.state | FLAG_PERSISTENT; + setFlags(FLAG_PERSISTENT); } else { - this.state = this.state & ~FLAG_PERSISTENT; + clearFlags(FLAG_PERSISTENT); } return this; } @@ -783,7 +785,7 @@ public boolean isDispatched() { } public HttpServerExchange unDispatch() { - state &= ~FLAG_DISPATCHED; + clearFlags(FLAG_DISPATCHED); dispatchTask = null; return this; } @@ -797,7 +799,7 @@ public HttpServerExchange unDispatch() { */ @Deprecated public HttpServerExchange dispatch() { - state |= FLAG_DISPATCHED; + setFlags(FLAG_DISPATCHED); return this; } @@ -833,7 +835,7 @@ public HttpServerExchange dispatch(final Executor executor, final Runnable runna if (executor != null) { this.dispatchExecutor = executor; } - state |= FLAG_DISPATCHED; + setFlags(FLAG_DISPATCHED); if(anyAreSet(state, FLAG_SHOULD_RESUME_READS | FLAG_SHOULD_RESUME_WRITES)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -900,9 +902,9 @@ boolean isInCall() { HttpServerExchange setInCall(boolean value) { if (value) { - state |= FLAG_IN_CALL; + setFlags(FLAG_IN_CALL); } else { - state &= ~FLAG_IN_CALL; + clearFlags(FLAG_IN_CALL); } return this; } @@ -1276,7 +1278,7 @@ public boolean isResponseStarted() { public StreamSourceChannel getRequestChannel() { if (requestChannel != null) { if(anyAreSet(state, FLAG_REQUEST_RESET)) { - state &= ~FLAG_REQUEST_RESET; + clearFlags(FLAG_REQUEST_RESET); return requestChannel; } return null; @@ -1295,7 +1297,7 @@ public StreamSourceChannel getRequestChannel() { } void resetRequestChannel() { - state |= FLAG_REQUEST_RESET; + setFlags(FLAG_REQUEST_RESET); } public boolean isRequestChannelAvailable() { @@ -1336,8 +1338,7 @@ public boolean isResponseComplete() { * the socket or implement a transfer coding. */ void terminateRequest() { - int oldVal = state; - if (allAreSet(oldVal, FLAG_REQUEST_TERMINATED)) { + if (allAreSet(state, FLAG_REQUEST_TERMINATED)) { // idempotent return; } @@ -1345,8 +1346,8 @@ void terminateRequest() { requestChannel.suspendReads(); requestChannel.requestDone(); } - this.state = oldVal | FLAG_REQUEST_TERMINATED; - if (anyAreSet(oldVal, FLAG_RESPONSE_TERMINATED)) { + setFlags(FLAG_REQUEST_TERMINATED); + if (anyAreSet(state, FLAG_RESPONSE_TERMINATED)) { invokeExchangeCompleteListeners(); } } @@ -1484,8 +1485,7 @@ public HttpServerExchange setStatusCode(final int statusCode) { if (statusCode < 0 || statusCode > 999) { throw new IllegalArgumentException("Invalid response code"); } - int oldVal = state; - if (allAreSet(oldVal, FLAG_RESPONSE_SENT)) { + if (allAreSet(state, FLAG_RESPONSE_SENT)) { throw UndertowMessages.MESSAGES.responseAlreadyStarted(); } if(statusCode >= 500) { @@ -1493,7 +1493,8 @@ public HttpServerExchange setStatusCode(final int statusCode) { UndertowLogger.ERROR_RESPONSE.debugf(new RuntimeException(), "Setting error code %s for exchange %s", statusCode, this); } } - this.state = oldVal & ~MASK_RESPONSE_CODE | statusCode & MASK_RESPONSE_CODE; + clearFlags(MASK_RESPONSE_CODE); + setFlags(statusCode & MASK_RESPONSE_CODE); return this; } @@ -1633,8 +1634,7 @@ public OutputStream getOutputStream() { * the socket or implement a transfer coding. */ HttpServerExchange terminateResponse() { - int oldVal = state; - if (allAreSet(oldVal, FLAG_RESPONSE_TERMINATED)) { + if (allAreSet(state, FLAG_RESPONSE_TERMINATED)) { // idempotent return this; } @@ -1642,8 +1642,8 @@ HttpServerExchange terminateResponse() { responseChannel.suspendWrites(); responseChannel.responseDone(); } - this.state = oldVal | FLAG_RESPONSE_TERMINATED; - if (anyAreSet(oldVal, FLAG_REQUEST_TERMINATED)) { + setFlags(FLAG_RESPONSE_TERMINATED); + if (anyAreSet(state, FLAG_REQUEST_TERMINATED)) { invokeExchangeCompleteListeners(); } return this; @@ -1878,11 +1878,10 @@ public void handleException(final Channel channel, final IOException exception) * @throws IllegalStateException if the response headers were already sent */ HttpServerExchange startResponse() throws IllegalStateException { - int oldVal = state; - if (allAreSet(oldVal, FLAG_RESPONSE_SENT)) { + if (allAreSet(state, FLAG_RESPONSE_SENT)) { throw UndertowMessages.MESSAGES.responseAlreadyStarted(); } - this.state = oldVal | FLAG_RESPONSE_SENT; + setFlags(FLAG_RESPONSE_SENT); log.tracef("Starting to write response for %s", this); return this; @@ -2077,7 +2076,7 @@ protected boolean isFinished() { @Override public void resumeWrites() { if (isInCall()) { - state |= FLAG_SHOULD_RESUME_WRITES; + setFlags(FLAG_SHOULD_RESUME_WRITES); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2088,7 +2087,7 @@ public void resumeWrites() { @Override public void suspendWrites() { - state &= ~FLAG_SHOULD_RESUME_WRITES; + clearFlags(FLAG_SHOULD_RESUME_WRITES); super.suspendWrites(); } @@ -2099,7 +2098,7 @@ public void wakeupWrites() { } if (isInCall()) { wakeup = true; - state |= FLAG_SHOULD_RESUME_WRITES; + setFlags(FLAG_SHOULD_RESUME_WRITES); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2120,10 +2119,10 @@ public void runResume() { } else { if (wakeup) { wakeup = false; - state &= ~FLAG_SHOULD_RESUME_WRITES; + clearFlags(FLAG_SHOULD_RESUME_WRITES); delegate.wakeupWrites(); } else { - state &= ~FLAG_SHOULD_RESUME_WRITES; + clearFlags(FLAG_SHOULD_RESUME_WRITES); delegate.resumeWrites(); } } @@ -2250,7 +2249,7 @@ protected boolean isFinished() { public void resumeReads() { readsResumed = true; if (isInCall()) { - state |= FLAG_SHOULD_RESUME_READS; + setFlags(FLAG_SHOULD_RESUME_READS); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2263,7 +2262,7 @@ public void resumeReads() { public void wakeupReads() { if (isInCall()) { wakeup = true; - state |= FLAG_SHOULD_RESUME_READS; + setFlags(FLAG_SHOULD_RESUME_READS); if(anyAreSet(state, FLAG_DISPATCHED)) { throw UndertowMessages.MESSAGES.resumedAndDispatched(); } @@ -2320,7 +2319,7 @@ public void awaitReadable() throws IOException { @Override public void suspendReads() { readsResumed = false; - state &= ~(FLAG_SHOULD_RESUME_READS); + clearFlags(FLAG_SHOULD_RESUME_READS); super.suspendReads(); } @@ -2490,10 +2489,10 @@ public void runResume() { } else { if (wakeup) { wakeup = false; - state &= ~FLAG_SHOULD_RESUME_READS; + clearFlags(FLAG_SHOULD_RESUME_READS); delegate.wakeupReads(); } else { - state &= ~FLAG_SHOULD_RESUME_READS; + clearFlags(FLAG_SHOULD_RESUME_READS); delegate.resumeReads(); } } @@ -2558,4 +2557,18 @@ public T create() { public String toString() { return "HttpServerExchange{ " + getRequestMethod().toString() + " " + getRequestURI() + '}'; } + + private void setFlags(int flags) { + int old; + do { + old = state; + } while (!stateUpdater.compareAndSet(this, old, old | flags)); + } + + private void clearFlags(int flags) { + int old; + do { + old = state; + } while (!stateUpdater.compareAndSet(this, old, old & ~flags)); + } }