From 3791911249bc0678754997c26255da677704af13 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Thu, 12 Aug 2021 17:54:37 -0300 Subject: [PATCH 01/13] Refactor DremioFlightProducer to delegate between FlightSql and legacy producers --- services/arrow-flight/pom.xml | 5 + .../service/flight/DremioFlightProducer.java | 84 +++++- .../flight/DremioFlightSqlProducer.java | 270 ++++++++++++++++++ .../flight/AbstractTestFlightServer.java | 20 +- .../service/flight/FlightClientUtils.java | 7 + .../flight/TestFlightServerWithBasicAuth.java | 13 + .../flight/TestFlightServerWithTokenAuth.java | 13 + .../TestFlightSqlServerWithBasicAuth.java | 55 ++++ .../TestFlightSqlServerWithTokenAuth.java | 56 ++++ 9 files changed, 505 insertions(+), 18 deletions(-) create mode 100644 services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java create mode 100644 services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java create mode 100644 services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java diff --git a/services/arrow-flight/pom.xml b/services/arrow-flight/pom.xml index 45d8edb77e..6e74d243ba 100644 --- a/services/arrow-flight/pom.xml +++ b/services/arrow-flight/pom.xml @@ -53,6 +53,11 @@ org.apache.arrow flight-grpc + + org.apache.arrow + flight-sql + 6.0.0-SNAPSHOT + com.google.protobuf protobuf-java diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index c22713976d..9d6ab4053e 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -13,8 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.dremio.service.flight; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSchemas; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; + import javax.inject.Provider; import org.apache.arrow.flight.Action; @@ -31,6 +43,7 @@ import org.apache.arrow.flight.PutResult; import org.apache.arrow.flight.Result; import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlUtils; import org.apache.arrow.memory.BufferAllocator; import com.dremio.exec.work.protector.UserWorker; @@ -39,12 +52,16 @@ import com.dremio.service.flight.impl.FlightPreparedStatement; import com.dremio.service.flight.impl.FlightWorkManager; import com.dremio.service.flight.impl.FlightWorkManager.RunQueryResponseHandlerFactory; +import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; /** * A FlightProducer implementation which exposes Dremio's catalog and produces results from SQL queries. */ public class DremioFlightProducer implements FlightProducer { + + private final DremioFlightSqlProducer flightSqlProducer; + private final FlightWorkManager flightWorkManager; private final Location location; private final DremioFlightSessionsManager sessionsManager; @@ -52,24 +69,37 @@ public class DremioFlightProducer implements FlightProducer { public DremioFlightProducer(Location location, DremioFlightSessionsManager sessionsManager, Provider workerProvider, Provider optionManagerProvider, - BufferAllocator allocator, RunQueryResponseHandlerFactory runQueryResponseHandlerFactory) { + BufferAllocator allocator, + RunQueryResponseHandlerFactory runQueryResponseHandlerFactory) { this.location = location; this.sessionsManager = sessionsManager; this.allocator = allocator; flightWorkManager = new FlightWorkManager(workerProvider, optionManagerProvider, runQueryResponseHandlerFactory); + + flightSqlProducer = + new DremioFlightSqlProducer(sessionsManager, workerProvider, optionManagerProvider, allocator, + runQueryResponseHandlerFactory); } @Override public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) { + if (isFlightSqlCommand(ticket)) { + this.flightSqlProducer.getStream(callContext, ticket, serverStreamListener); + return; + } + try { final CallHeaders headers = retrieveHeadersFromCallContext(callContext); final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers); - final TicketContent.PreparedStatementTicket preparedStatementTicket = TicketContent.PreparedStatementTicket.parseFrom(ticket.getBytes()); + final TicketContent.PreparedStatementTicket preparedStatementTicket = + TicketContent.PreparedStatementTicket.parseFrom(ticket.getBytes()); flightWorkManager.runPreparedStatement(preparedStatementTicket, serverStreamListener, allocator, session); } catch (InvalidProtocolBufferException ex) { - final RuntimeException error = CallStatus.INVALID_ARGUMENT.withCause(ex).withDescription("Invalid ticket used in getStream").toRuntimeException(); + final RuntimeException error = + CallStatus.INVALID_ARGUMENT.withCause(ex).withDescription("Invalid ticket used in getStream") + .toRuntimeException(); serverStreamListener.error(error); throw error; } @@ -82,6 +112,10 @@ public void listFlights(CallContext callContext, Criteria criteria, StreamListen @Override public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) { + if (isFlightSqlCommand(flightDescriptor)) { + return this.flightSqlProducer.getFlightInfo(callContext, flightDescriptor); + } + final CallHeaders headers = retrieveHeadersFromCallContext(callContext); final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers); final FlightPreparedStatement flightPreparedStatement = flightWorkManager @@ -90,12 +124,22 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight } @Override - public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) { + public Runnable acceptPut(CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + if (isFlightSqlCommand(flightStream.getDescriptor())) { + return this.flightSqlProducer.acceptPut(callContext, flightStream, streamListener); + } + throw CallStatus.UNIMPLEMENTED.withDescription("acceptPut is unimplemented").toRuntimeException(); } @Override public void doAction(CallContext callContext, Action action, StreamListener streamListener) { + if (isFlightSqlAction(action)) { + this.flightSqlProducer.doAction(callContext, action, streamListener); + return; + } + throw CallStatus.UNIMPLEMENTED.withDescription("doAction is unimplemented").toRuntimeException(); } @@ -113,4 +157,36 @@ public void listActions(CallContext callContext, StreamListener stre private CallHeaders retrieveHeadersFromCallContext(CallContext callContext) { return callContext.getMiddleware(FlightConstants.HEADER_KEY).headers(); } + + // TODO: Add this to FlightSqlUtils + private boolean isFlightSqlCommand(Any command) { + return command.is(CommandStatementQuery.class) || command.is(CommandPreparedStatementQuery.class) || + command.is(CommandGetCatalogs.class) || command.is(CommandGetSchemas.class) || + command.is(CommandGetTables.class) || command.is(CommandGetTableTypes.class) || + command.is(CommandGetSqlInfo.class) || command.is(CommandGetPrimaryKeys.class) || + command.is(CommandGetExportedKeys.class) || command.is(CommandGetImportedKeys.class); + } + + private boolean isFlightSqlCommand(byte[] bytes) { + try { + Any command = Any.parseFrom(bytes); + return isFlightSqlCommand(command); + } catch (InvalidProtocolBufferException e) { + return false; + } + } + + private boolean isFlightSqlCommand(FlightDescriptor flightDescriptor) { + return isFlightSqlCommand(flightDescriptor.getCommand()); + } + + private boolean isFlightSqlCommand(Ticket ticket) { + return isFlightSqlCommand(ticket.getBytes()); + } + + // TODO: Add this to FlightSqlUtils + private boolean isFlightSqlAction(Action action) { + String actionType = action.getType(); + return FlightSqlUtils.FLIGHT_SQL_ACTIONS.stream().anyMatch(action2 -> action2.getType().equals(actionType)); + } } diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java new file mode 100644 index 0000000000..ba86be29f1 --- /dev/null +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java @@ -0,0 +1,270 @@ +/* + * Copyright (C) 2017-2019 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dremio.service.flight; + +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSchemas; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; + +import javax.inject.Provider; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; + +import com.dremio.exec.work.protector.UserWorker; +import com.dremio.options.OptionManager; +import com.dremio.service.flight.impl.FlightWorkManager; +import com.dremio.service.flight.impl.FlightWorkManager.RunQueryResponseHandlerFactory; + +/** + * A FlightProducer implementation which exposes Dremio's catalog and produces results from SQL queries. + */ +public class DremioFlightSqlProducer implements FlightSqlProducer { + private final FlightWorkManager flightWorkManager; + private final DremioFlightSessionsManager sessionsManager; + private final BufferAllocator allocator; + + public DremioFlightSqlProducer(DremioFlightSessionsManager sessionsManager, + Provider workerProvider, Provider optionManagerProvider, + BufferAllocator allocator, + RunQueryResponseHandlerFactory runQueryResponseHandlerFactory) { + this.sessionsManager = sessionsManager; + this.allocator = allocator; + + flightWorkManager = new FlightWorkManager(workerProvider, optionManagerProvider, runQueryResponseHandlerFactory); + } + + @Override + public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("listFlights is unimplemented").toRuntimeException(); + } + + @Override + public void createPreparedStatement( + ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement not supported.").toRuntimeException(); + } + + @Override + public void closePreparedStatement( + ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + CallContext callContext, + StreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("closePreparedStatement not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoStatement( + CommandStatementQuery commandStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); + } + + @Override + public SchemaResult getSchemaStatement( + CommandStatementQuery commandStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public void getStreamStatement(CommandStatementQuery commandStatementQuery, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public void getStreamPreparedStatement( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); + } + + @Override + public Runnable acceptPutStatement( + CommandStatementUpdate commandStatementUpdate, + CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + CommandPreparedStatementUpdate commandPreparedStatementUpdate, + CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") + .toRuntimeException(); + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") + .toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSqlInfo(CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSqlInfo not supported.").toRuntimeException(); + } + + @Override + public void getStreamSqlInfo(CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSqlInfo not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoCatalogs( + CommandGetCatalogs commandGetCatalogs, CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); + } + + @Override + public void getStreamCatalogs(CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSchemas(CommandGetSchemas commandGetSchemas, + CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSchemas not supported.").toRuntimeException(); + } + + @Override + public void getStreamSchemas(CommandGetSchemas commandGetSchemas, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSchemas not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoTables(CommandGetTables commandGetTables, + CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTables not supported.").toRuntimeException(); + } + + @Override + public void getStreamTables(CommandGetTables commandGetTables, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTables not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoTableTypes( + CommandGetTableTypes commandGetTableTypes, CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTableTypes not supported.").toRuntimeException(); + } + + @Override + public void getStreamTableTypes(CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTableTypes not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys( + CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetPrimaryKeys not supported.").toRuntimeException(); + } + + @Override + public void getStreamPrimaryKeys(CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetPrimaryKeys not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoExportedKeys( + CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetExportedKeys not supported.").toRuntimeException(); + } + + @Override + public void getStreamExportedKeys( + CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetExportedKeys not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoImportedKeys( + CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetImportedKeys not supported.").toRuntimeException(); + } + + @Override + public void getStreamImportedKeys( + CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetImportedKeys not supported.").toRuntimeException(); + } + + @Override + public void close() throws Exception { + + } +} diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java index e2108d03bb..fedb5d4553 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java @@ -18,12 +18,12 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import java.nio.charset.StandardCharsets; +import java.io.IOException; +import java.sql.SQLException; import java.util.ArrayList; import java.util.List; import org.apache.arrow.flight.FlightClient; -import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; @@ -198,25 +198,17 @@ public void describeTo(Description description) { } } - private static FlightDescriptor toFlightDescriptor(String query) { - return FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)); - } - - private FlightInfo getFlightInfo(String query) { - final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); - return (DremioFlightService.FLIGHT_LEGACY_AUTH_MODE.equals(wrapper.getAuthMode()))? - wrapper.getClient().getInfo(toFlightDescriptor(query)): - wrapper.getClient().getInfo(toFlightDescriptor(query), wrapper.getTokenCallOption()); - } + public abstract FlightInfo getFlightInfo(String query) throws IOException, SQLException; - private FlightStream executeQuery(FlightClientUtils.FlightClientWrapper wrapper, String query) { + private FlightStream executeQuery(FlightClientUtils.FlightClientWrapper wrapper, String query) + throws IOException, SQLException { // Assumption is that we have exactly one endpoint returned. return (DremioFlightService.FLIGHT_LEGACY_AUTH_MODE.equals(wrapper.getAuthMode()))? wrapper.getClient().getStream(getFlightInfo(query).getEndpoints().get(0).getTicket()): wrapper.getClient().getStream(getFlightInfo(query).getEndpoints().get(0).getTicket(), wrapper.getTokenCallOption()); } - private FlightStream executeQuery(String query) { + private FlightStream executeQuery(String query) throws IOException, SQLException { // Assumption is that we have exactly one endpoint returned. return executeQuery(getFlightClientWrapper(), query); } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java index 3d924899aa..4b0a74b26d 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java @@ -21,6 +21,7 @@ import org.apache.arrow.flight.FlightClient; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.memory.BufferAllocator; import com.dremio.common.AutoCloseables; @@ -36,6 +37,7 @@ public final class FlightClientUtils { public static final class FlightClientWrapper implements AutoCloseable { private BufferAllocator allocator; private FlightClient client; + private FlightSqlClient sqlClient; private String authMode; private CredentialCallOption tokenCallOption; @@ -43,10 +45,15 @@ public FlightClientWrapper(BufferAllocator allocator, FlightClient client, String authMode) { this.allocator = allocator; this.client = client; + this.sqlClient = new FlightSqlClient(this.client); this.authMode = authMode; this.tokenCallOption = null; } + public FlightSqlClient getSqlClient() { + return sqlClient; + } + public BufferAllocator getAllocator() { return allocator; } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java index 2bf7dd8c45..80f06b9fdf 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.dremio.service.flight; +import java.nio.charset.StandardCharsets; + +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; import org.junit.BeforeClass; import com.dremio.service.flight.impl.FlightWorkManager; @@ -37,4 +42,12 @@ public static void setup() throws Exception { protected String getAuthMode() { return DremioFlightService.FLIGHT_LEGACY_AUTH_MODE; } + + @Override + public FlightInfo getFlightInfo(String query) { + final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); + + final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)); + return wrapper.getClient().getInfo(command); + } } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java index a18eb88726..5e8f4d9613 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java @@ -13,8 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.dremio.service.flight; +import java.nio.charset.StandardCharsets; + +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; import org.junit.BeforeClass; import com.dremio.service.flight.impl.FlightWorkManager; @@ -36,4 +41,12 @@ public static void setup() throws Exception { protected String getAuthMode() { return DremioFlightService.FLIGHT_AUTH2_AUTH_MODE; } + + @Override + public FlightInfo getFlightInfo(String query) { + final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); + + final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)); + return wrapper.getClient().getInfo(command, wrapper.getTokenCallOption()); + } } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java new file mode 100644 index 0000000000..f88b2c4d75 --- /dev/null +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2017-2019 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dremio.service.flight; + +import java.sql.SQLException; + +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.junit.BeforeClass; +import org.junit.Ignore; + +import com.dremio.service.flight.impl.FlightWorkManager; + +/** + * Test FlightServer with basic authentication using FlightSql producer. + */ +@Ignore +public class TestFlightSqlServerWithBasicAuth extends AbstractTestFlightServer { + @BeforeClass + public static void setup() throws Exception { + setupBaseFlightQueryTest( + false, + true, + "flight.endpoint.port", + FlightWorkManager.RunQueryResponseHandlerFactory.DEFAULT, + DremioFlightService.FLIGHT_LEGACY_AUTH_MODE); + } + + @Override + protected String getAuthMode() { + return DremioFlightService.FLIGHT_LEGACY_AUTH_MODE; + } + + @Override + public FlightInfo getFlightInfo(String query) throws SQLException { + final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper(); + + final FlightSqlClient.PreparedStatement preparedStatement = clientWrapper.getSqlClient().prepare(query); + return preparedStatement.execute(); + } +} diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java new file mode 100644 index 0000000000..1a643a668f --- /dev/null +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2017-2019 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dremio.service.flight; + +import java.sql.SQLException; + +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.junit.BeforeClass; +import org.junit.Ignore; + +import com.dremio.service.flight.impl.FlightWorkManager; + +/** + * Test FlightServer with bearer token authentication using FlightSql producer. + */ +@Ignore +public class TestFlightSqlServerWithTokenAuth extends AbstractTestFlightServer { + @BeforeClass + public static void setup() throws Exception { + setupBaseFlightQueryTest( + false, + true, + "flight.endpoint.port", + FlightWorkManager.RunQueryResponseHandlerFactory.DEFAULT); + } + + @Override + protected String getAuthMode() { + return DremioFlightService.FLIGHT_AUTH2_AUTH_MODE; + } + + @Override + public FlightInfo getFlightInfo(String query) throws SQLException { + final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper(); + + final FlightSqlClient.PreparedStatement preparedStatement = + clientWrapper.getSqlClient().prepare(query, clientWrapper.getTokenCallOption()); + + return preparedStatement.execute(); + } +} From a221f0dedbc07493316e7e2f04f5f8a4718c4a96 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 13 Aug 2021 13:12:54 -0300 Subject: [PATCH 02/13] Remove TODOs regarding to FlightSqlUtils --- .../java/com/dremio/service/flight/DremioFlightProducer.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index 9d6ab4053e..b578ffe7e9 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -158,7 +158,6 @@ private CallHeaders retrieveHeadersFromCallContext(CallContext callContext) { return callContext.getMiddleware(FlightConstants.HEADER_KEY).headers(); } - // TODO: Add this to FlightSqlUtils private boolean isFlightSqlCommand(Any command) { return command.is(CommandStatementQuery.class) || command.is(CommandPreparedStatementQuery.class) || command.is(CommandGetCatalogs.class) || command.is(CommandGetSchemas.class) || @@ -184,7 +183,6 @@ private boolean isFlightSqlCommand(Ticket ticket) { return isFlightSqlCommand(ticket.getBytes()); } - // TODO: Add this to FlightSqlUtils private boolean isFlightSqlAction(Action action) { String actionType = action.getType(); return FlightSqlUtils.FLIGHT_SQL_ACTIONS.stream().anyMatch(action2 -> action2.getType().equals(actionType)); From ea6a46245b9d150be6c8bbfd06014b18f5409ae3 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 13 Aug 2021 13:50:47 -0300 Subject: [PATCH 03/13] Refactor DremioFlightProducer so that it implements FlightSqlProducer, remova DremioFlightSqlProducer --- .../service/flight/DremioFlightProducer.java | 215 +++++++++++++- .../flight/DremioFlightSqlProducer.java | 270 ------------------ 2 files changed, 203 insertions(+), 282 deletions(-) delete mode 100644 services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index b578ffe7e9..8dc1ba8519 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -37,13 +37,15 @@ import org.apache.arrow.flight.FlightConstants; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightProducer; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.PutResult; import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; import org.apache.arrow.flight.sql.FlightSqlUtils; +import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.memory.BufferAllocator; import com.dremio.exec.work.protector.UserWorker; @@ -58,9 +60,7 @@ /** * A FlightProducer implementation which exposes Dremio's catalog and produces results from SQL queries. */ -public class DremioFlightProducer implements FlightProducer { - - private final DremioFlightSqlProducer flightSqlProducer; +public class DremioFlightProducer implements FlightSqlProducer { private final FlightWorkManager flightWorkManager; private final Location location; @@ -76,16 +76,12 @@ public DremioFlightProducer(Location location, DremioFlightSessionsManager sessi this.allocator = allocator; flightWorkManager = new FlightWorkManager(workerProvider, optionManagerProvider, runQueryResponseHandlerFactory); - - flightSqlProducer = - new DremioFlightSqlProducer(sessionsManager, workerProvider, optionManagerProvider, allocator, - runQueryResponseHandlerFactory); } @Override public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) { if (isFlightSqlCommand(ticket)) { - this.flightSqlProducer.getStream(callContext, ticket, serverStreamListener); + FlightSqlProducer.super.getStream(callContext, ticket, serverStreamListener); return; } @@ -113,7 +109,7 @@ public void listFlights(CallContext callContext, Criteria criteria, StreamListen @Override public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) { if (isFlightSqlCommand(flightDescriptor)) { - return this.flightSqlProducer.getFlightInfo(callContext, flightDescriptor); + return FlightSqlProducer.super.getFlightInfo(callContext, flightDescriptor); } final CallHeaders headers = retrieveHeadersFromCallContext(callContext); @@ -127,7 +123,7 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) { if (isFlightSqlCommand(flightStream.getDescriptor())) { - return this.flightSqlProducer.acceptPut(callContext, flightStream, streamListener); + return FlightSqlProducer.super.acceptPut(callContext, flightStream, streamListener); } throw CallStatus.UNIMPLEMENTED.withDescription("acceptPut is unimplemented").toRuntimeException(); @@ -136,7 +132,7 @@ public Runnable acceptPut(CallContext callContext, FlightStream flightStream, @Override public void doAction(CallContext callContext, Action action, StreamListener streamListener) { if (isFlightSqlAction(action)) { - this.flightSqlProducer.doAction(callContext, action, streamListener); + FlightSqlProducer.super.doAction(callContext, action, streamListener); return; } @@ -148,6 +144,201 @@ public void listActions(CallContext callContext, StreamListener stre throw CallStatus.UNIMPLEMENTED.withDescription("listActions is unimplemented").toRuntimeException(); } + @Override + public void createPreparedStatement( + FlightSql.ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, + CallContext callContext, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement not supported.").toRuntimeException(); + } + + @Override + public void closePreparedStatement( + FlightSql.ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + CallContext callContext, + StreamListener listener) { + throw CallStatus.UNIMPLEMENTED.withDescription("closePreparedStatement not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoStatement( + CommandStatementQuery commandStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); + } + + @Override + public SchemaResult getSchemaStatement( + CommandStatementQuery commandStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public void getStreamStatement(CommandStatementQuery commandStatementQuery, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public void getStreamPreparedStatement( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); + } + + @Override + public Runnable acceptPutStatement( + FlightSql.CommandStatementUpdate commandStatementUpdate, + CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + FlightSql.CommandPreparedStatementUpdate commandPreparedStatementUpdate, + CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") + .toRuntimeException(); + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, FlightStream flightStream, + StreamListener streamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") + .toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSqlInfo(CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSqlInfo not supported.").toRuntimeException(); + } + + @Override + public void getStreamSqlInfo(CommandGetSqlInfo commandGetSqlInfo, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSqlInfo not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoCatalogs( + CommandGetCatalogs commandGetCatalogs, CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); + } + + @Override + public void getStreamCatalogs(CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSchemas(CommandGetSchemas commandGetSchemas, + CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSchemas not supported.").toRuntimeException(); + } + + @Override + public void getStreamSchemas(CommandGetSchemas commandGetSchemas, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSchemas not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoTables(CommandGetTables commandGetTables, + CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTables not supported.").toRuntimeException(); + } + + @Override + public void getStreamTables(CommandGetTables commandGetTables, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTables not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoTableTypes( + CommandGetTableTypes commandGetTableTypes, CallContext callContext, + FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTableTypes not supported.").toRuntimeException(); + } + + @Override + public void getStreamTableTypes(CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTableTypes not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys( + CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetPrimaryKeys not supported.").toRuntimeException(); + } + + @Override + public void getStreamPrimaryKeys(CommandGetPrimaryKeys commandGetPrimaryKeys, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetPrimaryKeys not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoExportedKeys( + CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetExportedKeys not supported.").toRuntimeException(); + } + + @Override + public void getStreamExportedKeys( + CommandGetExportedKeys commandGetExportedKeys, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetExportedKeys not supported.").toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoImportedKeys( + CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, FlightDescriptor flightDescriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetImportedKeys not supported.").toRuntimeException(); + } + + @Override + public void getStreamImportedKeys( + CommandGetImportedKeys commandGetImportedKeys, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetImportedKeys not supported.").toRuntimeException(); + } + + @Override + public void close() throws Exception { + + } + /** * Helper method to retrieve CallHeaders from the CallContext. * diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java deleted file mode 100644 index ba86be29f1..0000000000 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightSqlProducer.java +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Copyright (C) 2017-2019 Dremio Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.dremio.service.flight; - -import static org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; -import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSchemas; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; - -import javax.inject.Provider; - -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.Criteria; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightStream; -import org.apache.arrow.flight.PutResult; -import org.apache.arrow.flight.Result; -import org.apache.arrow.flight.SchemaResult; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.sql.FlightSqlProducer; -import org.apache.arrow.memory.BufferAllocator; - -import com.dremio.exec.work.protector.UserWorker; -import com.dremio.options.OptionManager; -import com.dremio.service.flight.impl.FlightWorkManager; -import com.dremio.service.flight.impl.FlightWorkManager.RunQueryResponseHandlerFactory; - -/** - * A FlightProducer implementation which exposes Dremio's catalog and produces results from SQL queries. - */ -public class DremioFlightSqlProducer implements FlightSqlProducer { - private final FlightWorkManager flightWorkManager; - private final DremioFlightSessionsManager sessionsManager; - private final BufferAllocator allocator; - - public DremioFlightSqlProducer(DremioFlightSessionsManager sessionsManager, - Provider workerProvider, Provider optionManagerProvider, - BufferAllocator allocator, - RunQueryResponseHandlerFactory runQueryResponseHandlerFactory) { - this.sessionsManager = sessionsManager; - this.allocator = allocator; - - flightWorkManager = new FlightWorkManager(workerProvider, optionManagerProvider, runQueryResponseHandlerFactory); - } - - @Override - public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("listFlights is unimplemented").toRuntimeException(); - } - - @Override - public void createPreparedStatement( - ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, - CallContext callContext, - StreamListener streamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement not supported.").toRuntimeException(); - } - - @Override - public void closePreparedStatement( - ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, - CallContext callContext, - StreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("closePreparedStatement not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoStatement( - CommandStatementQuery commandStatementQuery, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoPreparedStatement( - CommandPreparedStatementQuery commandPreparedStatementQuery, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); - } - - @Override - public SchemaResult getSchemaStatement( - CommandStatementQuery commandStatementQuery, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); - } - - @Override - public void getStreamStatement(CommandStatementQuery commandStatementQuery, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); - } - - @Override - public void getStreamPreparedStatement( - CommandPreparedStatementQuery commandPreparedStatementQuery, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); - } - - @Override - public Runnable acceptPutStatement( - CommandStatementUpdate commandStatementUpdate, - CallContext callContext, FlightStream flightStream, - StreamListener streamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); - } - - @Override - public Runnable acceptPutPreparedStatementUpdate( - CommandPreparedStatementUpdate commandPreparedStatementUpdate, - CallContext callContext, FlightStream flightStream, - StreamListener streamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") - .toRuntimeException(); - } - - @Override - public Runnable acceptPutPreparedStatementQuery( - CommandPreparedStatementQuery commandPreparedStatementQuery, - CallContext callContext, FlightStream flightStream, - StreamListener streamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") - .toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoSqlInfo(CommandGetSqlInfo commandGetSqlInfo, - CallContext callContext, - FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSqlInfo not supported.").toRuntimeException(); - } - - @Override - public void getStreamSqlInfo(CommandGetSqlInfo commandGetSqlInfo, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSqlInfo not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoCatalogs( - CommandGetCatalogs commandGetCatalogs, CallContext callContext, - FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); - } - - @Override - public void getStreamCatalogs(CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoSchemas(CommandGetSchemas commandGetSchemas, - CallContext callContext, - FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSchemas not supported.").toRuntimeException(); - } - - @Override - public void getStreamSchemas(CommandGetSchemas commandGetSchemas, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetSchemas not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoTables(CommandGetTables commandGetTables, - CallContext callContext, - FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTables not supported.").toRuntimeException(); - } - - @Override - public void getStreamTables(CommandGetTables commandGetTables, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTables not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoTableTypes( - CommandGetTableTypes commandGetTableTypes, CallContext callContext, - FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTableTypes not supported.").toRuntimeException(); - } - - @Override - public void getStreamTableTypes(CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetTableTypes not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoPrimaryKeys( - CommandGetPrimaryKeys commandGetPrimaryKeys, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetPrimaryKeys not supported.").toRuntimeException(); - } - - @Override - public void getStreamPrimaryKeys(CommandGetPrimaryKeys commandGetPrimaryKeys, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetPrimaryKeys not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoExportedKeys( - CommandGetExportedKeys commandGetExportedKeys, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetExportedKeys not supported.").toRuntimeException(); - } - - @Override - public void getStreamExportedKeys( - CommandGetExportedKeys commandGetExportedKeys, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetExportedKeys not supported.").toRuntimeException(); - } - - @Override - public FlightInfo getFlightInfoImportedKeys( - CommandGetImportedKeys commandGetImportedKeys, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetImportedKeys not supported.").toRuntimeException(); - } - - @Override - public void getStreamImportedKeys( - CommandGetImportedKeys commandGetImportedKeys, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetImportedKeys not supported.").toRuntimeException(); - } - - @Override - public void close() throws Exception { - - } -} From 6498b2c630942cdbae50c2aa6d7bed05f29fcafc Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 13 Aug 2021 14:58:21 -0300 Subject: [PATCH 04/13] Fix minor style issues --- .../dremio/service/flight/DremioFlightProducer.java | 5 +++-- .../com/dremio/service/flight/FlightClientUtils.java | 10 ++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index 8dc1ba8519..f7ea2768f2 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -80,7 +80,7 @@ public DremioFlightProducer(Location location, DremioFlightSessionsManager sessi @Override public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) { - if (isFlightSqlCommand(ticket)) { + if (isFlightSqlTicket(ticket)) { FlightSqlProducer.super.getStream(callContext, ticket, serverStreamListener); return; } @@ -370,7 +370,8 @@ private boolean isFlightSqlCommand(FlightDescriptor flightDescriptor) { return isFlightSqlCommand(flightDescriptor.getCommand()); } - private boolean isFlightSqlCommand(Ticket ticket) { + private boolean isFlightSqlTicket(Ticket ticket) { + // The byte array on ticket is a serialized FlightSqlCommand return isFlightSqlCommand(ticket.getBytes()); } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java index 4b0a74b26d..6765b4af9d 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/FlightClientUtils.java @@ -35,10 +35,10 @@ public final class FlightClientUtils { * Container class for holding a FlightClient and its associated allocator. */ public static final class FlightClientWrapper implements AutoCloseable { - private BufferAllocator allocator; - private FlightClient client; - private FlightSqlClient sqlClient; - private String authMode; + private final BufferAllocator allocator; + private final FlightClient client; + private final FlightSqlClient sqlClient; + private final String authMode; private CredentialCallOption tokenCallOption; public FlightClientWrapper(BufferAllocator allocator, FlightClient client, @@ -79,8 +79,6 @@ public void close() throws Exception { // Note - client must close first as it creates a child allocator from // the input allocator. AutoCloseables.close(client, allocator); - client = null; - allocator = null; tokenCallOption = null; } } From 310a07930fb5db8fe0ad5a49b9812c340d26fe95 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 13 Aug 2021 16:00:23 -0300 Subject: [PATCH 05/13] Implement CommandPreparedStatementQuery --- .../service/flight/DremioFlightProducer.java | 159 ++++++++++++++---- .../flight/impl/FlightPreparedStatement.java | 20 ++- .../flight/impl/FlightWorkManager.java | 8 +- .../TestFlightSqlServerWithBasicAuth.java | 1 - .../TestFlightSqlServerWithTokenAuth.java | 1 - 5 files changed, 152 insertions(+), 37 deletions(-) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index f7ea2768f2..76f26b760d 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -16,6 +16,10 @@ package com.dremio.service.flight; +import static com.google.protobuf.Any.pack; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; @@ -25,7 +29,12 @@ import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; + +import java.nio.charset.StandardCharsets; +import java.util.concurrent.TimeUnit; import javax.inject.Provider; @@ -36,6 +45,7 @@ import org.apache.arrow.flight.Criteria; import org.apache.arrow.flight.FlightConstants; import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; @@ -45,17 +55,22 @@ import org.apache.arrow.flight.Ticket; import org.apache.arrow.flight.sql.FlightSqlProducer; import org.apache.arrow.flight.sql.FlightSqlUtils; -import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.Schema; +import com.dremio.exec.proto.UserProtos; import com.dremio.exec.work.protector.UserWorker; import com.dremio.options.OptionManager; import com.dremio.sabot.rpc.user.UserSession; import com.dremio.service.flight.impl.FlightPreparedStatement; import com.dremio.service.flight.impl.FlightWorkManager; import com.dremio.service.flight.impl.FlightWorkManager.RunQueryResponseHandlerFactory; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; +import com.google.common.collect.ImmutableList; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; /** * A FlightProducer implementation which exposes Dremio's catalog and produces results from SQL queries. @@ -66,6 +81,7 @@ public class DremioFlightProducer implements FlightSqlProducer { private final Location location; private final DremioFlightSessionsManager sessionsManager; private final BufferAllocator allocator; + private final Cache flightPreparedStatementCache; public DremioFlightProducer(Location location, DremioFlightSessionsManager sessionsManager, Provider workerProvider, Provider optionManagerProvider, @@ -76,6 +92,11 @@ public DremioFlightProducer(Location location, DremioFlightSessionsManager sessi this.allocator = allocator; flightWorkManager = new FlightWorkManager(workerProvider, optionManagerProvider, runQueryResponseHandlerFactory); + + flightPreparedStatementCache = CacheBuilder.newBuilder() + .maximumSize(1024) + .expireAfterAccess(30, TimeUnit.MINUTES) + .build(); } @Override @@ -85,13 +106,14 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen return; } - try { - final CallHeaders headers = retrieveHeadersFromCallContext(callContext); - final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers); - final TicketContent.PreparedStatementTicket preparedStatementTicket = - TicketContent.PreparedStatementTicket.parseFrom(ticket.getBytes()); + getStreamLegacy(callContext, ticket, serverStreamListener); + } + + private void getStreamLegacy(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) { - flightWorkManager.runPreparedStatement(preparedStatementTicket, serverStreamListener, allocator, session); + final TicketContent.PreparedStatementTicket preparedStatementTicket; + try { + preparedStatementTicket = TicketContent.PreparedStatementTicket.parseFrom(ticket.getBytes()); } catch (InvalidProtocolBufferException ex) { final RuntimeException error = CallStatus.INVALID_ARGUMENT.withCause(ex).withDescription("Invalid ticket used in getStream") @@ -99,6 +121,31 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen serverStreamListener.error(error); throw error; } + + UserProtos.PreparedStatementHandle preparedStatementHandle = preparedStatementTicket.getHandle(); + + runPreparedStatement(callContext, serverStreamListener, preparedStatementHandle); + } + + @Override + public void getStreamPreparedStatement(CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, Ticket ticket, + ServerStreamListener serverStreamListener) { + UserProtos.PreparedStatementHandle preparedStatementHandle; + try { + preparedStatementHandle = + UserProtos.PreparedStatementHandle.parseFrom(commandPreparedStatementQuery.getPreparedStatementHandle()); + } catch (InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid PreparedStatementHandle").toRuntimeException(); + } + + // Check if given PreparedStatement is cached + FlightPreparedStatement preparedStatement = flightPreparedStatementCache.getIfPresent(preparedStatementHandle); + if (preparedStatement == null) { + throw CallStatus.NOT_FOUND.withDescription("PreparedStatement not found.").toRuntimeException(); + } + + runPreparedStatement(callContext, serverStreamListener, preparedStatementHandle); } @Override @@ -112,13 +159,37 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight return FlightSqlProducer.super.getFlightInfo(callContext, flightDescriptor); } - final CallHeaders headers = retrieveHeadersFromCallContext(callContext); - final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers); + return getFlightInfoLegacy(callContext, flightDescriptor); + } + + private FlightInfo getFlightInfoLegacy(CallContext callContext, FlightDescriptor flightDescriptor) { + final UserSession session = getUserSessionFromCallContext(callContext); final FlightPreparedStatement flightPreparedStatement = flightWorkManager .createPreparedStatement(flightDescriptor, callContext::isCancelled, session); return flightPreparedStatement.getFlightInfo(location); } + @Override + public FlightInfo getFlightInfoPreparedStatement( + CommandPreparedStatementQuery commandPreparedStatementQuery, + CallContext callContext, FlightDescriptor flightDescriptor) { + final UserProtos.PreparedStatementHandle preparedStatementHandle; + + try { + preparedStatementHandle = + UserProtos.PreparedStatementHandle.parseFrom(commandPreparedStatementQuery.getPreparedStatementHandle()); + } catch (InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid PreparedStatementHandle").toRuntimeException(); + } + + FlightPreparedStatement preparedStatement = flightPreparedStatementCache.getIfPresent(preparedStatementHandle); + if (preparedStatement == null) { + throw CallStatus.NOT_FOUND.withDescription("PreparedStatement not found.").toRuntimeException(); + } + + return getFlightInfoForFlightSqlCommands(commandPreparedStatementQuery, flightDescriptor, preparedStatement); + } + @Override public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) { @@ -146,18 +217,39 @@ public void listActions(CallContext callContext, StreamListener stre @Override public void createPreparedStatement( - FlightSql.ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, + ActionCreatePreparedStatementRequest actionCreatePreparedStatementRequest, CallContext callContext, StreamListener streamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("createPreparedStatement not supported.").toRuntimeException(); + final FlightDescriptor flightDescriptor = + FlightDescriptor.command(actionCreatePreparedStatementRequest.getQuery().getBytes(StandardCharsets.UTF_8)); + + final UserSession session = getUserSessionFromCallContext(callContext); + final FlightPreparedStatement flightPreparedStatement = flightWorkManager + .createPreparedStatement(flightDescriptor, callContext::isCancelled, session); + + flightPreparedStatementCache.put(flightPreparedStatement.getServerHandle(), flightPreparedStatement); + + final ActionCreatePreparedStatementResult action = flightPreparedStatement.createAction(); + + streamListener.onNext(new Result(pack(action).toByteArray())); + streamListener.onCompleted(); + } @Override public void closePreparedStatement( - FlightSql.ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, CallContext callContext, StreamListener listener) { - throw CallStatus.UNIMPLEMENTED.withDescription("closePreparedStatement not supported.").toRuntimeException(); + UserProtos.PreparedStatementHandle preparedStatementHandle; + try { + preparedStatementHandle = + UserProtos.PreparedStatementHandle.parseFrom(actionClosePreparedStatementRequest.getPreparedStatementHandle()); + } catch (InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid PreparedStatementHandle").toRuntimeException(); + } + + flightPreparedStatementCache.invalidate(preparedStatementHandle); } @Override @@ -167,13 +259,6 @@ public FlightInfo getFlightInfoStatement( throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); } - @Override - public FlightInfo getFlightInfoPreparedStatement( - CommandPreparedStatementQuery commandPreparedStatementQuery, - CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); - } - @Override public SchemaResult getSchemaStatement( CommandStatementQuery commandStatementQuery, @@ -188,17 +273,9 @@ public void getStreamStatement(CommandStatementQuery commandStatementQuery, throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); } - @Override - public void getStreamPreparedStatement( - CommandPreparedStatementQuery commandPreparedStatementQuery, - CallContext callContext, Ticket ticket, - ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement not supported.").toRuntimeException(); - } - @Override public Runnable acceptPutStatement( - FlightSql.CommandStatementUpdate commandStatementUpdate, + CommandStatementUpdate commandStatementUpdate, CallContext callContext, FlightStream flightStream, StreamListener streamListener) { throw CallStatus.UNIMPLEMENTED.withDescription("Statement not supported.").toRuntimeException(); @@ -206,7 +283,7 @@ public Runnable acceptPutStatement( @Override public Runnable acceptPutPreparedStatementUpdate( - FlightSql.CommandPreparedStatementUpdate commandPreparedStatementUpdate, + CommandPreparedStatementUpdate commandPreparedStatementUpdate, CallContext callContext, FlightStream flightStream, StreamListener streamListener) { throw CallStatus.UNIMPLEMENTED.withDescription("PreparedStatement with parameter binding not supported.") @@ -339,6 +416,13 @@ public void close() throws Exception { } + private void runPreparedStatement(CallContext callContext, + ServerStreamListener serverStreamListener, + UserProtos.PreparedStatementHandle preparedStatementHandle) { + final UserSession session = getUserSessionFromCallContext(callContext); + flightWorkManager.runPreparedStatement(preparedStatementHandle, serverStreamListener, allocator, session); + } + /** * Helper method to retrieve CallHeaders from the CallContext. * @@ -349,6 +433,21 @@ private CallHeaders retrieveHeadersFromCallContext(CallContext callContext) { return callContext.getMiddleware(FlightConstants.HEADER_KEY).headers(); } + private UserSession getUserSessionFromCallContext(CallContext callContext) { + final CallHeaders headers = retrieveHeadersFromCallContext(callContext); + return sessionsManager.getUserSession(callContext.peerIdentity(), headers); + } + + private FlightInfo getFlightInfoForFlightSqlCommands( + T commandPreparedStatementQuery, FlightDescriptor flightDescriptor, FlightPreparedStatement preparedStatement) { + Schema schema = preparedStatement.getSchema(); + + final Ticket ticket = new Ticket(pack(commandPreparedStatementQuery).toByteArray()); + + final FlightEndpoint flightEndpoint = new FlightEndpoint(ticket, location); + return new FlightInfo(schema, flightDescriptor, ImmutableList.of(flightEndpoint), -1, -1); + } + private boolean isFlightSqlCommand(Any command) { return command.is(CommandStatementQuery.class) || command.is(CommandPreparedStatementQuery.class) || command.is(CommandGetCatalogs.class) || command.is(CommandGetSchemas.class) || diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java index 4558468eb2..f8dbc22e93 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java @@ -15,6 +15,8 @@ */ package com.dremio.service.flight.impl; +import static org.apache.arrow.flight.sql.impl.FlightSql.*; + import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; import org.apache.arrow.flight.FlightInfo; @@ -56,7 +58,7 @@ public FlightInfo getFlightInfo(Location location) { final PreparedStatementTicket preparedStatementTicketContent = PreparedStatementTicket.newBuilder() .setQuery(query) - .setHandle(createPreparedStatementResp.getPreparedStatement().getServerHandle()) + .setHandle(getServerHandle()) .build(); final Ticket ticket = new Ticket(preparedStatementTicketContent.toByteArray()); @@ -75,6 +77,22 @@ public Schema getSchema() { return buildSchema(resp.getPreparedStatement().getArrowSchema()); } + public ActionCreatePreparedStatementResult createAction() { + final UserProtos.CreatePreparedStatementArrowResp createPreparedStatementResp = responseHandler.get(); + final Schema schema = buildSchema(createPreparedStatementResp.getPreparedStatement().getArrowSchema()); + + return ActionCreatePreparedStatementResult.newBuilder() + .setDatasetSchema(ByteString.copyFrom(schema.toByteArray())) + .setParameterSchema(ByteString.EMPTY) + .setPreparedStatementHandle(getServerHandle().toByteString()) + .build(); + } + + public UserProtos.PreparedStatementHandle getServerHandle() { + UserProtos.CreatePreparedStatementArrowResp createPreparedStatementResp = responseHandler.get(); + return createPreparedStatementResp.getPreparedStatement().getServerHandle(); + } + private static Schema buildSchema(ByteString arrowSchema) { return Schema.deserialize(arrowSchema.asReadOnlyByteBuffer()); } diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java index 578df179b1..85e915acd1 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java @@ -88,9 +88,9 @@ public FlightPreparedStatement createPreparedStatement(FlightDescriptor flightDe return new FlightPreparedStatement(flightDescriptor, query, createPreparedStatementResponseHandler); } - public void runPreparedStatement(TicketContent.PreparedStatementTicket ticket, FlightProducer.ServerStreamListener listener, - BufferAllocator allocator, UserSession userSession) { - + public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStatementHandle, + FlightProducer.ServerStreamListener listener, BufferAllocator allocator, + UserSession userSession) { final UserBitShared.ExternalId runExternalId = ExternalIdHelper.generateExternalId(); final UserRequest userRequest = new UserRequest(UserProtos.RpcType.RUN_QUERY, @@ -100,7 +100,7 @@ public void runPreparedStatement(TicketContent.PreparedStatementTicket ticket, F .setWorkloadType(UserBitShared.WorkloadType.FLIGHT) .setWorkloadClass(UserBitShared.WorkloadClass.GENERAL)) .setSource(UserProtos.SubmissionSource.FLIGHT) - .setPreparedStatementHandle(ticket.getHandle()) + .setPreparedStatementHandle(preparedStatementHandle) .build()); final UserResponseHandler responseHandler = runQueryResponseHandlerFactory.getHandler(runExternalId, userSession, diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java index f88b2c4d75..30c2362def 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java @@ -28,7 +28,6 @@ /** * Test FlightServer with basic authentication using FlightSql producer. */ -@Ignore public class TestFlightSqlServerWithBasicAuth extends AbstractTestFlightServer { @BeforeClass public static void setup() throws Exception { diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java index 1a643a668f..4ed9229a50 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java @@ -28,7 +28,6 @@ /** * Test FlightServer with bearer token authentication using FlightSql producer. */ -@Ignore public class TestFlightSqlServerWithTokenAuth extends AbstractTestFlightServer { @BeforeClass public static void setup() throws Exception { From fb74615563456b44d742cd66fb90056a51fd0174 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 13 Aug 2021 16:33:52 -0300 Subject: [PATCH 06/13] Fix minor style issues --- .../service/flight/DremioFlightProducer.java | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index 76f26b760d..dddcb96101 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -92,7 +92,6 @@ public DremioFlightProducer(Location location, DremioFlightSessionsManager sessi this.allocator = allocator; flightWorkManager = new FlightWorkManager(workerProvider, optionManagerProvider, runQueryResponseHandlerFactory); - flightPreparedStatementCache = CacheBuilder.newBuilder() .maximumSize(1024) .expireAfterAccess(30, TimeUnit.MINUTES) @@ -187,7 +186,8 @@ public FlightInfo getFlightInfoPreparedStatement( throw CallStatus.NOT_FOUND.withDescription("PreparedStatement not found.").toRuntimeException(); } - return getFlightInfoForFlightSqlCommands(commandPreparedStatementQuery, flightDescriptor, preparedStatement); + Schema schema = preparedStatement.getSchema(); + return getFlightInfoForFlightSqlCommands(commandPreparedStatementQuery, flightDescriptor, schema); } @Override @@ -245,11 +245,11 @@ public void closePreparedStatement( try { preparedStatementHandle = UserProtos.PreparedStatementHandle.parseFrom(actionClosePreparedStatementRequest.getPreparedStatementHandle()); + + flightPreparedStatementCache.invalidate(preparedStatementHandle); } catch (InvalidProtocolBufferException e) { throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid PreparedStatementHandle").toRuntimeException(); } - - flightPreparedStatementCache.invalidate(preparedStatementHandle); } @Override @@ -439,10 +439,8 @@ private UserSession getUserSessionFromCallContext(CallContext callContext) { } private FlightInfo getFlightInfoForFlightSqlCommands( - T commandPreparedStatementQuery, FlightDescriptor flightDescriptor, FlightPreparedStatement preparedStatement) { - Schema schema = preparedStatement.getSchema(); - - final Ticket ticket = new Ticket(pack(commandPreparedStatementQuery).toByteArray()); + T command, FlightDescriptor flightDescriptor, Schema schema) { + final Ticket ticket = new Ticket(pack(command).toByteArray()); final FlightEndpoint flightEndpoint = new FlightEndpoint(ticket, location); return new FlightInfo(schema, flightDescriptor, ImmutableList.of(flightEndpoint), -1, -1); From fdb9e11b93fa8fa7514ee10a90d97689201f1188 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Fri, 13 Aug 2021 16:28:57 -0300 Subject: [PATCH 07/13] Implement CommandGetCatalogs --- .../service/flight/DremioFlightProducer.java | 7 +- .../flight/impl/FlightWorkManager.java | 12 ++++ .../impl/GetCatalogsResponseHandler.java | 69 +++++++++++++++++++ .../flight/AbstractTestFlightServer.java | 25 ++++--- .../flight/AbstractTestFlightSqlServer.java | 59 ++++++++++++++++ .../flight/TestFlightServerWithBasicAuth.java | 8 +-- .../flight/TestFlightServerWithTokenAuth.java | 9 ++- .../TestFlightSqlServerWithBasicAuth.java | 15 ++-- .../TestFlightSqlServerWithTokenAuth.java | 13 ++-- 9 files changed, 186 insertions(+), 31 deletions(-) create mode 100644 services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java create mode 100644 services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightSqlServer.java diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index dddcb96101..2bb2d5d2f7 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -317,13 +317,16 @@ public void getStreamSqlInfo(CommandGetSqlInfo commandGetSqlInfo, public FlightInfo getFlightInfoCatalogs( CommandGetCatalogs commandGetCatalogs, CallContext callContext, FlightDescriptor flightDescriptor) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); + return getFlightInfoForFlightSqlCommands(commandGetCatalogs, flightDescriptor, getSchemaCatalogs().getSchema()); } @Override public void getStreamCatalogs(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) { - throw CallStatus.UNIMPLEMENTED.withDescription("CommandGetCatalogs not supported.").toRuntimeException(); + final CallHeaders headers = retrieveHeadersFromCallContext(callContext); + final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers); + + flightWorkManager.getCatalogs(serverStreamListener, allocator, session); } @Override diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java index 85e915acd1..856bebaf94 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java @@ -109,6 +109,18 @@ public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStat workerProvider.get().submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP); } + public void getCatalogs(FlightProducer.ServerStreamListener listener, BufferAllocator allocator, + UserSession userSession) { + final UserBitShared.ExternalId runExternalId = ExternalIdHelper.generateExternalId(); + final UserRequest userRequest = + new UserRequest(UserProtos.RpcType.GET_CATALOGS, UserProtos.GetCatalogsReq.newBuilder().build()); + + final UserResponseHandler responseHandler = new GetCatalogsResponseHandler(allocator, listener); + + workerProvider.get() + .submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP); + } + @VisibleForTesting static String getQuery(FlightDescriptor descriptor) { if (!descriptor.isCommand()) { diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java new file mode 100644 index 0000000000..0de8670b82 --- /dev/null +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2017-2019 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dremio.service.flight.impl; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.util.Text; + +import com.dremio.common.utils.protos.QueryWritableBatch; +import com.dremio.exec.proto.GeneralRPCProtos; +import com.dremio.exec.proto.UserProtos; +import com.dremio.exec.rpc.RpcOutcomeListener; +import com.dremio.exec.work.protector.UserResponseHandler; +import com.dremio.exec.work.protector.UserResult; + +class GetCatalogsResponseHandler implements UserResponseHandler { + private final BufferAllocator allocator; + private final FlightProducer.ServerStreamListener listener; + + public GetCatalogsResponseHandler(BufferAllocator allocator, FlightProducer.ServerStreamListener listener) { + this.allocator = allocator; + this.listener = listener; + } + + @Override + public void sendData(RpcOutcomeListener outcomeListener, + QueryWritableBatch result) { + } + + @Override + public void completed(UserResult result) { + UserProtos.GetCatalogsResp catalogsResp = result.unwrap(UserProtos.GetCatalogsResp.class); + + try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, + allocator)) { + listener.start(vectorSchemaRoot); + + vectorSchemaRoot.allocateNew(); + VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name"); + + int i = 0; + for (UserProtos.CatalogMetadata catalogMetadata : catalogsResp.getCatalogsList()) { + catalogNameVector.setSafe(i, new Text(catalogMetadata.getCatalogName())); + i++; + } + + vectorSchemaRoot.setRowCount(catalogsResp.getCatalogsCount()); + listener.putNext(); + listener.completed(); + } + } +} diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java index fedb5d4553..9bbabce5f3 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java @@ -19,15 +19,19 @@ import static org.junit.Assert.assertTrue; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; +import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; @@ -198,17 +202,13 @@ public void describeTo(Description description) { } } - public abstract FlightInfo getFlightInfo(String query) throws IOException, SQLException; - - private FlightStream executeQuery(FlightClientUtils.FlightClientWrapper wrapper, String query) - throws IOException, SQLException { + private FlightStream executeQuery(FlightClientUtils.FlightClientWrapper wrapper, String query) throws SQLException { // Assumption is that we have exactly one endpoint returned. - return (DremioFlightService.FLIGHT_LEGACY_AUTH_MODE.equals(wrapper.getAuthMode()))? - wrapper.getClient().getStream(getFlightInfo(query).getEndpoints().get(0).getTicket()): - wrapper.getClient().getStream(getFlightInfo(query).getEndpoints().get(0).getTicket(), wrapper.getTokenCallOption()); + Ticket ticket = getFlightInfo(query).getEndpoints().get(0).getTicket(); + return wrapper.getClient().getStream(ticket, getCallOptions()); } - private FlightStream executeQuery(String query) throws IOException, SQLException { + private FlightStream executeQuery(String query) throws SQLException { // Assumption is that we have exactly one endpoint returned. return executeQuery(getFlightClientWrapper(), query); } @@ -235,4 +235,13 @@ private List executeQueryWithStringResults(String query) throws Exceptio return actualStringResults; } } + + abstract CallOption[] getCallOptions(); + + public FlightInfo getFlightInfo(String query) throws SQLException { + final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); + + final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)); + return wrapper.getClient().getInfo(command, getCallOptions()); + } } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightSqlServer.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightSqlServer.java new file mode 100644 index 0000000000..c1dea2b950 --- /dev/null +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightSqlServer.java @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2017-2019 Dremio Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.dremio.service.flight; + +import java.sql.SQLException; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.Assert; +import org.junit.Test; + +public abstract class AbstractTestFlightSqlServer extends AbstractTestFlightServer { + + @Override + public FlightInfo getFlightInfo(String query) throws SQLException { + final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper(); + + final FlightSqlClient.PreparedStatement preparedStatement = + clientWrapper.getSqlClient().prepare(query, getCallOptions()); + + return preparedStatement.execute(); + } + + @Test + public void testGetCatalogs() { + FlightSqlClient flightSqlClient = getFlightClientWrapper().getSqlClient(); + CallOption[] callOptions = getCallOptions(); + + FlightInfo flightInfo = flightSqlClient.getCatalogs(callOptions); + try (FlightStream stream = flightSqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket(), callOptions)) { + Assert.assertTrue(stream.next()); + VectorSchemaRoot root = stream.getRoot(); + Assert.assertEquals(1, root.getRowCount()); + + String catalogName = ((VarCharVector) root.getVector("catalog_name")).getObject(0).toString(); + Assert.assertEquals("DREMIO", catalogName); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java index 80f06b9fdf..5273fc921e 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithBasicAuth.java @@ -18,6 +18,7 @@ import java.nio.charset.StandardCharsets; +import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.junit.BeforeClass; @@ -44,10 +45,7 @@ protected String getAuthMode() { } @Override - public FlightInfo getFlightInfo(String query) { - final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); - - final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)); - return wrapper.getClient().getInfo(command); + CallOption[] getCallOptions() { + return new CallOption[0]; } } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java index 5e8f4d9613..144894af26 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightServerWithTokenAuth.java @@ -18,6 +18,7 @@ import java.nio.charset.StandardCharsets; +import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.junit.BeforeClass; @@ -43,10 +44,14 @@ protected String getAuthMode() { } @Override - public FlightInfo getFlightInfo(String query) { + CallOption[] getCallOptions() { final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); + return new CallOption[] { wrapper.getTokenCallOption() }; + } + @Override + public FlightInfo getFlightInfo(String query) { final FlightDescriptor command = FlightDescriptor.command(query.getBytes(StandardCharsets.UTF_8)); - return wrapper.getClient().getInfo(command, wrapper.getTokenCallOption()); + return getFlightClientWrapper().getClient().getInfo(command, getCallOptions()); } } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java index 30c2362def..0da66e2c15 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java @@ -18,17 +18,23 @@ import java.sql.SQLException; +import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Ignore; +import org.junit.Test; import com.dremio.service.flight.impl.FlightWorkManager; /** * Test FlightServer with basic authentication using FlightSql producer. */ -public class TestFlightSqlServerWithBasicAuth extends AbstractTestFlightServer { +public class TestFlightSqlServerWithBasicAuth extends AbstractTestFlightSqlServer { @BeforeClass public static void setup() throws Exception { setupBaseFlightQueryTest( @@ -45,10 +51,7 @@ protected String getAuthMode() { } @Override - public FlightInfo getFlightInfo(String query) throws SQLException { - final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper(); - - final FlightSqlClient.PreparedStatement preparedStatement = clientWrapper.getSqlClient().prepare(query); - return preparedStatement.execute(); + CallOption[] getCallOptions() { + return new CallOption[0]; } } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java index 4ed9229a50..228f187a1d 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java @@ -18,6 +18,7 @@ import java.sql.SQLException; +import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.sql.FlightSqlClient; import org.junit.BeforeClass; @@ -28,7 +29,7 @@ /** * Test FlightServer with bearer token authentication using FlightSql producer. */ -public class TestFlightSqlServerWithTokenAuth extends AbstractTestFlightServer { +public class TestFlightSqlServerWithTokenAuth extends AbstractTestFlightSqlServer { @BeforeClass public static void setup() throws Exception { setupBaseFlightQueryTest( @@ -44,12 +45,8 @@ protected String getAuthMode() { } @Override - public FlightInfo getFlightInfo(String query) throws SQLException { - final FlightClientUtils.FlightClientWrapper clientWrapper = getFlightClientWrapper(); - - final FlightSqlClient.PreparedStatement preparedStatement = - clientWrapper.getSqlClient().prepare(query, clientWrapper.getTokenCallOption()); - - return preparedStatement.execute(); + CallOption[] getCallOptions() { + final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); + return new CallOption[] { wrapper.getTokenCallOption() }; } } From 55b37fcc9542a78b7a911150c94d806de0f86fb0 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Mon, 16 Aug 2021 13:48:29 -0300 Subject: [PATCH 08/13] Remove unused imports from TestFLightSqlServer*.java --- .../flight/TestFlightSqlServerWithBasicAuth.java | 10 ---------- .../flight/TestFlightSqlServerWithTokenAuth.java | 7 +------ 2 files changed, 1 insertion(+), 16 deletions(-) diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java index 0da66e2c15..2035cfc887 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithBasicAuth.java @@ -16,18 +16,8 @@ package com.dremio.service.flight; -import java.sql.SQLException; - import org.apache.arrow.flight.CallOption; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightStream; -import org.apache.arrow.flight.sql.FlightSqlClient; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.junit.Assert; import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; import com.dremio.service.flight.impl.FlightWorkManager; diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java index 228f187a1d..750487d61c 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/TestFlightSqlServerWithTokenAuth.java @@ -16,13 +16,8 @@ package com.dremio.service.flight; -import java.sql.SQLException; - import org.apache.arrow.flight.CallOption; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.sql.FlightSqlClient; import org.junit.BeforeClass; -import org.junit.Ignore; import com.dremio.service.flight.impl.FlightWorkManager; @@ -47,6 +42,6 @@ protected String getAuthMode() { @Override CallOption[] getCallOptions() { final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); - return new CallOption[] { wrapper.getTokenCallOption() }; + return new CallOption[] {wrapper.getTokenCallOption()}; } } From c679c3a758212fe7c4486fcd6561f173804857fa Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Mon, 16 Aug 2021 14:28:43 -0300 Subject: [PATCH 09/13] Add missing JavaDoc --- .../com/dremio/service/flight/impl/FlightWorkManager.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java index 856bebaf94..793e912f42 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java @@ -109,6 +109,13 @@ public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStat workerProvider.get().submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP); } + /** + * Submits a GET_CATALOGS job to a worker and sends the response to given ServerStreamListener. + * + * @param listener ServerStreamListener listening to the job result. + * @param allocator BufferAllocator used to allocate the response VectorSchemaRoot. + * @param userSession The session for the user which made the request. + */ public void getCatalogs(FlightProducer.ServerStreamListener listener, BufferAllocator allocator, UserSession userSession) { final UserBitShared.ExternalId runExternalId = ExternalIdHelper.generateExternalId(); From 6644a61a365671162691ee8ee44d4c70674896e2 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 17 Aug 2021 13:57:55 -0300 Subject: [PATCH 10/13] Add missing JavaDoc --- .../dremio/service/flight/impl/GetCatalogsResponseHandler.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java index 0de8670b82..8b4bd0b458 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java @@ -30,6 +30,9 @@ import com.dremio.exec.work.protector.UserResponseHandler; import com.dremio.exec.work.protector.UserResult; +/** + * {@link UserResponseHandler} implementation for {@link FlightWorkManager#getCatalogs}. + */ class GetCatalogsResponseHandler implements UserResponseHandler { private final BufferAllocator allocator; private final FlightProducer.ServerStreamListener listener; From defb39a3b0a80e0c67ab17494a37dd42841b73bb Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 17 Aug 2021 16:11:04 -0300 Subject: [PATCH 11/13] Refactor CancellableUserResponseHandler and its previous subclasses --- .../service/flight/DremioFlightProducer.java | 2 +- ...reatePreparedStatementResponseHandler.java | 87 ------------ .../flight/impl/FlightPreparedStatement.java | 5 +- .../flight/impl/FlightWorkManager.java | 48 +++++-- .../impl/GetCatalogsResponseHandler.java | 72 ---------- .../CancellableUserResponseHandler.java | 50 ++++++- ...reatePreparedStatementResponseHandler.java | 124 ------------------ .../impl/TestFlightPreparedStatement.java | 5 +- .../TestCancellableUserResponseHandler.java | 60 +++------ 9 files changed, 115 insertions(+), 338 deletions(-) delete mode 100644 services/arrow-flight/src/main/java/com/dremio/service/flight/impl/CreatePreparedStatementResponseHandler.java delete mode 100644 services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java delete mode 100644 services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestCreatePreparedStatementResponseHandler.java diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java index 2bb2d5d2f7..4fb40b1a13 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/DremioFlightProducer.java @@ -326,7 +326,7 @@ public void getStreamCatalogs(CallContext callContext, Ticket ticket, final CallHeaders headers = retrieveHeadersFromCallContext(callContext); final UserSession session = sessionsManager.getUserSession(callContext.peerIdentity(), headers); - flightWorkManager.getCatalogs(serverStreamListener, allocator, session); + flightWorkManager.getCatalogs(serverStreamListener, allocator, callContext::isCancelled, session); } @Override diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/CreatePreparedStatementResponseHandler.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/CreatePreparedStatementResponseHandler.java deleted file mode 100644 index 515c5b002a..0000000000 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/CreatePreparedStatementResponseHandler.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (C) 2017-2019 Dremio Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.dremio.service.flight.impl; - -import java.util.function.Supplier; - -import javax.inject.Provider; - -import org.apache.arrow.flight.CallStatus; - -import com.dremio.common.utils.protos.QueryWritableBatch; -import com.dremio.exec.proto.GeneralRPCProtos; -import com.dremio.exec.proto.UserBitShared; -import com.dremio.exec.proto.UserProtos; -import com.dremio.exec.rpc.RpcOutcomeListener; -import com.dremio.exec.work.protector.UserResult; -import com.dremio.exec.work.protector.UserWorker; -import com.dremio.sabot.rpc.user.UserSession; -import com.dremio.service.flight.error.mapping.DremioFlightErrorMapper; -import com.dremio.service.flight.protector.CancellableUserResponseHandler; - -/** - * The UserResponseHandler that consumes a CreatePreparedStatementResponse. - */ -public class CreatePreparedStatementResponseHandler extends - CancellableUserResponseHandler { - - static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(CreatePreparedStatementResponseHandler.class); - - public CreatePreparedStatementResponseHandler(UserBitShared.ExternalId prepareExternalId, - UserSession userSession, - Provider workerProvider, - Supplier isRequestCancelled) { - super(prepareExternalId, userSession, workerProvider, isRequestCancelled); - } - - @Override - public void sendData(RpcOutcomeListener outcomeListener, QueryWritableBatch result) { - throw new UnsupportedOperationException("A response sender based implementation should send no data to end users."); - } - - @Override - public void completed(UserResult result) { - switch (result.getState()) { - case COMPLETED: - getCompletableFuture().complete(result.unwrap(UserProtos.CreatePreparedStatementArrowResp.class)); - break; - case FAILED: - getCompletableFuture().completeExceptionally( - DremioFlightErrorMapper.toFlightRuntimeException(result.getException())); - break; - case CANCELED: - final Exception canceledException = result.getException(); - getCompletableFuture().completeExceptionally( - CallStatus.CANCELLED - .withCause(canceledException) - .withDescription(canceledException.getMessage()) - .toRuntimeException()); - break; - - case STARTING: - case RUNNING: - case NO_LONGER_USED_1: - case ENQUEUED: - default: - getCompletableFuture().completeExceptionally( - CallStatus.INTERNAL - .withCause(new IllegalStateException()) - .withDescription("Internal Error: Invalid planning state.") - .toRuntimeException()); - break; - } - } -} diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java index f8dbc22e93..283cd0488b 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightPreparedStatement.java @@ -26,6 +26,7 @@ import com.dremio.exec.proto.UserProtos; import com.dremio.service.flight.TicketContent.PreparedStatementTicket; +import com.dremio.service.flight.protector.CancellableUserResponseHandler; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; @@ -37,10 +38,10 @@ public class FlightPreparedStatement { private final FlightDescriptor flightDescriptor; private final String query; - private final CreatePreparedStatementResponseHandler responseHandler; + private final CancellableUserResponseHandler responseHandler; public FlightPreparedStatement(FlightDescriptor flightDescriptor, String query, - CreatePreparedStatementResponseHandler responseHandler) { + CancellableUserResponseHandler responseHandler) { this.flightDescriptor = flightDescriptor; this.query = query; this.responseHandler = responseHandler; diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java index 793e912f42..c4b7608e92 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/FlightWorkManager.java @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.dremio.service.flight.impl; import java.nio.charset.StandardCharsets; @@ -23,7 +24,11 @@ import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.sql.FlightSqlProducer; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.util.Text; import com.dremio.common.utils.protos.ExternalIdHelper; import com.dremio.exec.proto.UserBitShared; @@ -35,9 +40,9 @@ import com.dremio.options.OptionManager; import com.dremio.sabot.rpc.user.UserSession; import com.dremio.service.flight.DremioFlightServiceOptions; -import com.dremio.service.flight.TicketContent; import com.dremio.service.flight.impl.RunQueryResponseHandler.BackpressureHandlingResponseHandler; import com.dremio.service.flight.impl.RunQueryResponseHandler.BasicResponseHandler; +import com.dremio.service.flight.protector.CancellableUserResponseHandler; import com.google.common.annotations.VisibleForTesting; /** @@ -67,7 +72,8 @@ public FlightWorkManager(Provider workerProvider, * @return A FlightPreparedStatement which consumes the result of the job. */ public FlightPreparedStatement createPreparedStatement(FlightDescriptor flightDescriptor, - Supplier isRequestCancelled, UserSession userSession) { + Supplier isRequestCancelled, + UserSession userSession) { final String query = getQuery(flightDescriptor); final UserProtos.CreatePreparedStatementArrowReq createPreparedStatementReq = @@ -79,8 +85,10 @@ public FlightPreparedStatement createPreparedStatement(FlightDescriptor flightDe final UserRequest userRequest = new UserRequest(UserProtos.RpcType.CREATE_PREPARED_STATEMENT_ARROW, createPreparedStatementReq); - final CreatePreparedStatementResponseHandler createPreparedStatementResponseHandler = - new CreatePreparedStatementResponseHandler(prepareExternalId, userSession, workerProvider, isRequestCancelled); + final CancellableUserResponseHandler + createPreparedStatementResponseHandler = + new CancellableUserResponseHandler<>(prepareExternalId, userSession, + workerProvider, isRequestCancelled, UserProtos.CreatePreparedStatementArrowResp.class); workerProvider.get().submitWork(prepareExternalId, userSession, createPreparedStatementResponseHandler, userRequest, TerminationListenerRegistry.NOOP); @@ -89,8 +97,8 @@ public FlightPreparedStatement createPreparedStatement(FlightDescriptor flightDe } public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStatementHandle, - FlightProducer.ServerStreamListener listener, BufferAllocator allocator, - UserSession userSession) { + FlightProducer.ServerStreamListener listener, BufferAllocator allocator, + UserSession userSession) { final UserBitShared.ExternalId runExternalId = ExternalIdHelper.generateExternalId(); final UserRequest userRequest = new UserRequest(UserProtos.RpcType.RUN_QUERY, @@ -106,7 +114,8 @@ public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStat final UserResponseHandler responseHandler = runQueryResponseHandlerFactory.getHandler(runExternalId, userSession, workerProvider, optionManagerProvider, listener, allocator); - workerProvider.get().submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP); + workerProvider.get() + .submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP); } /** @@ -117,15 +126,36 @@ public void runPreparedStatement(UserProtos.PreparedStatementHandle preparedStat * @param userSession The session for the user which made the request. */ public void getCatalogs(FlightProducer.ServerStreamListener listener, BufferAllocator allocator, - UserSession userSession) { + Supplier isRequestCancelled, UserSession userSession) { final UserBitShared.ExternalId runExternalId = ExternalIdHelper.generateExternalId(); final UserRequest userRequest = new UserRequest(UserProtos.RpcType.GET_CATALOGS, UserProtos.GetCatalogsReq.newBuilder().build()); - final UserResponseHandler responseHandler = new GetCatalogsResponseHandler(allocator, listener); + final CancellableUserResponseHandler responseHandler = + new CancellableUserResponseHandler<>(runExternalId, userSession, workerProvider, isRequestCancelled, + UserProtos.GetCatalogsResp.class); workerProvider.get() .submitWork(runExternalId, userSession, responseHandler, userRequest, TerminationListenerRegistry.NOOP); + + UserProtos.GetCatalogsResp response = responseHandler.get(); + try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, + allocator)) { + listener.start(vectorSchemaRoot); + + vectorSchemaRoot.allocateNew(); + VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name"); + + int i = 0; + for (UserProtos.CatalogMetadata catalogMetadata : response.getCatalogsList()) { + catalogNameVector.setSafe(i, new Text(catalogMetadata.getCatalogName())); + i++; + } + + vectorSchemaRoot.setRowCount(response.getCatalogsCount()); + listener.putNext(); + listener.completed(); + } } @VisibleForTesting diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java deleted file mode 100644 index 8b4bd0b458..0000000000 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/impl/GetCatalogsResponseHandler.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (C) 2017-2019 Dremio Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.dremio.service.flight.impl; - -import org.apache.arrow.flight.FlightProducer; -import org.apache.arrow.flight.sql.FlightSqlProducer; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.util.Text; - -import com.dremio.common.utils.protos.QueryWritableBatch; -import com.dremio.exec.proto.GeneralRPCProtos; -import com.dremio.exec.proto.UserProtos; -import com.dremio.exec.rpc.RpcOutcomeListener; -import com.dremio.exec.work.protector.UserResponseHandler; -import com.dremio.exec.work.protector.UserResult; - -/** - * {@link UserResponseHandler} implementation for {@link FlightWorkManager#getCatalogs}. - */ -class GetCatalogsResponseHandler implements UserResponseHandler { - private final BufferAllocator allocator; - private final FlightProducer.ServerStreamListener listener; - - public GetCatalogsResponseHandler(BufferAllocator allocator, FlightProducer.ServerStreamListener listener) { - this.allocator = allocator; - this.listener = listener; - } - - @Override - public void sendData(RpcOutcomeListener outcomeListener, - QueryWritableBatch result) { - } - - @Override - public void completed(UserResult result) { - UserProtos.GetCatalogsResp catalogsResp = result.unwrap(UserProtos.GetCatalogsResp.class); - - try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA, - allocator)) { - listener.start(vectorSchemaRoot); - - vectorSchemaRoot.allocateNew(); - VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name"); - - int i = 0; - for (UserProtos.CatalogMetadata catalogMetadata : catalogsResp.getCatalogsList()) { - catalogNameVector.setSafe(i, new Text(catalogMetadata.getCatalogName())); - i++; - } - - vectorSchemaRoot.setRowCount(catalogsResp.getCatalogsCount()); - listener.putNext(); - listener.completed(); - } - } -} diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java index 91d3eb91d6..a3fdacda6b 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java @@ -27,8 +27,12 @@ import org.apache.arrow.flight.FlightRuntimeException; import com.dremio.common.exceptions.UserException; +import com.dremio.common.utils.protos.QueryWritableBatch; +import com.dremio.exec.proto.GeneralRPCProtos; import com.dremio.exec.proto.UserBitShared; +import com.dremio.exec.rpc.RpcOutcomeListener; import com.dremio.exec.work.protector.UserResponseHandler; +import com.dremio.exec.work.protector.UserResult; import com.dremio.exec.work.protector.UserWorker; import com.dremio.sabot.rpc.user.UserSession; import com.dremio.service.flight.error.mapping.DremioFlightErrorMapper; @@ -39,21 +43,62 @@ * * @param The response type. */ -public abstract class CancellableUserResponseHandler implements UserResponseHandler { +public class CancellableUserResponseHandler implements UserResponseHandler { private final CompletableFuture future = new CompletableFuture<>(); private final Supplier isRequestCancelled; private final UserBitShared.ExternalId externalId; private final UserSession userSession; private final Provider workerProvider; + private final Class responseType; public CancellableUserResponseHandler(UserBitShared.ExternalId externalId, UserSession userSession, Provider workerProvider, - Supplier isRequestCancelled) { + Supplier isRequestCancelled, + Class responseType) { this.externalId = externalId; this.userSession = userSession; this.workerProvider = workerProvider; this.isRequestCancelled = isRequestCancelled; + this.responseType = responseType; + } + + @Override + public final void sendData(RpcOutcomeListener outcomeListener, QueryWritableBatch result) { + throw new UnsupportedOperationException("A response sender based implementation should send no data to end users."); + } + + @Override + public final void completed(UserResult result) { + switch (result.getState()) { + case COMPLETED: + getCompletableFuture().complete(result.unwrap(responseType)); + break; + case FAILED: + getCompletableFuture().completeExceptionally( + DremioFlightErrorMapper.toFlightRuntimeException(result.getException())); + break; + case CANCELED: + final Exception canceledException = result.getException(); + getCompletableFuture().completeExceptionally( + CallStatus.CANCELLED + .withCause(canceledException) + .withDescription(canceledException.getMessage()) + .toRuntimeException()); + break; + + case STARTING: + case RUNNING: + case NO_LONGER_USED_1: + case ENQUEUED: + default: + getCompletableFuture().completeExceptionally( + CallStatus.INTERNAL + .withCause(new IllegalStateException()) + .withDescription("Internal Error: Invalid planning state.") + .toRuntimeException()); + break; + } } public T get() { @@ -104,4 +149,5 @@ public void cancelJob() { protected CompletableFuture getCompletableFuture() { return future; } + } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestCreatePreparedStatementResponseHandler.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestCreatePreparedStatementResponseHandler.java deleted file mode 100644 index 49ae2a6489..0000000000 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestCreatePreparedStatementResponseHandler.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright (C) 2017-2019 Dremio Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.dremio.service.flight.impl; - -import static com.dremio.exec.proto.UserBitShared.ExternalId; -import static com.dremio.exec.proto.UserBitShared.QueryId; -import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.mock; - -import javax.inject.Provider; - -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.FlightRuntimeException; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; - -import com.dremio.common.exceptions.UserException; -import com.dremio.common.utils.protos.ExternalIdHelper; -import com.dremio.exec.proto.UserBitShared.QueryProfile; -import com.dremio.exec.proto.UserBitShared.QueryResult.QueryState; -import com.dremio.exec.proto.UserProtos.CreatePreparedStatementArrowResp; -import com.dremio.exec.work.protector.UserResult; -import com.dremio.exec.work.protector.UserWorker; -import com.dremio.sabot.rpc.user.UserSession; - -/** - * Tests for CreatePreparedStatementResponseHandler. - */ -public class TestCreatePreparedStatementResponseHandler { - - private final QueryId queryId = QueryId.getDefaultInstance(); - private final QueryProfile resultProfile = QueryProfile.getDefaultInstance(); - private final ExternalId externalId = ExternalIdHelper.generateExternalId(); - private final UserSession userSession = mock(UserSession.class); - private final Provider mockedUserWorkerProvider = mock(Provider.class); - - @Rule - public ExpectedException thrown = ExpectedException.none(); - - @Test - public void testSuccessfulResultPropagation() { - // Arrange - final CreatePreparedStatementArrowResp expected = CreatePreparedStatementArrowResp.newBuilder().build(); - - final CreatePreparedStatementResponseHandler cancellableUserResponseHandler = - new CreatePreparedStatementResponseHandler(externalId, userSession, mockedUserWorkerProvider, () -> false); - - final UserResult result = new UserResult(expected, queryId, - QueryState.COMPLETED, - resultProfile, - null, // UserException - null, // cancelReason - false // clientCancelled - ); - - // Act - cancellableUserResponseHandler.completed(result); - - // Assert - CreatePreparedStatementArrowResp actual = cancellableUserResponseHandler.get(); - assertEquals(expected, actual); - } - - @Test - public void testCancelledExceptionResultPropagation() { - testExceptionResultPropagation(CallStatus.CANCELLED, QueryState.CANCELED); - } - - @Test - public void testFailedExceptionResultPropagation() { - testExceptionResultPropagation(CallStatus.INTERNAL, QueryState.FAILED); - } - - public void testExceptionResultPropagation(CallStatus callStatus, QueryState queryState) { - // Arrange - final CreatePreparedStatementResponseHandler cancellableUserResponseHandler = - new CreatePreparedStatementResponseHandler(externalId, userSession, mockedUserWorkerProvider, () -> false); - - final Exception original = new TestException("Dummy Exception"); - final UserException cause = UserException.parseError(original).buildSilently(); - - final Throwable expected = callStatus - .withCause(original) - .withDescription(original.getLocalizedMessage()) - .toRuntimeException(); - - thrown.expectMessage(expected.getLocalizedMessage()); - thrown.expect(FlightRuntimeException.class); - - final UserResult result = new UserResult( - expected, - queryId, - queryState, - resultProfile, - cause, // UserException - null, // cancelReason - false // clientCancelled - ); - cancellableUserResponseHandler.completed(result); - - // Act - cancellableUserResponseHandler.get(); - } - - private static class TestException extends Exception { - public TestException(String message) { - super(message); - } - } -} diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestFlightPreparedStatement.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestFlightPreparedStatement.java index 20c3b3d94d..1a09fd7d6d 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestFlightPreparedStatement.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/impl/TestFlightPreparedStatement.java @@ -40,6 +40,7 @@ import com.dremio.exec.proto.UserProtos; import com.dremio.exec.proto.UserProtos.PreparedStatementHandle; import com.dremio.service.flight.TicketContent; +import com.dremio.service.flight.protector.CancellableUserResponseHandler; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; @@ -81,7 +82,7 @@ public class TestFlightPreparedStatement { .build()) .build(); - private static CreatePreparedStatementResponseHandler mockHandler; + private static CancellableUserResponseHandler mockHandler; private static Location mockLocation; @Rule @@ -89,7 +90,7 @@ public class TestFlightPreparedStatement { @Before public void setup() { - mockHandler = mock(CreatePreparedStatementResponseHandler.class); + mockHandler = mock(CancellableUserResponseHandler.class); mockLocation = mock(Location.class); } diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java index 5937948a6d..a4d4e361f8 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package com.dremio.service.flight.protector; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -35,11 +35,9 @@ import org.junit.Test; import org.junit.rules.ExpectedException; +import com.dremio.common.exceptions.UserException; import com.dremio.common.utils.protos.ExternalIdHelper; -import com.dremio.common.utils.protos.QueryWritableBatch; -import com.dremio.exec.proto.GeneralRPCProtos; import com.dremio.exec.proto.UserBitShared; -import com.dremio.exec.rpc.RpcOutcomeListener; import com.dremio.exec.work.protector.UserResult; import com.dremio.exec.work.protector.UserWorker; import com.dremio.sabot.rpc.user.UserSession; @@ -55,6 +53,8 @@ public class TestCancellableUserResponseHandler { private final UserSession userSession = mock(UserSession.class); private final UserWorker userWorker = mock(UserWorker.class); private final Provider mockedUserWorkerProvider = mock(Provider.class); + private final UserBitShared.QueryId queryId = UserBitShared.QueryId.getDefaultInstance(); + private final UserBitShared.QueryProfile resultProfile = UserBitShared.QueryProfile.getDefaultInstance(); @Rule public ExpectedException thrown = ExpectedException.none(); @@ -69,20 +69,14 @@ public void setup() { public void testSuccessfulResultPropagation() { final BigDecimal expected = new BigDecimal(1); final CancellableUserResponseHandler cancellableUserResponseHandler = - new CancellableUserResponseHandler(externalId, userSession, mockedUserWorkerProvider, () -> false) { - @Override - public void sendData(RpcOutcomeListener outcomeListener, QueryWritableBatch result) { - fail(); - } - - @Override - public void completed(UserResult result) { - getCompletableFuture().complete(expected); - } - }; + new CancellableUserResponseHandler<>(externalId, userSession, mockedUserWorkerProvider, () -> false, + BigDecimal.class); // Act - cancellableUserResponseHandler.completed(null); + UserResult userResult = + new UserResult(expected, queryId, UserBitShared.QueryResult.QueryState.COMPLETED, resultProfile, null, null, + false); + cancellableUserResponseHandler.completed(userResult); // Assert final BigDecimal actual = cancellableUserResponseHandler.get(); @@ -104,18 +98,15 @@ public void testExceptionalPropagation() { thrown.expect(FlightRuntimeException.class); final CancellableUserResponseHandler cancellableUserResponseHandler = - new CancellableUserResponseHandler(externalId, userSession, mockedUserWorkerProvider, () -> false) { - @Override - public void sendData(RpcOutcomeListener outcomeListener, QueryWritableBatch result) { - fail(); - } - - @Override - public void completed(UserResult result) { - getCompletableFuture().completeExceptionally(thrownRootException); - } - }; - cancellableUserResponseHandler.completed(null); + new CancellableUserResponseHandler<>(externalId, userSession, mockedUserWorkerProvider, () -> false, + BigDecimal.class); + + UserException userException = UserException.parseError(expected).buildSilently(); + UserResult userResult = + new UserResult(null, queryId, UserBitShared.QueryResult.QueryState.FAILED, resultProfile, userException, null, + false); + cancellableUserResponseHandler.completed( + userResult); // Act cancellableUserResponseHandler.get(); @@ -141,17 +132,8 @@ private void testClientCancelCaughtAndPropagatedToServer(Supplier isCan thrown.expect(FlightRuntimeException.class); final CancellableUserResponseHandler cancellableUserResponseHandler = - new CancellableUserResponseHandler(externalId, userSession, mockedUserWorkerProvider, isCancelled) { - @Override - public void sendData(RpcOutcomeListener outcomeListener, QueryWritableBatch result) { - fail(); - } - - @Override - public void completed(UserResult result) { - fail(); - } - }; + new CancellableUserResponseHandler<>(externalId, userSession, mockedUserWorkerProvider, isCancelled, + BigDecimal.class); try { // Act From d748158c0bd7f01329acbdc21310e3c48a8e34f8 Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 17 Aug 2021 16:45:47 -0300 Subject: [PATCH 12/13] Add missing JavaDocs --- .../com/dremio/service/flight/AbstractTestFlightServer.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java index 9bbabce5f3..7e3d1080b0 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/AbstractTestFlightServer.java @@ -236,8 +236,14 @@ private List executeQueryWithStringResults(String query) throws Exceptio } } + /** + * Return an array of {@link CallOption} used in all calls to Flight Server (getFlightInfo, getStream, etc.). + */ abstract CallOption[] getCallOptions(); + /** + * Returns a FlightInfo for executing given query. + */ public FlightInfo getFlightInfo(String query) throws SQLException { final FlightClientUtils.FlightClientWrapper wrapper = getFlightClientWrapper(); From cfac6b11631c022f06f681aa7fa024cc662f573c Mon Sep 17 00:00:00 2001 From: Rafael Telles Date: Tue, 17 Aug 2021 17:34:47 -0300 Subject: [PATCH 13/13] Update TestCancellableUserResponseHandler --- .../protector/CancellableUserResponseHandler.java | 2 +- .../protector/TestCancellableUserResponseHandler.java | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java b/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java index a3fdacda6b..e592c351e9 100644 --- a/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java +++ b/services/arrow-flight/src/main/java/com/dremio/service/flight/protector/CancellableUserResponseHandler.java @@ -69,7 +69,7 @@ public final void sendData(RpcOutcomeListener outcomeListe } @Override - public final void completed(UserResult result) { + public void completed(UserResult result) { switch (result.getState()) { case COMPLETED: getCompletableFuture().complete(result.unwrap(responseType)); diff --git a/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java b/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java index a4d4e361f8..7b6c313861 100644 --- a/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java +++ b/services/arrow-flight/src/test/java/com/dremio/service/flight/protector/TestCancellableUserResponseHandler.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -132,8 +133,13 @@ private void testClientCancelCaughtAndPropagatedToServer(Supplier isCan thrown.expect(FlightRuntimeException.class); final CancellableUserResponseHandler cancellableUserResponseHandler = - new CancellableUserResponseHandler<>(externalId, userSession, mockedUserWorkerProvider, isCancelled, - BigDecimal.class); + new CancellableUserResponseHandler(externalId, userSession, mockedUserWorkerProvider, isCancelled, + BigDecimal.class) { + @Override + public void completed(UserResult result) { + fail(); + } + }; try { // Act