Skip to content

Commit

Permalink
Support overriding User-Agent in Tsunami scans
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629044389
Change-Id: I331a2c9feea1687bc6ff232dc866b127b93862f7
  • Loading branch information
Tsunami Team authored and copybara-github committed Apr 29, 2024
1 parent ed68b9d commit 1d464cb
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ public final class HttpClientCliOptions implements CliOption {
+ " --http-client-connect-timeout-seconds.")
Integer writeTimeoutSeconds;

@Parameter(
names = "--http-client-user-agent",
description = "User-Agent to use in HTTP requests.")
public String userAgent = HttpClient.TSUNAMI_USER_AGENT;

@Override
public void validate() {
validateTimeout("--http-client-call-timeout-seconds", callTimeoutSeconds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Strings.isNullOrEmpty;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

import com.google.inject.AbstractModule;
Expand Down Expand Up @@ -139,9 +140,10 @@ HttpClient provideOkHttpHttpClient(
@TrustAllCertificates boolean trustAllCertificates,
ConnectionFactory connectionFactory,
@LogId String logId,
@ConnectTimeout Duration connectTimeout) {
@ConnectTimeout Duration connectTimeout,
@UserAgent String userAgent) {
return new OkHttpHttpClient(
okHttpClient, trustAllCertificates, connectionFactory, logId, connectTimeout);
okHttpClient, trustAllCertificates, connectionFactory, logId, connectTimeout, userAgent);
}

@Provides
Expand Down Expand Up @@ -260,6 +262,15 @@ int provideWriteTimeoutSeconds(
return 10;
}

@Provides
@UserAgent
String provideUserAgent(HttpClientCliOptions httpClientCliOptions) {
if (!isNullOrEmpty(httpClientCliOptions.userAgent)) {
return httpClientCliOptions.userAgent;
}
return HttpClient.TSUNAMI_USER_AGENT;
}

@Qualifier
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER, ElementType.METHOD, ElementType.FIELD})
Expand Down Expand Up @@ -310,6 +321,11 @@ int provideWriteTimeoutSeconds(
@Target({ElementType.PARAMETER, ElementType.METHOD, ElementType.FIELD})
@interface MaxRequests {}

@Qualifier
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.PARAMETER, ElementType.METHOD, ElementType.FIELD})
@interface UserAgent {}

