Skip to content

Commit

Permalink
CALCITE-6529: Use persistent sessionContext in AvaticaCommonsHttpClie…
Browse files Browse the repository at this point in the history
…ntImpl
  • Loading branch information
Villő Szűcs committed Sep 13, 2024
1 parent 188f35a commit 06e7d61
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class AvaticaCommonsHttpClientImpl implements AvaticaHttpClient, HttpClie
protected CredentialsProvider credentialsProvider = null;
protected Lookup<AuthSchemeFactory> authRegistry = null;
protected Object userToken;
protected HttpClientContext persistentContext;

public AvaticaCommonsHttpClientImpl(URL url) {
this.uri = toURI(Objects.requireNonNull(url));
Expand All @@ -109,34 +110,33 @@ protected void initializeClient(PoolingHttpClientConnectionManager pool,
HttpClientBuilder httpClientBuilder = HttpClients.custom().setConnectionManager(pool)
.setDefaultRequestConfig(requestConfig);
this.client = httpClientBuilder.build();

this.persistentContext = HttpClientContext.create();
// Set the credentials if they were provided.
if (null != this.credentialsProvider) {
persistentContext.setCredentialsProvider(credentialsProvider);
persistentContext.setAuthSchemeRegistry(authRegistry);
persistentContext.setAuthCache(authCache);
}
if (null != userToken) {
persistentContext.setUserToken(userToken);
}

}

@Override public byte[] send(byte[] request) {
while (true) {
HttpClientContext context = HttpClientContext.create();

// Set the credentials if they were provided.
if (null != this.credentialsProvider) {
context.setCredentialsProvider(credentialsProvider);
context.setAuthSchemeRegistry(authRegistry);
context.setAuthCache(authCache);
}

if (null != userToken) {
context.setUserToken(userToken);
}

ByteArrayEntity entity = new ByteArrayEntity(request, ContentType.APPLICATION_OCTET_STREAM);

// Create the client with the AuthSchemeRegistry and manager
HttpPost post = new HttpPost(uri);
post.setEntity(entity);

try (CloseableHttpResponse response = execute(post, context)) {
try (CloseableHttpResponse response = execute(post, persistentContext)) {
final int statusCode = response.getCode();
if (HttpURLConnection.HTTP_OK == statusCode
|| HttpURLConnection.HTTP_INTERNAL_ERROR == statusCode) {
userToken = context.getUserToken();
userToken = persistentContext.getUserToken();
return EntityUtils.toByteArray(response.getEntity());
} else if (HttpURLConnection.HTTP_UNAVAILABLE == statusCode) {
LOG.debug("Failed to connect to server (HTTP/503), retrying");
Expand Down Expand Up @@ -184,6 +184,8 @@ CloseableHttpResponse execute(HttpPost post, HttpClientContext context)
throw new IllegalArgumentException("Unsupported authentiation type: " + authType);
}
this.authRegistry = authRegistryBuilder.build();
persistentContext.setCredentialsProvider(credentialsProvider);
persistentContext.setAuthSchemeRegistry(authRegistry);
}

@Override public void setGSSCredential(GSSCredential credential) {
Expand All @@ -205,6 +207,8 @@ CloseableHttpResponse execute(HttpPost post, HttpClientContext context)
((BasicCredentialsProvider) this.credentialsProvider)
.setCredentials(anyAuthScope, EmptyCredentials.INSTANCE);
}
persistentContext.setCredentialsProvider(credentialsProvider);
persistentContext.setAuthSchemeRegistry(authRegistry);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,33 @@
package org.apache.calcite.avatica.remote;

import org.apache.calcite.avatica.AvaticaUtils;
import org.apache.calcite.avatica.ConnectionConfig;

import org.apache.hc.client5.http.classic.methods.HttpPost;
import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse;
import org.apache.hc.client5.http.protocol.HttpClientContext;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager;
import org.apache.hc.core5.http.NoHttpResponseException;
import org.apache.hc.core5.http.io.entity.ByteArrayEntity;
import org.apache.hc.core5.http.io.entity.StringEntity;

import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;

import java.io.ByteArrayInputStream;
import java.net.HttpURLConnection;
import java.net.URL;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -64,9 +72,11 @@ public class AvaticaCommonsHttpClientImplTest {

final AvaticaCommonsHttpClientImpl client =
spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1")));
client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock(
ConnectionConfig.class));

doAnswer(failThenSucceed).when(client)
.execute(any(HttpPost.class), any(HttpClientContext.class));
.execute(any(HttpPost.class), eq(client.persistentContext));

when(badResponse.getCode()).thenReturn(HttpURLConnection.HTTP_UNAVAILABLE);

Expand Down Expand Up @@ -96,9 +106,11 @@ public class AvaticaCommonsHttpClientImplTest {

final AvaticaCommonsHttpClientImpl client =
spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1")));
client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock(
ConnectionConfig.class));

doAnswer(failThenSucceed).when(client)
.execute(any(HttpPost.class), any(HttpClientContext.class));
.execute(any(HttpPost.class), eq(client.persistentContext));

when(badResponse.getCode()).thenReturn(HttpURLConnection.HTTP_UNAVAILABLE);

Expand All @@ -109,6 +121,63 @@ public class AvaticaCommonsHttpClientImplTest {
assertEquals("success", AvaticaUtils.newStringUtf8(responseBytes));
}

@Test
public void testPersistentContextReusedAcrossRequests() throws Exception {
final AvaticaCommonsHttpClientImpl client =
spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1")));
client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock(
ConnectionConfig.class));

CloseableHttpResponse response = mock(CloseableHttpResponse.class);
when(response.getCode()).thenReturn(HttpURLConnection.HTTP_OK);

ByteArrayEntity entity = mock(ByteArrayEntity.class);
when(entity.getContent()).thenReturn(new ByteArrayInputStream(new byte[0]));
when(response.getEntity()).thenReturn(entity);

doReturn(response).when(client)
.execute(any(HttpPost.class), eq(client.persistentContext));

client.send(new byte[0]);
client.send(new byte[0]);

// Verify that persistentContext was reused and not created again
verify(client, times(2)).execute(any(HttpPost.class),
eq(client.persistentContext));
}

@Test
public void testPersistentContextThreadSafety() throws Exception {
final AvaticaCommonsHttpClientImpl client =
spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1")));
client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock(
ConnectionConfig.class));

doReturn(mock(CloseableHttpResponse.class)).when(client)
.execute(any(HttpPost.class), eq(client.persistentContext));

Runnable requestTask = () -> {
try {
client.send(new byte[0]);
} catch (Exception e) {
fail("Threaded request failed with exception: " + e.getMessage());
}
};

int threadCount = 5;
Thread[] threads = new Thread[threadCount];
for (int i = 0; i < threadCount; i++) {
threads[i] = new Thread(requestTask);
threads[i].start();
}

for (Thread thread : threads) {
thread.join();
}

verify(client, times(threadCount)).execute(any(HttpPost.class), eq(client.persistentContext));
}

}

// End AvaticaCommonsHttpClientImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ private static void setupUsers(File keytabDir) throws KrbException {
// Passes the GSSCredential into the HTTP client implementation
final AvaticaCommonsHttpClientImpl httpClient =
new AvaticaCommonsHttpClientImpl(httpServerUrl);
httpClient.setGSSCredential(credential);
httpClient.setHttpClientPool(pool, config);
httpClient.setGSSCredential(credential);

return httpClient.send(new byte[0]);
}
Expand Down

0 comments on commit 06e7d61

Please sign in to comment.