Skip to content

Commit

Permalink
Implement CommandGetCatalogs
Browse files Browse the repository at this point in the history
  • Loading branch information
rafael-telles committed Aug 13, 2021
1 parent 310a079 commit 698817d
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ public FlightInfo getFlightInfoPreparedStatement(
throw CallStatus.NOT_FOUND.withDescription("PreparedStatement not found.").toRuntimeException();
}

return getFlightInfoForFlightSqlCommands(commandPreparedStatementQuery, flightDescriptor, preparedStatement);
Schema schema = preparedStatement.getSchema();
return getFlightInfoForFlightSqlCommands(commandPreparedStatementQuery, flightDescriptor, schema);
}

@Override
Expand Down Expand Up @@ -317,13 +318,16 @@ public void getStreamSqlInfo(CommandGetSqlInfo commandGetSqlInfo,
public FlightInfo getFlightInfoCatalogs(
CommandGetCatalogs commandGetCatalogs, CallContext callContext,
FlightDescriptor flightDescriptor) {
throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException();
return getFlightInfoForFlightSqlCommands(commandGetCatalogs, flightDescriptor, getSchemaCatalogs().getSchema());
}

@Override
public void getStreamCatalogs(CallContext callContext, Ticket ticket,
ServerStreamListener serverStreamListener) {
throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException();
final CallHeaders headers = retrieveHeadersFromCallContext(callContext);
final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers);

flightWorkManager.getCatalogs(serverStreamListener, allocator, session);
}

@Override
Expand Down Expand Up @@ -439,8 +443,7 @@ private UserSession getUserSessionFromCallContext(CallContext callContext) {
}

private <T extends Message> FlightInfo getFlightInfoForFlightSqlCommands(
T commandPreparedStatementQuery, FlightDescriptor flightDescriptor, FlightPreparedStatement preparedStatement) {
Schema schema = preparedStatement.getSchema();
T commandPreparedStatementQuery, FlightDescriptor flightDescriptor, Schema schema) {

final Ticket ticket = new Ticket(pack(commandPreparedStatementQuery).toByteArray());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,18 @@ public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStat
workerProvider.get().submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP);
}

public void getCatalogs(FlightProducer.ServerStreamListener listener, BufferAllocator allocator,
UserSession userSession) {
final UserBitShared.ExternalId runExternalId = ExternalIdHelper.generateExternalId();
final UserRequest userRequest =
new UserRequest(UserProtos.RpcType.GET_CATALOGS, UserProtos.GetCatalogsReq.newBuilder().build());

final UserResponseHandler responseHandler = new GetCatalogsResponseHandler(allocator, listener);

workerProvider.get()
.submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP);
}