/** Builder for {@link HttpClientModule}. */
public static final class Builder {
private static final int DEFAULT_CONNECTION_POOL_MAX_IDLE = 5;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.common.io.ByteSource;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.protobuf.ByteString;
import com.google.tsunami.common.net.http.javanet.ConnectionFactory;
import com.google.tsunami.proto.NetworkService;
Expand Down Expand Up @@ -59,18 +60,21 @@ final class OkHttpHttpClient extends HttpClient {
private final ConnectionFactory connectionFactory;
private final String logId;
private final Duration connectionTimeout;
private final String userAgent;

OkHttpHttpClient(
OkHttpClient okHttpClient,
boolean trustAllCertificates,
ConnectionFactory connectionFactory,
String logId,
Duration connectionTimeout) {
Duration connectionTimeout,
String userAgent) {
this.okHttpClient = checkNotNull(okHttpClient);
this.trustAllCertificates = trustAllCertificates;
this.connectionFactory = checkNotNull(connectionFactory);
this.logId = logId;
this.connectionTimeout = connectionTimeout;
this.userAgent = isNullOrEmpty(userAgent) ? TSUNAMI_USER_AGENT : userAgent;
}

/**
Expand Down Expand Up @@ -106,7 +110,7 @@ public HttpResponse sendAsIs(HttpRequest httpRequest) throws IOException {
.getAll(headerName)
.forEach(
headerValue -> connection.setRequestProperty(headerName, headerValue)));
connection.setRequestProperty(USER_AGENT, TSUNAMI_USER_AGENT);
connection.setRequestProperty(USER_AGENT, this.userAgent);

if (ImmutableSet.of(HttpMethod.POST, HttpMethod.PUT, HttpMethod.DELETE)
.contains(httpRequest.method())) {
Expand Down Expand Up @@ -164,7 +168,7 @@ public HttpResponse send(HttpRequest httpRequest, @Nullable NetworkService netwo

OkHttpClient callHttpClient = clientWithHostnameAsProxy(networkService);
try (Response okHttpResponse =
callHttpClient.newCall(buildOkHttpRequest(httpRequest)).execute()) {
callHttpClient.newCall(buildOkHttpRequest(httpRequest, this.userAgent)).execute()) {
return parseResponse(okHttpResponse);
}
}
Expand Down Expand Up @@ -197,7 +201,7 @@ public ListenableFuture<HttpResponse> sendAsync(
logId, httpRequest.method(), httpRequest.url());
OkHttpClient callHttpClient = clientWithHostnameAsProxy(networkService);
SettableFuture<HttpResponse> responseFuture = SettableFuture.create();
Call requestCall = callHttpClient.newCall(buildOkHttpRequest(httpRequest));
Call requestCall = callHttpClient.newCall(buildOkHttpRequest(httpRequest, this.userAgent));

try {
requestCall.enqueue(
Expand Down Expand Up @@ -264,7 +268,7 @@ private OkHttpClient clientWithHostnameAsProxy(NetworkService networkService) {
.build();
}

private static Request buildOkHttpRequest(HttpRequest httpRequest) {
private static Request buildOkHttpRequest(HttpRequest httpRequest, String userAgent) {
Request.Builder okRequestBuilder = new Request.Builder().url(httpRequest.url());

httpRequest.headers().names().stream()
Expand All @@ -275,7 +279,7 @@ private static Request buildOkHttpRequest(HttpRequest httpRequest) {
.headers()
.getAll(headerName)
.forEach(headerValue -> okRequestBuilder.addHeader(headerName, headerValue)));
okRequestBuilder.addHeader(USER_AGENT, TSUNAMI_USER_AGENT);
okRequestBuilder.addHeader(USER_AGENT, userAgent);

switch (httpRequest.method()) {
case GET:
Expand Down Expand Up @@ -352,6 +356,7 @@ public static class OkHttpHttpClientBuilder extends Builder<OkHttpHttpClient> {
private final ConnectionFactory connectionFactory;
private String logId;
private Duration connectionTimeout;
private String userAgent;

private OkHttpHttpClientBuilder(OkHttpHttpClient okHttpHttpClient) {
this.okHttpClient = okHttpHttpClient.okHttpClient;
Expand All @@ -360,6 +365,7 @@ private OkHttpHttpClientBuilder(OkHttpHttpClient okHttpHttpClient) {
this.connectionFactory = okHttpHttpClient.connectionFactory;
this.logId = okHttpHttpClient.logId;
this.connectionTimeout = okHttpHttpClient.connectionTimeout;
this.userAgent = okHttpHttpClient.userAgent;
}

@Override
Expand All @@ -386,14 +392,21 @@ public OkHttpHttpClientBuilder setConnectTimeout(Duration connectionTimeout) {
return this;
}

@CanIgnoreReturnValue
public OkHttpHttpClientBuilder setUserAgent(String userAgent) {
this.userAgent = userAgent;
return this;
}

@Override
public OkHttpHttpClient build() {
return new OkHttpHttpClient(
okHttpClient.newBuilder().followRedirects(followRedirects).build(),
trustAllCertificates,
connectionFactory,
logId,
connectionTimeout);
connectionTimeout,
userAgent);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,57 @@ protected void configure() {
mockWebServer.shutdown();
}

@Test
public void send_default_userAgent() throws IOException, InterruptedException {
String responseBody = "test response";
mockWebServer.enqueue(
new MockResponse()
.setResponseCode(HttpStatus.OK.code())
.setHeader(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString())
.setBody(responseBody));
mockWebServer.start();

HttpUrl baseUrl = mockWebServer.url("/");
httpClient.send(get(baseUrl.toString()).withEmptyHeaders().build());

assertThat(mockWebServer.takeRequest().getHeader(USER_AGENT))
.isEqualTo(HttpClient.TSUNAMI_USER_AGENT);
}

@Test
public void send_overridden_userAgent() throws IOException, InterruptedException {
String responseBody = "test response";
mockWebServer.enqueue(
new MockResponse()
.setResponseCode(HttpStatus.OK.code())
.setHeader(CONTENT_TYPE, MediaType.PLAIN_TEXT_UTF_8.toString())
.setBody(responseBody));
mockWebServer.start();

final String userAgentOverride = "User Agent In Override";

HttpClientCliOptions cliOptions = new HttpClientCliOptions();
cliOptions.userAgent = userAgentOverride;
HttpClientConfigProperties configProperties = new HttpClientConfigProperties();
cliOptions.trustAllCertificates = configProperties.trustAllCertificates = true;
HttpClient httpClient =
Guice.createInjector(
new AbstractModule() {
@Override
protected void configure() {
install(new HttpClientModule.Builder().build());
bind(HttpClientCliOptions.class).toInstance(cliOptions);
bind(HttpClientConfigProperties.class).toInstance(configProperties);
}
})
.getInstance(HttpClient.class);

HttpUrl baseUrl = mockWebServer.url("/");
httpClient.send(get(baseUrl.toString()).withEmptyHeaders().build());

assertThat(mockWebServer.takeRequest().getHeader(USER_AGENT)).isEqualTo(userAgentOverride);
}

private MockWebServer startMockWebServerWithSsl(InetAddress serverAddress)
throws GeneralSecurityException, IOException {
MockWebServer mockWebServer = new MockWebServer();
Expand Down

0 comments on commit 1d464cb

Please sign in to comment.