Skip to content

Commit

Permalink
Merge pull request #827 from mbfreder/alb-remoteHost
Browse files Browse the repository at this point in the history
fix: getRemoteHost and getRemotePort for ALB
  • Loading branch information
deki authored May 10, 2024
2 parents ac38ef5 + 46d9637 commit ec98c39
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest {
static final String PROTOCOL_HEADER_NAME = "X-Forwarded-Proto";
static final String HOST_HEADER_NAME = "Host";
static final String PORT_HEADER_NAME = "X-Forwarded-Port";
static final String CLIENT_IP_HEADER = "X-Forwarded-For";


//-------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
import com.amazonaws.serverless.proxy.model.RequestSource;
import com.amazonaws.services.lambda.runtime.Context;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import jakarta.servlet.*;
import jakarta.servlet.http.*;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpUpgradeHandler;
import jakarta.ws.rs.core.HttpHeaders;
import jakarta.ws.rs.core.SecurityContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedReader;
import java.io.IOException;
Expand Down Expand Up @@ -435,12 +437,22 @@ public String getRemoteAddr() {
if (request.getRequestContext() == null || request.getRequestContext().getIdentity() == null) {
return "127.0.0.1";
}
if (request.getRequestContext().getElb() != null) {
return request.getHeaders().get(CLIENT_IP_HEADER);
}
return request.getRequestContext().getIdentity().getSourceIp();
}


@Override
public String getRemoteHost() {
if (Objects.nonNull(request.getRequestContext().getElb())) {
String hostHeader = request.getHeaders().get(HttpHeaders.HOST);

// the host header has the form host:port, so we split the string to get the host part
return Arrays.asList(hostHeader.split(":")).get(0);
}

return request.getMultiValueHeaders().getFirst(HttpHeaders.HOST);
}

Expand Down Expand Up @@ -471,6 +483,12 @@ public RequestDispatcher getRequestDispatcher(String s) {

@Override
public int getRemotePort() {
if (Objects.nonNull(request.getRequestContext().getElb())) {
String portHeader = request.getHeaders().get(PORT_HEADER_NAME);
if (Objects.nonNull(portHeader)) {
return Integer.parseInt(portHeader);
}
}
return 0;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,18 @@ void serverName_albHostHeader_returnsHostHeader() {
assertEquals("testapi.us-east-1.elb.amazonaws.com", serverName);
}

@Test
void getRemoteHost_albHostHeader_returnsHostHeader() {
initAwsProxyHttpServletRequestTest("ALB");
AwsProxyRequest proxyReq = new AwsProxyRequestBuilder("/test", "GET")
.alb().build();
proxyReq.getHeaders().put(HttpHeaders.HOST, "testapi.us-east-1.elb.amazonaws.com");
HttpServletRequest servletRequest = new AwsProxyHttpServletRequest(proxyReq, null, null);

String host = servletRequest.getRemoteHost();
assertEquals("testapi.us-east-1.elb.amazonaws.com", host);
}

private AwsProxyRequestBuilder getRequestWithHeaders() {
return new AwsProxyRequestBuilder("/hello", "GET")
.header(CUSTOM_HEADER_KEY, CUSTOM_HEADER_VALUE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ public AwsProxyRequestBuilder(AwsProxyRequest req) {

public AwsProxyRequestBuilder(String path, String httpMethod) {
this.request = new AwsProxyRequest();
this.request.setMultiValueHeaders(new Headers()); // avoid NPE
this.request.setMultiValueHeaders(new Headers());// avoid NPE
this.request.setHeaders(new SingleValueHeaders());
this.request.setHttpMethod(httpMethod);
this.request.setPath(path);
this.request.setMultiValueQueryStringParameters(new MultiValuedTreeMap<>());
Expand Down

0 comments on commit ec98c39

Please sign in to comment.