@VisibleForTesting
static String getQuery(FlightDescriptor descriptor) {
if (!descriptor.isCommand()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (C) 2017-2019 Dremio Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.dremio.service.flight.impl;

import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.util.Text;

import com.dremio.common.utils.protos.QueryWritableBatch;
import com.dremio.exec.proto.GeneralRPCProtos;
import com.dremio.exec.proto.UserProtos;
import com.dremio.exec.rpc.RpcOutcomeListener;
import com.dremio.exec.work.protector.UserResponseHandler;
import com.dremio.exec.work.protector.UserResult;

class GetCatalogsResponseHandler implements UserResponseHandler {
private final BufferAllocator allocator;
private final FlightProducer.ServerStreamListener listener;

public GetCatalogsResponseHandler(BufferAllocator allocator, FlightProducer.ServerStreamListener listener) {
this.allocator = allocator;
this.listener = listener;
}

@Override
public void sendData(RpcOutcomeListener<GeneralRPCProtos.Ack> outcomeListener,
QueryWritableBatch result) {
}

@Override
public void completed(UserResult result) {
UserProtos.GetCatalogsResp catalogsResp = result.unwrap(UserProtos.GetCatalogsResp.class);

try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA,
allocator)) {
listener.start(vectorSchemaRoot);

vectorSchemaRoot.allocateNew();
VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name");

int i = 0;
for (UserProtos.CatalogMetadata catalogMetadata : catalogsResp.getCatalogsList()) {
catalogNameVector.setSafe(i, new Text(catalogMetadata.getCatalogName()));
i++;
}

vectorSchemaRoot.setRowCount(catalogsResp.getCatalogsCount());
listener.putNext();
listener.completed();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,19 @@
import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
Expand Down Expand Up @@ -198,17 +202,13 @@ public void describeTo(Description description) {
}
}

public abstract FlightInfo getFlightInfo(String query) throws IOException, SQLException;

private FlightStream executeQuery(FlightClientUtils.FlightClientWrapper wrapper, String query)
throws IOException, SQLException {
private FlightStream executeQuery(FlightClientUtils.FlightClientWrapper wrapper, String query) throws SQLException {
// Assumption is that we have exactly one endpoint returned.
return (DremioFlightService.FLIGHT_LEGACY_AUTH_MODE.equals(wrapper.getAuthMode()))?
wrapper.getClient().getStream(getFlightInfo(query).getEndpoints().get(0).getTicket()):
wrapper.getClient().getStream(getFlightInfo(query).getEndpoints().get(0).getTicket(), wrapper.getTokenCallOption());
Ticket ticket = getFlightInfo(query).getEndpoints().get(0).getTicket();
return wrapper.getClient().getStream(ticket, getCallOptions());
}

private FlightStream executeQuery(String query) throws IOException, SQLException {
private FlightStream executeQuery(String query) throws SQLException {
// Assumption is that we have exactly one endpoint returned.
return executeQuery(getFlightClientWrapper(), query);
}
Expand All @@ -235,4 +235,13 @@ private List<String> executeQueryWithStringResults(String query) throws Exceptio
return actualStringResults;
}
}

abstract CallOption[] getCallOptions();

public FlightInfo getFlightInfo(String query) throws SQLException {
final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper();

final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8));
return wrapper.getClient().getInfo(command, getCallOptions());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (C) 2017-2019 Dremio Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.dremio.service.flight;

import java.sql.SQLException;

import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.junit.Assert;
import org.junit.Test;

public abstract class AbstractTestFlightSqlServer extends AbstractTestFlightServer {

@Override
public FlightInfo getFlightInfo(String query) throws SQLException {
final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper();

final FlightSqlClient.PreparedStatement preparedStatement =
clientWrapper.getSqlClient().prepare(query, getCallOptions());

return preparedStatement.execute();
}

@Test
public void testGetCatalogs() {
FlightSqlClient flightSqlClient = getFlightClientWrapper().getSqlClient();
CallOption[] callOptions = getCallOptions();

FlightInfo flightInfo = flightSqlClient.getCatalogs(callOptions);
try (FlightStream stream = flightSqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket(), callOptions)) {
Assert.assertTrue(stream.next());
VectorSchemaRoot root = stream.getRoot();
Assert.assertEquals(1, root.getRowCount());

String catalogName = ((VarCharVector) root.getVector("catalog_name")).getObject(0).toString();
Assert.assertEquals("DREMIO", catalogName);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.nio.charset.StandardCharsets;

import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.junit.BeforeClass;
Expand All @@ -44,10 +45,7 @@ protected String getAuthMode() {
}

@Override
public FlightInfo getFlightInfo(String query) {
final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper();

final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8));
return wrapper.getClient().getInfo(command);
CallOption[] getCallOptions() {
return new CallOption[0];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.nio.charset.StandardCharsets;

import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.junit.BeforeClass;
Expand All @@ -43,10 +44,14 @@ protected String getAuthMode() {
}

@Override
public FlightInfo getFlightInfo(String query) {
CallOption[] getCallOptions() {
final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper();
return new CallOption[] { wrapper.getTokenCallOption() };
}

@Override
public FlightInfo getFlightInfo(String query) {
final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8));
return wrapper.getClient().getInfo(command, wrapper.getTokenCallOption());
return getFlightClientWrapper().getClient().getInfo(command, getCallOptions());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,23 @@

import java.sql.SQLException;

import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;

import com.dremio.service.flight.impl.FlightWorkManager;

/**
* Test FlightServer with basic authentication using FlightSql producer.
*/
public class TestFlightSqlServerWithBasicAuth extends AbstractTestFlightServer {
public class TestFlightSqlServerWithBasicAuth extends AbstractTestFlightSqlServer {
@BeforeClass
public static void setup() throws Exception {
setupBaseFlightQueryTest(
Expand All @@ -45,10 +51,7 @@ protected String getAuthMode() {
}

@Override
public FlightInfo getFlightInfo(String query) throws SQLException {
final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper();

final FlightSqlClient.PreparedStatement preparedStatement = clientWrapper.getSqlClient().prepare(query);
return preparedStatement.execute();
CallOption[] getCallOptions() {
return new CallOption[0];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.sql.SQLException;

import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.junit.BeforeClass;
Expand All @@ -28,7 +29,7 @@
/**
* Test FlightServer with bearer token authentication using FlightSql producer.
*/
public class TestFlightSqlServerWithTokenAuth extends AbstractTestFlightServer {
public class TestFlightSqlServerWithTokenAuth extends AbstractTestFlightSqlServer {
@BeforeClass
public static void setup() throws Exception {
setupBaseFlightQueryTest(
Expand All @@ -44,12 +45,8 @@ protected String getAuthMode() {
}

@Override
public FlightInfo getFlightInfo(String query) throws SQLException {
final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper();

final FlightSqlClient.PreparedStatement preparedStatement =
clientWrapper.getSqlClient().prepare(query, clientWrapper.getTokenCallOption());

return preparedStatement.execute();
CallOption[] getCallOptions() {
final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper();
return new CallOption[] { wrapper.getTokenCallOption() };
}
}

0 comments on commit 698817d

Please sign in to comment.