Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[🍒] Add schema handling to remote task execution #712

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright © 2024 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.cdap.wrangler.api;

import io.cdap.cdap.api.data.schema.Schema;

import java.io.Serializable;
import java.util.List;

/**
* Response after executing directives remotely
* Please make sure all fields are registered with {@link io.cdap.wrangler.utils.KryoSerializer}
*/
public class RemoteDirectiveResponse implements Serializable {
private final List<Row> rows;
private final Schema outputSchema;

/**
* Only used by {@link io.cdap.wrangler.utils.KryoSerializer}
**/
private RemoteDirectiveResponse() {
this(null, null);
}

public RemoteDirectiveResponse(List<Row> rows, Schema outputSchema) {
this.rows = rows;
this.outputSchema = outputSchema;
}

public List<Row> getRows() {
return rows;
}

public Schema getOutputSchema() {
return outputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

/**
* TransientStoreKeys for storing Workspace schema in TransientStore
* NOTE: Please add any needed value in {@link io.cdap.wrangler.api.RemoteDirectiveResponse}
*/
public final class TransientStoreKeys {
public static final String INPUT_SCHEMA = "ws_input_schema";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import com.esotericsoftware.kryo.Serializer;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import com.esotericsoftware.kryo.serializers.JavaSerializer;
import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import java.sql.Time;
import java.sql.Timestamp;
Expand All @@ -33,20 +36,24 @@
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;

/**
* A helper class with allows Serialization and Deserialization using Kryo
* We should register all schema classes present in {@link SchemaConverter}
* and {@link RemoteDirectiveResponse}
**/
public class RowSerializer {
public class KryoSerializer {

private final Kryo kryo;
private static final Gson GSON = new Gson();

public RowSerializer() {
public KryoSerializer() {
kryo = new Kryo();
// Register all classes from RemoteDirectiveResponse
kryo.register(RemoteDirectiveResponse.class);
// Schema does not have no-arg constructor but implements Serializable
kryo.register(Schema.class, new JavaSerializer());
// Register all classes from SchemaConverter
kryo.register(Row.class);
kryo.register(ArrayList.class);
Expand All @@ -56,7 +63,7 @@ public RowSerializer() {
kryo.register(Map.class);
kryo.register(JsonNull.class);
// JsonPrimitive does not have no-arg constructor hence we need a
// custom serializer
// custom serializer as it is not serializable by JavaSerializer
kryo.register(JsonPrimitive.class, new JsonSerializer());
kryo.register(JsonArray.class);
kryo.register(JsonObject.class);
Expand All @@ -67,16 +74,15 @@ public RowSerializer() {
kryo.register(Timestamp.class);
}

public byte[] fromRows(List<Row> rows) {
public byte[] fromRemoteDirectiveResponse(RemoteDirectiveResponse response) {
Output output = new Output(1024, -1);
kryo.writeClassAndObject(output, rows);
kryo.writeClassAndObject(output, response);
return output.getBuffer();
}

public List<Row> toRows(byte[] bytes) {
public RemoteDirectiveResponse toRemoteDirectiveResponse(byte[] bytes) {
Input input = new Input(bytes);
List<Row> result = (List<Row>) kryo.readClassAndObject(input);
return result;
return (RemoteDirectiveResponse) kryo.readClassAndObject(input);
}

static class JsonSerializer extends Serializer<JsonElement> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.data.schema.Schema.Field;
Expand Down Expand Up @@ -100,7 +99,7 @@ public Schema getSchema(Object value, String name) throws RecordConvertorExcepti
* @param name name of the field
* @param recordPrefix prefix to append at the beginning of a custom record
* @return the schema of this object
* NOTE: ANY NEWLY SUPPORTED DATATYPE SHOULD ALSO BE REGISTERED IN {@link RowSerializer}
* NOTE: ANY NEWLY SUPPORTED DATATYPE SHOULD ALSO BE REGISTERED IN {@link KryoSerializer}
*/
@Nullable
public Schema getSchema(Object value, String name, @Nullable String recordPrefix) throws RecordConvertorException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import com.google.common.collect.Lists;
import com.google.gson.JsonParser;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.TestingRig;
import io.cdap.wrangler.api.RecipePipeline;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -37,7 +39,7 @@
import java.util.Map;
import java.util.Set;

public class RowSerializerTest {
public class KryoSerializerTest {

private static final String[] TESTS = new String[]{
JsonTestData.BASIC,
Expand Down Expand Up @@ -67,8 +69,9 @@ public void testJsonTypes() throws Exception {
Row row = new Row("body", test);

List<Row> expectedRows = executor.execute(Lists.newArrayList(row));
byte[] serializedRows = new RowSerializer().fromRows(expectedRows);
List<Row> gotRows = new RowSerializer().toRows(serializedRows);
byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, null));
List<Row> gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows();
Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray());
}
}
Expand All @@ -84,8 +87,9 @@ public void testLogicalTypes() throws Exception {
testRow.add("bigdecimal", new BigDecimal(new BigInteger("123456"), 5));
testRow.add("datetime", LocalDateTime.now());
List<Row> expectedRows = Collections.singletonList(testRow);
byte[] serializedRows = new RowSerializer().fromRows(expectedRows);
List<Row> gotRows = new RowSerializer().toRows(serializedRows);
byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, null));
List<Row> gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows();
Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray());
}

Expand All @@ -110,8 +114,33 @@ public void testCollectionTypes() throws Exception {
testRow.add("map", map);

List<Row> expectedRows = Collections.singletonList(testRow);
byte[] serializedRows = new RowSerializer().fromRows(expectedRows);
List<Row> gotRows = new RowSerializer().toRows(serializedRows);
byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, null));
List<Row> gotRows = new KryoSerializer().toRemoteDirectiveResponse(serializedRows).getRows();
Assert.assertArrayEquals(expectedRows.toArray(), gotRows.toArray());
}

@Test
public void testWithSchema() throws Exception {
Row testRow = new Row();
testRow.add("id", 1);
testRow.add("name", "abc");
testRow.add("date", LocalDate.of(2018, 11, 11));
testRow.add("time", LocalTime.of(11, 11, 11));
testRow.add("timestamp", ZonedDateTime.of(2018, 11, 11, 11, 11, 11, 0, ZoneId.of("UTC")));
testRow.add("bigdecimal", new BigDecimal(new BigInteger("123456"), 5));
testRow.add("datetime", LocalDateTime.now());
List<Row> expectedRows = Collections.singletonList(testRow);

SchemaConverter converter = new SchemaConverter();
Schema expectedSchema = converter.toSchema("myrecord", expectedRows.get(0));

byte[] serializedRows = new KryoSerializer().fromRemoteDirectiveResponse(
new RemoteDirectiveResponse(expectedRows, expectedSchema));
RemoteDirectiveResponse response = new KryoSerializer().toRemoteDirectiveResponse(
serializedRows);

Assert.assertArrayEquals(expectedRows.toArray(), response.getRows().toArray());
Assert.assertEquals(expectedSchema, response.getOutputSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.cdap.wrangler.utils;

import com.google.common.base.Charsets;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception {
actualRows = objectSerDe.toObject(bytes);
Assert.assertEquals(expectedRows.size(), actualRows.size());
}
@Test
public void testRemoteDirectiveResponseSerDe() throws Exception {
List<Row> expectedRows = new ArrayList<>();
Row firstRow = new Row();
firstRow.add("id", 1);
expectedRows.add(firstRow);
Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT)));
RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema);
ObjectSerDe<RemoteDirectiveResponse> objectSerDe = new ObjectSerDe<>();

byte[] bytes = objectSerDe.toByteArray(expectedResponse);
RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes);

Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size());
Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.cdap.wrangler.service.directive;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.parser.DirectiveClass;

import java.util.HashMap;
Expand All @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest {
private final Map<String, DirectiveClass> systemDirectives;
private final String pluginNameSpace;
private final byte[] data;
private final Schema inputSchema;

RemoteDirectiveRequest(String recipe, Map<String, DirectiveClass> systemDirectives,
String pluginNameSpace, byte[] data) {
String pluginNameSpace, byte[] data, Schema inputSchema) {
this.recipe = recipe;
this.systemDirectives = new HashMap<>(systemDirectives);
this.pluginNameSpace = pluginNameSpace;
this.data = data;
this.inputSchema = inputSchema;
}

public String getRecipe() {
Expand All @@ -53,4 +56,8 @@ public byte[] getData() {
public String getPluginNameSpace() {
return pluginNameSpace;
}

public Schema getInputSchema() {
return inputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
package io.cdap.wrangler.service.directive;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.api.service.worker.RunnableTask;
import io.cdap.cdap.api.service.worker.RunnableTaskContext;
import io.cdap.cdap.api.service.worker.SystemAppTaskContext;
import io.cdap.cdap.features.Feature;
import io.cdap.cdap.internal.io.SchemaTypeAdapter;
import io.cdap.directives.aggregates.DefaultTransientStore;
import io.cdap.wrangler.api.Arguments;
import io.cdap.wrangler.api.CompileException;
Expand All @@ -30,7 +33,10 @@
import io.cdap.wrangler.api.ErrorRecordBase;
import io.cdap.wrangler.api.ExecutorContext;
import io.cdap.wrangler.api.RecipeException;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.api.TransientVariableScope;
import io.cdap.wrangler.api.parser.UsageDefinition;
import io.cdap.wrangler.executor.RecipePipelineExecutor;
import io.cdap.wrangler.expression.EL;
Expand All @@ -43,22 +49,25 @@
import io.cdap.wrangler.proto.ErrorRecordsException;
import io.cdap.wrangler.registry.DirectiveInfo;
import io.cdap.wrangler.registry.UserDirectiveRegistry;
import io.cdap.wrangler.utils.KryoSerializer;
import io.cdap.wrangler.utils.ObjectSerDe;

import io.cdap.wrangler.utils.RowSerializer;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA;
import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA;

/**
* Task for remote execution of directives
*/
public class RemoteExecutionTask implements RunnableTask {

private static final Gson GSON = new Gson();

private static final Gson GSON = new GsonBuilder()
.registerTypeAdapter(Schema.class, new SchemaTypeAdapter())
.create();

@Override
public void run(RunnableTaskContext runnableTaskContext) throws Exception {
Expand Down Expand Up @@ -105,12 +114,18 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception {
ObjectSerDe<List<Row>> objectSerDe = new ObjectSerDe<>();
List<Row> rows = objectSerDe.toObject(directiveRequest.getData());

Schema inputSchema = directiveRequest.getInputSchema();
TransientStore transientStore = new DefaultTransientStore();
if (inputSchema != null) {
transientStore.set(TransientVariableScope.GLOBAL, INPUT_SCHEMA, inputSchema);
}

try (RecipePipelineExecutor executor = new RecipePipelineExecutor(() -> directives,
new ServicePipelineContext(
namespace,
ExecutorContext.Environment.SERVICE,
systemAppContext,
new DefaultTransientStore()))) {
transientStore))) {
rows = executor.execute(rows);
List<ErrorRecordBase> errors = executor.errors().stream()
.filter(ErrorRecordBase::isShownInWrangler)
Expand All @@ -123,12 +138,16 @@ public void run(RunnableTaskContext runnableTaskContext) throws Exception {
throw new BadRequestException(e.getMessage(), e);
}

Schema outputSchema = transientStore.get(OUTPUT_SCHEMA);
RemoteDirectiveResponse response = new RemoteDirectiveResponse(rows, outputSchema);
ObjectSerDe<RemoteDirectiveResponse> responseSerDe = new ObjectSerDe<>();

runnableTaskContext.setTerminateOnComplete(hasUDD.get() || EL.isUsed());

if (Feature.WRANGLER_KRYO_SERIALIZATION.isEnabled(systemAppContext)) {
runnableTaskContext.writeResult(new RowSerializer().fromRows(rows));
runnableTaskContext.writeResult(new KryoSerializer().fromRemoteDirectiveResponse(response));
} else {
runnableTaskContext.writeResult(objectSerDe.toByteArray(rows));
runnableTaskContext.writeResult(responseSerDe.toByteArray(response));
}
} catch (DirectiveParseException | ClassNotFoundException | CompileException e) {
throw new BadRequestException(e.getMessage(), e);
Expand Down
Loading