Skip to content

Commit

Permalink
Add schema handling to remote task execution
Browse files Browse the repository at this point in the history
  • Loading branch information
vanathi-g committed Apr 29, 2024
1 parent b4a1d06 commit eb981ab
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright © 2024 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package io.cdap.wrangler.api;

import io.cdap.cdap.api.data.schema.Schema;

import java.io.Serializable;
import java.util.List;

/**
* Response after executing directives remotely
*/
public class RemoteDirectiveResponse implements Serializable {
private final List<Row> rows;
private final Schema outputSchema;

public RemoteDirectiveResponse(List<Row> rows, Schema outputSchema) {
this.rows = rows;
this.outputSchema = outputSchema;
}

public List<Row> getRows() {
return rows;
}

public Schema getOutputSchema() {
return outputSchema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.cdap.wrangler.api;

import java.io.Serializable;
import java.util.Map;
import java.util.Set;

/**
Expand Down Expand Up @@ -61,4 +62,8 @@ public interface TransientStore extends Serializable {
* @return list of all the variables.
*/
Set<String> getVariables();

Map<String, Object> getGlobalVariables();

Map<String, Object> getLocalVariables();
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.cdap.directives.aggregates;

import com.google.common.collect.ImmutableMap;
import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.api.TransientVariableScope;

Expand Down Expand Up @@ -114,4 +115,14 @@ public void set(TransientVariableScope scope, String name, Object value) {
local.put(name, value);
}
}

@Override
public Map<String, Object> getGlobalVariables() {
return ImmutableMap.copyOf(global);
}

@Override
public Map<String, Object> getLocalVariables() {
return ImmutableMap.copyOf(local);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright © 2021 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.directives.aggregates;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.TypeAdapter;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.internal.io.SchemaTypeAdapter;
import io.cdap.wrangler.api.TransientVariableScope;

import java.io.IOException;
import java.util.Map;

import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA;
import static io.cdap.wrangler.schema.TransientStoreKeys.OUTPUT_SCHEMA;

/**
* GSON - JSON type adapter class for {@link DefaultTransientStore}
*/
public class DefaultTransientStoreTypeAdapter extends TypeAdapter<DefaultTransientStore> {
private static final Gson GSON = new GsonBuilder()
.registerTypeAdapter(Schema.class, new SchemaTypeAdapter())
.create();
private static final String GLOBAL_MAP_KEY = "global";
private static final String LOCAL_MAP_KEY = "local";

@Override
public void write(JsonWriter writer, DefaultTransientStore store) throws IOException {
if (store == null) {
writer.nullValue();
return;
}

writer.beginObject();
writeMaps(writer, store);
writer.endObject();
}

@Override
public DefaultTransientStore read(JsonReader reader) throws IOException {
DefaultTransientStore store = new DefaultTransientStore();
reader.beginObject();
while (reader.hasNext()) {
String name = reader.nextName();
if (name.equals(GLOBAL_MAP_KEY)) {
readMap(reader, store, TransientVariableScope.GLOBAL);
} else if (name.equals(LOCAL_MAP_KEY)) {
readMap(reader, store, TransientVariableScope.LOCAL);
}
}
reader.endObject();
return store;
}

private void writeMaps(JsonWriter writer, DefaultTransientStore store) throws IOException {
writer.name(GLOBAL_MAP_KEY).beginObject();
writeMap(writer, store.getGlobalVariables());
writer.endObject();


writer.name(LOCAL_MAP_KEY).beginObject();
writeMap(writer, store.getLocalVariables());
writer.endObject();
}

private void writeMap(JsonWriter writer, Map<String, Object> map) throws IOException {
for (Map.Entry<String, Object> entry : map.entrySet()) {
String key = entry.getKey();
writer.name(key);
if (key.equals(INPUT_SCHEMA) || key.equals(OUTPUT_SCHEMA)) {
GSON.toJson(entry.getValue(), Schema.class, writer);
} else {
GSON.toJson(entry.getValue(), Object.class, writer);
}
}
}

private void readMap(JsonReader reader, DefaultTransientStore store,
TransientVariableScope scope) throws IOException {
reader.beginObject();
while (reader.hasNext()) {
String key = reader.nextName();
if (key.equals(INPUT_SCHEMA) || key.equals(OUTPUT_SCHEMA)) {
Schema schemaValue = GSON.fromJson(reader, Schema.class);
store.set(scope, key, schemaValue);
} else {
Object value = GSON.fromJson(reader, Object.class);
store.set(scope, key, value);
}
}
reader.endObject();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright © 2021 Cask Data, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

package io.cdap.directives.aggregates;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.cdap.internal.io.SchemaTypeAdapter;
import io.cdap.wrangler.api.TransientStore;
import io.cdap.wrangler.api.TransientVariableScope;
import org.junit.BeforeClass;
import org.junit.Test;

import static io.cdap.wrangler.schema.TransientStoreKeys.INPUT_SCHEMA;
import static org.junit.Assert.assertEquals;

public class DefaultTransientStoreTypeAdapterTest {
private static final TransientStore EXPECTED_STORE = new DefaultTransientStore();
private static final String EXPECTED_JSON = "{"
+ "\"global\":{"
+ "\"num_val\":5.0"
+ "},"
+ "\"local\":{"
+ "\"string_val\":\"hello\""
+ "}"
+ "}";
private static final Gson GSON = new GsonBuilder()
.registerTypeAdapter(TransientStore.class, new DefaultTransientStoreTypeAdapter())
.registerTypeAdapter(Schema.class, new SchemaTypeAdapter())
.create();

@BeforeClass
public static void setup() {
EXPECTED_STORE.set(TransientVariableScope.GLOBAL, "num_val", 5.0);
EXPECTED_STORE.set(TransientVariableScope.LOCAL, "string_val", "hello");
}

@Test
public void testWrite() {
TransientStore store = new DefaultTransientStore();
store.set(TransientVariableScope.GLOBAL, "num_val", 5.0);
store.set(TransientVariableScope.LOCAL, "string_val", "hello");

String actualJson = GSON.toJson(store);

assertEquals(EXPECTED_JSON, actualJson);
}

@Test
public void testRead() {
TransientStore actualStore = GSON.fromJson(EXPECTED_JSON, TransientStore.class);

assertEquals(EXPECTED_STORE.getGlobalVariables(), actualStore.getGlobalVariables());
assertEquals(EXPECTED_STORE.getLocalVariables(), actualStore.getLocalVariables());
}

@Test
public void testSchema() {
TransientStore store = new DefaultTransientStore();
Schema inputSchema = Schema.recordOf(
"inputSchema",
Schema.Field.of("int_col", Schema.of(Schema.Type.INT))
);
store.set(TransientVariableScope.LOCAL, INPUT_SCHEMA, inputSchema);

String actualJson = GSON.toJson(store);
TransientStore actualStore = GSON.fromJson(actualJson, TransientStore.class);

assertEquals(inputSchema, actualStore.get(INPUT_SCHEMA));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ public void increment(TransientVariableScope scope, String name, long value) {
public Set<String> getVariables() {
return s.keySet();
}

@Override
public Map<String, Object> getGlobalVariables() {
return null;
}

@Override
public Map<String, Object> getLocalVariables() {
return null;
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package io.cdap.wrangler.utils;

import com.google.common.base.Charsets;
import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.api.RemoteDirectiveResponse;
import io.cdap.wrangler.api.Row;
import org.junit.Assert;
import org.junit.Test;
Expand Down Expand Up @@ -85,4 +87,20 @@ public void testLogicalTypeSerDe() throws Exception {
actualRows = objectSerDe.toObject(bytes);
Assert.assertEquals(expectedRows.size(), actualRows.size());
}
@Test
public void testRemoteDirectiveResponseSerDe() throws Exception {
List<Row> expectedRows = new ArrayList<>();
Row firstRow = new Row();
firstRow.add("id", 1);
expectedRows.add(firstRow);
Schema expectedSchema = Schema.recordOf(Schema.Field.of("id", Schema.of(Schema.Type.INT)));
RemoteDirectiveResponse expectedResponse = new RemoteDirectiveResponse(expectedRows, expectedSchema);
ObjectSerDe<RemoteDirectiveResponse> objectSerDe = new ObjectSerDe<>();

byte[] bytes = objectSerDe.toByteArray(expectedResponse);
RemoteDirectiveResponse actualResponse = objectSerDe.toObject(bytes);

Assert.assertEquals(expectedResponse.getRows().size(), actualResponse.getRows().size());
Assert.assertEquals(expectedResponse.getOutputSchema(), actualResponse.getOutputSchema());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.cdap.wrangler.service.directive;

import io.cdap.cdap.api.data.schema.Schema;
import io.cdap.wrangler.parser.DirectiveClass;

import java.util.HashMap;
Expand All @@ -29,13 +30,15 @@ public class RemoteDirectiveRequest {
private final Map<String, DirectiveClass> systemDirectives;
private final String pluginNameSpace;
private final byte[] data;
private final Schema inputSchema;

RemoteDirectiveRequest(String recipe, Map<String, DirectiveClass> systemDirectives,
String pluginNameSpace, byte[] data) {
String pluginNameSpace, byte[] data, Schema inputSchema) {
this.recipe = recipe;
this.systemDirectives = new HashMap<>(systemDirectives);
this.pluginNameSpace = pluginNameSpace;
this.data = data;
this.inputSchema = inputSchema;
}

public String getRecipe() {
Expand All @@ -53,4 +56,8 @@ public byte[] getData() {
public String getPluginNameSpace() {
return pluginNameSpace;
}

public Schema getInputSchema() {
return inputSchema;
}
}
Loading

0 comments on commit eb981ab

Please sign in to comment.