Skip to content

Commit

Permalink
Merge pull request #1505 from fl4via/UNDERTOW-2293
Browse files Browse the repository at this point in the history
[UNDERTOW-2293] Rewrite the fix for UNDERTOW-2243 in a way to prevent extra chunked messages
  • Loading branch information
fl4via authored Jul 19, 2023
2 parents 6689dad + b2c69a0 commit 7d87eef
Show file tree
Hide file tree
Showing 9 changed files with 504 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@
import io.undertow.util.Protocols;
import io.undertow.util.RedirectBuilder;
import io.undertow.util.StatusCodes;
import org.xnio.Bits;

import static io.undertow.util.URLUtils.isAbsoluteUrl;

/**
* Implementation of {@code HttpServletResponse}.
*
* @author Stuart Douglas
* @author <a href="mailto:[email protected]">Richard Opalka</a>
* @author Flavia Rainone
*/
public final class HttpServletResponseImpl implements HttpServletResponse {

Expand All @@ -74,15 +78,22 @@ public final class HttpServletResponseImpl implements HttpServletResponse {
private PrintWriter writer;
private Integer bufferSize;
private long contentLength = -1;
private boolean insideInclude = false;
private Locale locale;
private boolean responseDone = false;

private boolean ignoredFlushPerformed = false;

private boolean treatAsCommitted = false;
private int flags = 0;
// response is inside include
private static final int INSIDE_INCLUDE_FLAG = 1 << 0x00;
// response is done
private static final int RESPONSE_DONE_FLAG = 1 << 0x01;
// ignored flush has been performed
private static final int IGNORED_FLUSH_PERFORMED_FLAG = 1 << 0x02;
// prevents anything to be added to response
private static final int TREAT_AS_COMMITTED_FLAG = 1 << 0x03;
// indicates that we are closing the response and no further content should be written
private static final int CONTENT_FULLY_WRITTEN_FLAG = 1 << 0x04;
//if a content type has been set either implicitly or implicitly
private static final int CHARSET_SET_FLAG = 1 << 0x05;

private boolean charsetSet = false; //if a content type has been set either implicitly or implicitly
private Locale locale;
private String contentType;
private String charset;
private Supplier<Map<String, String>> trailerSupplier;
Expand All @@ -99,7 +110,7 @@ public HttpServerExchange getExchange() {

@Override
public void addCookie(final Cookie newCookie) {
if (insideInclude) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG)) {
return;
}
final ServletCookieAdaptor servletCookieAdaptor = new ServletCookieAdaptor(newCookie);
Expand All @@ -116,7 +127,7 @@ public boolean containsHeader(final String name) {

@Override
public void sendError(final int sc, final String msg) throws IOException {
if(insideInclude) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG)) {
//not 100% sure this is the correct action
return;
}
Expand All @@ -135,7 +146,7 @@ public void sendError(final int sc, final String msg) throws IOException {
exchange.setStatusCode(sc);
if(src.isRunningInsideHandler()) {
//all we do is set the error on the context, we handle it when the request is returned
treatAsCommitted = true;
flags |= TREAT_AS_COMMITTED_FLAG;
src.setError(sc, msg);
} else {
//if the src is null there is no outer handler, as we are in an asnc request
Expand All @@ -148,7 +159,7 @@ public void doErrorDispatch(int sc, String error) throws IOException {
responseState = ResponseState.NONE;
resetBuffer();
exchange.getResponseHeaders().remove(Headers.CONTENT_LENGTH);
treatAsCommitted = false;
flags &= ~TREAT_AS_COMMITTED_FLAG;
final String location = servletContext.getDeployment().getErrorPages().getErrorLocation(sc);
if (location != null) {
RequestDispatcherImpl requestDispatcher = new RequestDispatcherImpl(location, servletContext);
Expand Down Expand Up @@ -230,7 +241,7 @@ public void setHeader(final HttpString name, final String value) {
if(name == null) {
throw UndertowServletMessages.MESSAGES.headerNameWasNull();
}
if (insideInclude || ignoredFlushPerformed) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG | IGNORED_FLUSH_PERFORMED_FLAG)) {
return;
}
if(name.equals(Headers.CONTENT_TYPE)) {
Expand All @@ -252,7 +263,7 @@ public void addHeader(final HttpString name, final String value) {
if(name == null) {
throw UndertowServletMessages.MESSAGES.headerNameWasNull();
}
if (insideInclude || ignoredFlushPerformed || treatAsCommitted) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG | IGNORED_FLUSH_PERFORMED_FLAG | TREAT_AS_COMMITTED_FLAG | CONTENT_FULLY_WRITTEN_FLAG)) {
return;
}
if(name.equals(Headers.CONTENT_TYPE) && !exchange.getResponseHeaders().contains(Headers.CONTENT_TYPE)) {
Expand All @@ -274,7 +285,7 @@ public void addIntHeader(final String name, final int value) {

@Override
public void setStatus(final int sc) {
if (insideInclude || treatAsCommitted) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG | TREAT_AS_COMMITTED_FLAG)) {
return;
}
if (responseStarted()) {
Expand Down Expand Up @@ -332,7 +343,7 @@ public String getCharacterEncoding() {
@Override
public String getContentType() {
if (contentType != null) {
if (charsetSet) {
if (Bits.anyAreSet(flags, CHARSET_SET_FLAG)) {
return contentType + ";charset=" + getCharacterEncoding();
} else {
return contentType;
Expand All @@ -354,7 +365,7 @@ public ServletOutputStream getOutputStream() {
@Override
public PrintWriter getWriter() throws IOException {
if (writer == null) {
if (!charsetSet) {
if (!Bits.anyAreSet(flags, CHARSET_SET_FLAG)) {
//servet 5.5
setCharacterEncoding(getCharacterEncoding());
}
Expand All @@ -381,10 +392,14 @@ private void createOutputStream() {

@Override
public void setCharacterEncoding(final String charset) {
if (insideInclude || responseStarted() || writer != null || isCommitted()) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG) || responseStarted() || writer != null || isCommitted()) {
return;
}
charsetSet = charset != null;
if (charset != null) {
flags |= CHARSET_SET_FLAG;
} else {
flags &= ~CHARSET_SET_FLAG;
}
this.charset = charset;
if (contentType != null) {
exchange.getResponseHeaders().put(Headers.CONTENT_TYPE, getContentType());
Expand All @@ -398,7 +413,7 @@ public void setContentLength(final int len) {

@Override
public void setContentLengthLong(final long len) {
if (insideInclude || responseStarted()) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG) || responseStarted()) {
return;
}
if(len >= 0) {
Expand All @@ -410,31 +425,35 @@ public void setContentLengthLong(final long len) {
}

boolean isIgnoredFlushPerformed() {
return ignoredFlushPerformed;
return Bits.anyAreSet(flags, IGNORED_FLUSH_PERFORMED_FLAG);
}

void setIgnoredFlushPerformed(boolean ignoredFlushPerformed) {
this.ignoredFlushPerformed = ignoredFlushPerformed;
if (ignoredFlushPerformed) {
flags |= IGNORED_FLUSH_PERFORMED_FLAG;
} else {
flags &= ~IGNORED_FLUSH_PERFORMED_FLAG;
}
}

private boolean responseStarted() {
return exchange.isResponseStarted() || ignoredFlushPerformed || treatAsCommitted;
return exchange.isResponseStarted() || Bits.anyAreSet(flags, IGNORED_FLUSH_PERFORMED_FLAG | TREAT_AS_COMMITTED_FLAG);
}

@Override
public void setContentType(final String type) {
if (type == null || insideInclude || responseStarted()) {
if (type == null || Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG) || responseStarted()) {
return;
}
ContentTypeInfo ct = servletContext.parseContentType(type);
contentType = ct.getContentType();
boolean useCharset = false;
if(ct.getCharset() != null && writer == null && !isCommitted()) {
charset = ct.getCharset();
charsetSet = true;
flags |= CHARSET_SET_FLAG;
useCharset = true;
}
if(useCharset || !charsetSet) {
if(useCharset || !Bits.anyAreSet(flags, CHARSET_SET_FLAG)) {
exchange.getResponseHeaders().put(Headers.CONTENT_TYPE, ct.getHeader());
} else if(ct.getCharset() == null) {
exchange.getResponseHeaders().put(Headers.CONTENT_TYPE, ct.getHeader() + "; charset=" + charset);
Expand Down Expand Up @@ -472,7 +491,7 @@ public void flushBuffer() throws IOException {
}

public void closeStreamAndWriter() throws IOException {
if(treatAsCommitted) {
if(Bits.anyAreSet(flags, TREAT_AS_COMMITTED_FLAG)) {
return;
}
if (writer != null) {
Expand Down Expand Up @@ -525,17 +544,17 @@ public void reset() {
responseState = ResponseState.NONE;
exchange.getResponseHeaders().clear();
exchange.setStatusCode(StatusCodes.OK);
treatAsCommitted = false;
flags &= ~(TREAT_AS_COMMITTED_FLAG | CONTENT_FULLY_WRITTEN_FLAG);
}

@Override
public void setLocale(final Locale loc) {
if (insideInclude || responseStarted()) {
if (Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG) || responseStarted()) {
return;
}
this.locale = loc;
exchange.getResponseHeaders().put(Headers.CONTENT_LANGUAGE, loc.getLanguage() + "-" + loc.getCountry());
if (!charsetSet && writer == null) {
if (!Bits.anyAreSet(flags, CHARSET_SET_FLAG) && writer == null) {
final Map<String, String> localeCharsetMapping = servletContext.getDeployment().getDeploymentInfo().getLocaleCharsetMapping();
// first try DD provided mappings
String charset = null;
Expand Down Expand Up @@ -572,10 +591,10 @@ public Locale getLocale() {
}

public void responseDone() {
if (responseDone || treatAsCommitted) {
if (Bits.anyAreSet(flags, RESPONSE_DONE_FLAG | TREAT_AS_COMMITTED_FLAG)) {
return;
}
responseDone = true;
flags |= RESPONSE_DONE_FLAG;
try {
closeStreamAndWriter();
} catch (IOException e) {
Expand All @@ -586,11 +605,15 @@ public void responseDone() {
}

public boolean isInsideInclude() {
return insideInclude;
return Bits.anyAreSet(flags, INSIDE_INCLUDE_FLAG);
}

public void setInsideInclude(final boolean insideInclude) {
this.insideInclude = insideInclude;
if (insideInclude) {
this.flags |= INSIDE_INCLUDE_FLAG;
} else {
this.flags &= ~INSIDE_INCLUDE_FLAG;
}
}

public void setServletContext(final ServletContextImpl servletContext) {
Expand Down Expand Up @@ -779,7 +802,7 @@ private static String escapeHtml(String msg) {
}

public boolean isTreatAsCommitted() {
return treatAsCommitted;
return Bits.anyAreSet(flags, TREAT_AS_COMMITTED_FLAG);
}

@Override
Expand Down Expand Up @@ -810,4 +833,11 @@ public Supplier<Map<String, String>> getTrailerFields() {
return trailerSupplier;
}

/**
* Marks this response as closed for writing extra bytes, including the addition of headers.
*/
void setContentFullyWritten() {
this.flags |= CONTENT_FULLY_WRITTEN_FLAG;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ void updateWritten(final long len) throws IOException {
this.written += len;
long contentLength = servletRequestContext.getOriginalResponse().getContentLength();
if (contentLength != -1 && this.written >= contentLength) {
flushInternal();
servletRequestContext.getOriginalResponse().setContentFullyWritten();
}
}

Expand All @@ -389,21 +389,7 @@ void updateWrittenAsync(final long len) throws IOException {
this.written += len;
long contentLength = servletRequestContext.getOriginalResponse().getContentLength();
if (contentLength != -1 && this.written >= contentLength) {
setFlags(FLAG_CLOSED);
//if buffersToWrite is set we are already flushing
//so we don't have to do anything
if (buffersToWrite == null && pendingFile == null) {
if (flushBufferAsync(true)) {
channel.shutdownWrites();
setFlags(FLAG_DELEGATE_SHUTDOWN);
channel.flush();
if (pooledBuffer != null) {
pooledBuffer.close();
buffer = null;
pooledBuffer = null;
}
}
}
servletRequestContext.getOriginalResponse().setContentFullyWritten();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* JBoss, Home of Professional Open Source.
* Copyright 2023 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.undertow.servlet.test.response.writer;

import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import java.io.IOException;
import java.io.PrintWriter;

/**
* Asynchronous version of {@link ExceptionWriterServlet}.
*
* @author rmartinc
*/
public class AsyncExceptionWriterServlet extends jakarta.servlet.http.HttpServlet {

@Override
protected void doGet(final HttpServletRequest req, final HttpServletResponse resp) throws ServletException, IOException {
final var asyncContext = req.startAsync();
new Thread(()->{
try {
resp.setContentType("text/plain;charset=UTF-8");
try (PrintWriter writer = resp.getWriter()) {
new Exception("TestException").printStackTrace(writer);
}
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
asyncContext.complete();
}
}).start();
}
}
Loading

0 comments on commit 7d87eef

Please sign in to comment.