From eb981ab5bc9b579b2ed29b60be3102cce4f15639 Mon Sep 17 00:00:00 2001 From: Vanathi Ganeshraj Date: Wed, 24 Apr 2024 14:53:55 +0530 Subject: [PATCH] Add schema handling to remote task execution --- .../wrangler/api/RemoteDirectiveResponse.java | 42 +++++++ .../io/cdap/wrangler/api/TransientStore.java | 5 + .../aggregates/DefaultTransientStore.java | 11 ++ .../DefaultTransientStoreTypeAdapter.java | 110 ++++++++++++++++++ .../DefaultTransientStoreTypeAdapterTest.java | 85 ++++++++++++++ .../aggregates/SetTransientVariableTest.java | 10 ++ .../cdap/wrangler/utils/ObjectSerDeTest.java | 18 +++ .../directive/RemoteDirectiveRequest.java | 9 +- .../directive/RemoteExecutionTask.java | 29 ++++- .../service/directive/WorkspaceHandler.java | 13 ++- 10 files changed, 326 insertions(+), 6 deletions(-) create mode 100644 wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java create mode 100644 wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapter.java create mode 100644 wrangler-core/src/test/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapterTest.java diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java new file mode 100644 index 000000000..57670ac4c --- /dev/null +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/RemoteDirectiveResponse.java @@ -0,0 +1,42 @@ +/* + * Copyright © 2024 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.cdap.wrangler.api; + +import io.cdap.cdap.api.data.schema.Schema; + +import java.io.Serializable; +import java.util.List; + +/** + * Response after executing directives remotely + */ +public class RemoteDirectiveResponse implements Serializable { + private final List rows; + private final Schema outputSchema; + + public RemoteDirectiveResponse(List rows, Schema outputSchema) { + this.rows = rows; + this.outputSchema = outputSchema; + } + + public List getRows() { + return rows; + } + + public Schema getOutputSchema() { + return outputSchema; + } +} diff --git a/wrangler-api/src/main/java/io/cdap/wrangler/api/TransientStore.java b/wrangler-api/src/main/java/io/cdap/wrangler/api/TransientStore.java index bf3fb8171..accb71cda 100644 --- a/wrangler-api/src/main/java/io/cdap/wrangler/api/TransientStore.java +++ b/wrangler-api/src/main/java/io/cdap/wrangler/api/TransientStore.java @@ -17,6 +17,7 @@ package io.cdap.wrangler.api; import java.io.Serializable; +import java.util.Map; import java.util.Set; /** @@ -61,4 +62,8 @@ public interface TransientStore extends Serializable { * @return list of all the variables. */ Set getVariables(); + + Map getGlobalVariables(); + + Map getLocalVariables(); } diff --git a/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStore.java b/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStore.java index c515c868d..bd199afc9 100644 --- a/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStore.java +++ b/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStore.java @@ -16,6 +16,7 @@ package io.cdap.directives.aggregates; +import com.google.common.collect.ImmutableMap; import io.cdap.wrangler.api.TransientStore; import io.cdap.wrangler.api.TransientVariableScope; @@ -114,4 +115,14 @@ public void set(TransientVariableScope scope, String name, Object value) { local.put(name, value); } } + + @Override + public Map getGlobalVariables() { + return ImmutableMap.copyOf(global); + } + + @Override + public Map getLocalVariables() { + return ImmutableMap.copyOf(local); + } } diff --git a/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapter.java b/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapter.java new file mode 100644 index 000000000..32e741c74 --- /dev/null +++ b/wrangler-core/src/main/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapter.java @@ -0,0 +1,110 @@ +/* + * Copyright © 2021 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.directives.aggregates; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.TypeAdapter; +import com.google.gson.stream.JsonReader; +import com.google.gson.stream.JsonWriter; +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.cdap.internal.io.SchemaTypeAdapter; +import io.cdap.wrangler.api.TransientVariableScope; + +import java.io.IOException; +import java.util.Map; + +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + +/** + * GSON - JSON type adapter class for {@link DefaultTransientStore} + */ +public class DefaultTransientStoreTypeAdapter extends TypeAdapter { + private static final Gson GSON = new GsonBuilder() + .registerTypeAdapter(Schema.class, new SchemaTypeAdapter()) + .create(); + private static final String GLOBAL_MAP_KEY = "global"; + private static final String LOCAL_MAP_KEY = "local"; + + @Override + public void write(JsonWriter writer, DefaultTransientStore store) throws IOException { + if (store == null) { + writer.nullValue(); + return; + } + + writer.beginObject(); + writeMaps(writer, store); + writer.endObject(); + } + + @Override + public DefaultTransientStore read(JsonReader reader) throws IOException { + DefaultTransientStore store = new DefaultTransientStore(); + reader.beginObject(); + while (reader.hasNext()) { + String name = reader.nextName(); + if (name.equals(GLOBAL_MAP_KEY)) { + readMap(reader, store, TransientVariableScope.GLOBAL); + } else if (name.equals(LOCAL_MAP_KEY)) { + readMap(reader, store, TransientVariableScope.LOCAL); + } + } + reader.endObject(); + return store; + } + + private void writeMaps(JsonWriter writer, DefaultTransientStore store) throws IOException { + writer.name(GLOBAL_MAP_KEY).beginObject(); + writeMap(writer, store.getGlobalVariables()); + writer.endObject(); + + + writer.name(LOCAL_MAP_KEY).beginObject(); + writeMap(writer, store.getLocalVariables()); + writer.endObject(); + } + + private void writeMap(JsonWriter writer, Map map) throws IOException { + for (Map.Entry entry : map.entrySet()) { + String key = entry.getKey(); + writer.name(key); + if (key.equals(INPUT_SCHEMA) || key.equals(OUTPUT_SCHEMA)) { + GSON.toJson(entry.getValue(), Schema.class, writer); + } else { + GSON.toJson(entry.getValue(), Object.class, writer); + } + } + } + + private void readMap(JsonReader reader, DefaultTransientStore store, + TransientVariableScope scope) throws IOException { + reader.beginObject(); + while (reader.hasNext()) { + String key = reader.nextName(); + if (key.equals(INPUT_SCHEMA) || key.equals(OUTPUT_SCHEMA)) { + Schema schemaValue = GSON.fromJson(reader, Schema.class); + store.set(scope, key, schemaValue); + } else { + Object value = GSON.fromJson(reader, Object.class); + store.set(scope, key, value); + } + } + reader.endObject(); + } +} diff --git a/wrangler-core/src/test/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapterTest.java b/wrangler-core/src/test/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapterTest.java new file mode 100644 index 000000000..e6bb14d76 --- /dev/null +++ b/wrangler-core/src/test/java/io/cdap/directives/aggregates/DefaultTransientStoreTypeAdapterTest.java @@ -0,0 +1,85 @@ +/* + * Copyright © 2021 Cask Data, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.cdap.directives.aggregates; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.cdap.internal.io.SchemaTypeAdapter; +import io.cdap.wrangler.api.TransientStore; +import io.cdap.wrangler.api.TransientVariableScope; +import org.junit.BeforeClass; +import org.junit.Test; + +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static org.junit.Assert.assertEquals; + +public class DefaultTransientStoreTypeAdapterTest { + private static final TransientStore EXPECTED_STORE = new DefaultTransientStore(); + private static final String EXPECTED_JSON = "{" + + "\"global\":{" + + "\"num_val\":5.0" + + "}," + + "\"local\":{" + + "\"string_val\":\"hello\"" + + "}" + + "}"; + private static final Gson GSON = new GsonBuilder() + .registerTypeAdapter(TransientStore.class, new DefaultTransientStoreTypeAdapter()) + .registerTypeAdapter(Schema.class, new SchemaTypeAdapter()) + .create(); + + @BeforeClass + public static void setup() { + EXPECTED_STORE.set(TransientVariableScope.GLOBAL, "num_val", 5.0); + EXPECTED_STORE.set(TransientVariableScope.LOCAL, "string_val", "hello"); + } + + @Test + public void testWrite() { + TransientStore store = new DefaultTransientStore(); + store.set(TransientVariableScope.GLOBAL, "num_val", 5.0); + store.set(TransientVariableScope.LOCAL, "string_val", "hello"); + + String actualJson = GSON.toJson(store); + + assertEquals(EXPECTED_JSON, actualJson); + } + + @Test + public void testRead() { + TransientStore actualStore = GSON.fromJson(EXPECTED_JSON, TransientStore.class); + + assertEquals(EXPECTED_STORE.getGlobalVariables(), actualStore.getGlobalVariables()); + assertEquals(EXPECTED_STORE.getLocalVariables(), actualStore.getLocalVariables()); + } + + @Test + public void testSchema() { + TransientStore store = new DefaultTransientStore(); + Schema inputSchema = Schema.recordOf( + "inputSchema", + Schema.Field.of("int_col", Schema.of(Schema.Type.INT)) + ); + store.set(TransientVariableScope.LOCAL, INPUT_SCHEMA, inputSchema); + + String actualJson = GSON.toJson(store); + TransientStore actualStore = GSON.fromJson(actualJson, TransientStore.class); + + assertEquals(inputSchema, actualStore.get(INPUT_SCHEMA)); + } +} diff --git a/wrangler-core/src/test/java/io/cdap/directives/aggregates/SetTransientVariableTest.java b/wrangler-core/src/test/java/io/cdap/directives/aggregates/SetTransientVariableTest.java index bd83225b8..0833a29c1 100644 --- a/wrangler-core/src/test/java/io/cdap/directives/aggregates/SetTransientVariableTest.java +++ b/wrangler-core/src/test/java/io/cdap/directives/aggregates/SetTransientVariableTest.java @@ -108,6 +108,16 @@ public void increment(TransientVariableScope scope, String name, long value) { public Set getVariables() { return s.keySet(); } + + @Override + public Map getGlobalVariables() { + return null; + } + + @Override + public Map getLocalVariables() { + return null; + } }; } diff --git a/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java b/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java index 689668444..f297b23e8 100644 --- a/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java +++ b/wrangler-core/src/test/java/io/cdap/wrangler/utils/ObjectSerDeTest.java @@ -17,6 +17,8 @@ package io.cdap.wrangler.utils; import com.google.common.base.Charsets; +import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import org.junit.Assert; import org.junit.Test; @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception { actualRows = objectSerDe.toObject(bytes); Assert.assertEquals(expectedRows.size(), actualRows.size()); } + @Test + public void testRemoteDirectiveResponseSerDe() throws Exception { + List expectedRows = new ArrayList<>(); + Row firstRow = new Row(); + firstRow.add("id", 1); + expectedRows.add(firstRow); + Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT))); + RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema); + ObjectSerDe objectSerDe = new ObjectSerDe<>(); + + byte[] bytes = objectSerDe.toByteArray(expectedResponse); + RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes); + + Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size()); + Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema()); + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java index 9b77f23f3..d3e6d959f 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteDirectiveRequest.java @@ -15,6 +15,7 @@ */ package io.cdap.wrangler.service.directive; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.wrangler.parser.DirectiveClass; import java.util.HashMap; @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest { private final Map systemDirectives; private final String pluginNameSpace; private final byte[] data; + private final Schema inputSchema; RemoteDirectiveRequest(String recipe, Map systemDirectives, - String pluginNameSpace, byte[] data) { + String pluginNameSpace, byte[] data, Schema inputSchema) { this.recipe = recipe; this.systemDirectives = new HashMap<>(systemDirectives); this.pluginNameSpace = pluginNameSpace; this.data = data; + this.inputSchema = inputSchema; } public String getRecipe() { @@ -53,4 +56,8 @@ public byte[] getData() { public String getPluginNameSpace() { return pluginNameSpace; } + + public Schema getInputSchema() { + return inputSchema; + } } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java index 27216247f..674c82649 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/RemoteExecutionTask.java @@ -16,10 +16,14 @@ package io.cdap.wrangler.service.directive; import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.cdap.cdap.api.data.schema.Schema; import io.cdap.cdap.api.service.worker.RunnableTask; import io.cdap.cdap.api.service.worker.RunnableTaskContext; import io.cdap.cdap.api.service.worker.SystemAppTaskContext; +import io.cdap.cdap.internal.io.SchemaTypeAdapter; import io.cdap.directives.aggregates.DefaultTransientStore; +import io.cdap.directives.aggregates.DefaultTransientStoreTypeAdapter; import io.cdap.wrangler.api.Arguments; import io.cdap.wrangler.api.CompileException; import io.cdap.wrangler.api.Directive; @@ -29,7 +33,10 @@ import io.cdap.wrangler.api.ErrorRecordBase; import io.cdap.wrangler.api.ExecutorContext; import io.cdap.wrangler.api.RecipeException; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; +import io.cdap.wrangler.api.TransientStore; +import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.api.parser.UsageDefinition; import io.cdap.wrangler.executor.RecipePipelineExecutor; import io.cdap.wrangler.expression.EL; @@ -50,12 +57,18 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + /** * Task for remote execution of directives */ public class RemoteExecutionTask implements RunnableTask { - private static final Gson GSON = new Gson(); + private static final Gson GSON = new GsonBuilder() + .registerTypeAdapter(TransientStore.class, new DefaultTransientStoreTypeAdapter()) + .registerTypeAdapter(Schema.class, new SchemaTypeAdapter()) + .create(); @Override public void run(RunnableTaskContext runnableTaskContext) throws Exception { @@ -102,12 +115,18 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception { ObjectSerDe> objectSerDe = new ObjectSerDe<>(); List rows = objectSerDe.toObject(directiveRequest.getData()); + Schema inputSchema = directiveRequest.getInputSchema(); + TransientStore transientStore = new DefaultTransientStore(); + if (inputSchema != null) { + transientStore.set(TransientVariableScope.GLOBAL, INPUT_SCHEMA, inputSchema); + } + try (RecipePipelineExecutor executor = new RecipePipelineExecutor(() -> directives, new ServicePipelineContext( namespace, ExecutorContext.Environment.SERVICE, systemAppContext, - new DefaultTransientStore()))) { + transientStore))) { rows = executor.execute(rows); List errors = executor.errors().stream() .filter(ErrorRecordBase::isShownInWrangler) @@ -120,8 +139,12 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception { throw new BadRequestException(e.getMessage(), e); } + Schema outputSchema = transientStore.get(OUTPUT_SCHEMA); + RemoteDirectiveResponse response = new RemoteDirectiveResponse(rows, outputSchema); + ObjectSerDe responseSerDe = new ObjectSerDe<>(); + runnableTaskContext.setTerminateOnComplete(hasUDD.get() || EL.isUsed()); - runnableTaskContext.writeResult(objectSerDe.toByteArray(rows)); + runnableTaskContext.writeResult(responseSerDe.toByteArray(response)); } catch (DirectiveParseException | ClassNotFoundException | CompileException e) { throw new BadRequestException(e.getMessage(), e); } diff --git a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java index 82b45521e..a6165549f 100644 --- a/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java +++ b/wrangler-service/src/main/java/io/cdap/wrangler/service/directive/WorkspaceHandler.java @@ -48,6 +48,7 @@ import io.cdap.wrangler.api.DirectiveParseException; import io.cdap.wrangler.api.GrammarMigrator; import io.cdap.wrangler.api.RecipeException; +import io.cdap.wrangler.api.RemoteDirectiveResponse; import io.cdap.wrangler.api.Row; import io.cdap.wrangler.api.TransientVariableScope; import io.cdap.wrangler.parser.ConfigDirectiveContext; @@ -101,6 +102,9 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; +import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA; +import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA; + /** * V2 endpoints for workspace */ @@ -618,13 +622,18 @@ private List executeRemotely(String namespace, List>().toObject(bytes); + RemoteDirectiveResponse response = new ObjectSerDe().toObject(bytes); + if (response.getOutputSchema() != null) { + TRANSIENT_STORE.set(TransientVariableScope.GLOBAL, OUTPUT_SCHEMA, response.getOutputSchema()); + } + return response.getRows(); } private List getSample(SampleResponse sampleResponse) {