Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UNDERTOW-2436] fix HttpServerExchange state flag race conditions #1661

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 49 additions & 36 deletions core/src/main/java/io/undertow/server/HttpServerExchange.java
Original file line number Diff line number Diff line change
@@ -80,6 +80,7 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;

import static org.xnio.Bits.allAreSet;
import static org.xnio.Bits.anyAreClear;
@@ -166,7 +167,8 @@ public final class HttpServerExchange extends AbstractAttachable {

// mutable state

private int state = 200;
private volatile int state = 200;
private static final AtomicIntegerFieldUpdater<HttpServerExchange> stateUpdater = AtomicIntegerFieldUpdater.newUpdater(HttpServerExchange.class, "state");
private HttpString requestMethod = HttpString.EMPTY;
private String requestScheme;

@@ -488,9 +490,9 @@ public HttpServerExchange setRequestURI(final String requestURI) {
public HttpServerExchange setRequestURI(final String requestURI, boolean containsHost) {
this.requestURI = requestURI;
if (containsHost) {
this.state |= FLAG_URI_CONTAINS_HOST;
setFlags(FLAG_URI_CONTAINS_HOST);
} else {
this.state &= ~FLAG_URI_CONTAINS_HOST;
clearFlags(FLAG_URI_CONTAINS_HOST);
}
return this;
}
@@ -771,9 +773,9 @@ void updateBytesSent(long bytes) {

public HttpServerExchange setPersistent(final boolean persistent) {
if (persistent) {
this.state = this.state | FLAG_PERSISTENT;
setFlags(FLAG_PERSISTENT);
} else {
this.state = this.state & ~FLAG_PERSISTENT;
clearFlags(FLAG_PERSISTENT);
}
return this;
}
@@ -783,7 +785,7 @@ public boolean isDispatched() {
}

public HttpServerExchange unDispatch() {
state &= ~FLAG_DISPATCHED;
clearFlags(FLAG_DISPATCHED);
dispatchTask = null;
return this;
}
@@ -797,7 +799,7 @@ public HttpServerExchange unDispatch() {
*/
@Deprecated
public HttpServerExchange dispatch() {
state |= FLAG_DISPATCHED;
setFlags(FLAG_DISPATCHED);
return this;
}

@@ -833,7 +835,7 @@ public HttpServerExchange dispatch(final Executor executor, final Runnable runna
if (executor != null) {
this.dispatchExecutor = executor;
}
state |= FLAG_DISPATCHED;
setFlags(FLAG_DISPATCHED);
if(anyAreSet(state, FLAG_SHOULD_RESUME_READS | FLAG_SHOULD_RESUME_WRITES)) {
throw UndertowMessages.MESSAGES.resumedAndDispatched();
}
@@ -900,9 +902,9 @@ boolean isInCall() {

HttpServerExchange setInCall(boolean value) {
if (value) {
state |= FLAG_IN_CALL;
setFlags(FLAG_IN_CALL);
} else {
state &= ~FLAG_IN_CALL;
clearFlags(FLAG_IN_CALL);
}
return this;
}
@@ -1276,7 +1278,7 @@ public boolean isResponseStarted() {
public StreamSourceChannel getRequestChannel() {
if (requestChannel != null) {
if(anyAreSet(state, FLAG_REQUEST_RESET)) {
state &= ~FLAG_REQUEST_RESET;
clearFlags(FLAG_REQUEST_RESET);
return requestChannel;
}
return null;
@@ -1295,7 +1297,7 @@ public StreamSourceChannel getRequestChannel() {
}

void resetRequestChannel() {
state |= FLAG_REQUEST_RESET;
setFlags(FLAG_REQUEST_RESET);
}

public boolean isRequestChannelAvailable() {
@@ -1336,17 +1338,16 @@ public boolean isResponseComplete() {
* the socket or implement a transfer coding.
*/
void terminateRequest() {
int oldVal = state;
if (allAreSet(oldVal, FLAG_REQUEST_TERMINATED)) {
if (allAreSet(state, FLAG_REQUEST_TERMINATED)) {
// idempotent
return;
}
if (requestChannel != null) {
requestChannel.suspendReads();
requestChannel.requestDone();
}
this.state = oldVal | FLAG_REQUEST_TERMINATED;
if (anyAreSet(oldVal, FLAG_RESPONSE_TERMINATED)) {
setFlags(FLAG_REQUEST_TERMINATED);
if (anyAreSet(state, FLAG_RESPONSE_TERMINATED)) {
invokeExchangeCompleteListeners();
}
}
@@ -1484,16 +1485,16 @@ public HttpServerExchange setStatusCode(final int statusCode) {
if (statusCode < 0 || statusCode > 999) {
throw new IllegalArgumentException("Invalid response code");
}
int oldVal = state;
if (allAreSet(oldVal, FLAG_RESPONSE_SENT)) {
if (allAreSet(state, FLAG_RESPONSE_SENT)) {
throw UndertowMessages.MESSAGES.responseAlreadyStarted();
}
if(statusCode >= 500) {
if(UndertowLogger.ERROR_RESPONSE.isDebugEnabled()) {
UndertowLogger.ERROR_RESPONSE.debugf(new RuntimeException(), "Setting error code %s for exchange %s", statusCode, this);
}
}
this.state = oldVal & ~MASK_RESPONSE_CODE | statusCode & MASK_RESPONSE_CODE;
clearFlags(MASK_RESPONSE_CODE);
setFlags(statusCode & MASK_RESPONSE_CODE);
return this;
}

@@ -1633,17 +1634,16 @@ public OutputStream getOutputStream() {
* the socket or implement a transfer coding.
*/
HttpServerExchange terminateResponse() {
int oldVal = state;
if (allAreSet(oldVal, FLAG_RESPONSE_TERMINATED)) {
if (allAreSet(state, FLAG_RESPONSE_TERMINATED)) {
// idempotent
return this;
}
if(responseChannel != null) {
responseChannel.suspendWrites();
responseChannel.responseDone();
}
this.state = oldVal | FLAG_RESPONSE_TERMINATED;
if (anyAreSet(oldVal, FLAG_REQUEST_TERMINATED)) {
setFlags(FLAG_RESPONSE_TERMINATED);
if (anyAreSet(state, FLAG_REQUEST_TERMINATED)) {
invokeExchangeCompleteListeners();
}
return this;
@@ -1878,11 +1878,10 @@ public void handleException(final Channel channel, final IOException exception)
* @throws IllegalStateException if the response headers were already sent
*/
HttpServerExchange startResponse() throws IllegalStateException {
int oldVal = state;
if (allAreSet(oldVal, FLAG_RESPONSE_SENT)) {
if (allAreSet(state, FLAG_RESPONSE_SENT)) {
throw UndertowMessages.MESSAGES.responseAlreadyStarted();
}
this.state = oldVal | FLAG_RESPONSE_SENT;
setFlags(FLAG_RESPONSE_SENT);

log.tracef("Starting to write response for %s", this);
return this;
@@ -2077,7 +2076,7 @@ protected boolean isFinished() {
@Override
public void resumeWrites() {
if (isInCall()) {
state |= FLAG_SHOULD_RESUME_WRITES;
setFlags(FLAG_SHOULD_RESUME_WRITES);
if(anyAreSet(state, FLAG_DISPATCHED)) {
throw UndertowMessages.MESSAGES.resumedAndDispatched();
}
@@ -2088,7 +2087,7 @@ public void resumeWrites() {

@Override
public void suspendWrites() {
state &= ~FLAG_SHOULD_RESUME_WRITES;
clearFlags(FLAG_SHOULD_RESUME_WRITES);
super.suspendWrites();
}

@@ -2099,7 +2098,7 @@ public void wakeupWrites() {
}
if (isInCall()) {
wakeup = true;
state |= FLAG_SHOULD_RESUME_WRITES;
setFlags(FLAG_SHOULD_RESUME_WRITES);
if(anyAreSet(state, FLAG_DISPATCHED)) {
throw UndertowMessages.MESSAGES.resumedAndDispatched();
}
@@ -2120,10 +2119,10 @@ public void runResume() {
} else {
if (wakeup) {
wakeup = false;
state &= ~FLAG_SHOULD_RESUME_WRITES;
clearFlags(FLAG_SHOULD_RESUME_WRITES);
delegate.wakeupWrites();
} else {
state &= ~FLAG_SHOULD_RESUME_WRITES;
clearFlags(FLAG_SHOULD_RESUME_WRITES);
delegate.resumeWrites();
}
}
@@ -2250,7 +2249,7 @@ protected boolean isFinished() {
public void resumeReads() {
readsResumed = true;
if (isInCall()) {
state |= FLAG_SHOULD_RESUME_READS;
setFlags(FLAG_SHOULD_RESUME_READS);
if(anyAreSet(state, FLAG_DISPATCHED)) {
throw UndertowMessages.MESSAGES.resumedAndDispatched();
}
@@ -2263,7 +2262,7 @@ public void resumeReads() {
public void wakeupReads() {
if (isInCall()) {
wakeup = true;
state |= FLAG_SHOULD_RESUME_READS;
setFlags(FLAG_SHOULD_RESUME_READS);
if(anyAreSet(state, FLAG_DISPATCHED)) {
throw UndertowMessages.MESSAGES.resumedAndDispatched();
}
@@ -2320,7 +2319,7 @@ public void awaitReadable() throws IOException {
@Override
public void suspendReads() {
readsResumed = false;
state &= ~(FLAG_SHOULD_RESUME_READS);
clearFlags(FLAG_SHOULD_RESUME_READS);
super.suspendReads();
}

@@ -2490,10 +2489,10 @@ public void runResume() {
} else {
if (wakeup) {
wakeup = false;
state &= ~FLAG_SHOULD_RESUME_READS;
clearFlags(FLAG_SHOULD_RESUME_READS);
delegate.wakeupReads();
} else {
state &= ~FLAG_SHOULD_RESUME_READS;
clearFlags(FLAG_SHOULD_RESUME_READS);
delegate.resumeReads();
}
}
@@ -2558,4 +2557,18 @@ public T create() {
public String toString() {
return "HttpServerExchange{ " + getRequestMethod().toString() + " " + getRequestURI() + '}';
}

private void setFlags(int flags) {
int old;
do {
old = state;
} while (!stateUpdater.compareAndSet(this, old, old | flags));
}

private void clearFlags(int flags) {
int old;
do {
old = state;
} while (!stateUpdater.compareAndSet(this, old, old & ~flags));
}
}