Skip to content

Commit

Permalink
feat: add ExpandRel support to core and spark
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Sep 20, 2024
1 parent e24ce6f commit 6ad28ae
Show file tree
Hide file tree
Showing 18 changed files with 385 additions and 37 deletions.
33 changes: 33 additions & 0 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import io.substrait.plan.ImmutablePlan;
import io.substrait.plan.ImmutableRoot;
import io.substrait.plan.Plan;
import io.substrait.proto.RelCommon;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.Expand;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
Expand Down Expand Up @@ -313,6 +315,37 @@ private Project project(
return Project.builder().input(input).expressions(expressions).remap(remap).build();
}

public Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel input) {
return expand(fieldsFn, Optional.empty(), Optional.empty(), input);
}

public Expand expand(
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
List<String> outputNames,
Rel input) {
return expand(fieldsFn, Optional.empty(), Optional.of(outputNames), input);
}

public Expand expand(
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
Rel.Remap remap,
List<String> outputNames,
Rel input) {
return expand(fieldsFn, Optional.of(remap), Optional.of(outputNames), input);
}

private Expand expand(
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
Optional<Rel.Remap> remap,
Optional<List<String>> outputNames,
Rel input) {
var fields = fieldsFn.apply(input);
var expand = Expand.builder().input(input).fields(fields).remap(remap);
outputNames.ifPresent(
names -> expand.hint(RelCommon.Hint.newBuilder().addAllOutputNames(names).build()));
return expand.build();
}

public Set set(Set.SetOp op, Rel... inputs) {
return set(op, Optional.empty(), inputs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ public OUTPUT visit(Project project) throws EXCEPTION {
return visitFallback(project);
}

@Override
public OUTPUT visit(Expand expand) throws EXCEPTION {
return visitFallback(expand);
}

@Override
public OUTPUT visit(Sort sort) throws EXCEPTION {
return visitFallback(sort);
Expand Down
64 changes: 64 additions & 0 deletions core/src/main/java/io/substrait/relation/Expand.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.substrait.relation;

import io.substrait.expression.Expression;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;
import org.immutables.value.Value;

@Value.Enclosing
@Value.Immutable
public abstract class Expand extends SingleInputRel {
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Expand.class);

public abstract List<ExpandField> getFields();

@Override
public Type.Struct deriveRecordType() {
Type.Struct initial = getInput().getRecordType();
return TypeCreator.of(initial.nullable())
.struct(
Stream.concat(
initial.fields().stream(),
getFields().stream()
.map(
f -> {
if (f.getSwitchingField().isPresent()) {
return f.getSwitchingField().get().getDuplicates().get(0).getType();
} else {
return f.getConsistentField().get().getType();
}
})));
}

@Override
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
return visitor.visit(this);
}

public static ImmutableExpand.Builder builder() {
return ImmutableExpand.builder();
}

@Value.Immutable
public abstract static class ExpandField {
public abstract Optional<SwitchingField> getSwitchingField();

public abstract Optional<Expression> getConsistentField();

public static ImmutableExpand.ExpandField.Builder builder() {
return ImmutableExpand.ExpandField.builder();
}
}

@Value.Immutable
public abstract static class SwitchingField {
public abstract List<Expression> getDuplicates();

public static ImmutableExpand.SwitchingField.Builder builder() {
return ImmutableExpand.SwitchingField.builder();
}
}
}
57 changes: 56 additions & 1 deletion core/src/main/java/io/substrait/relation/ProtoRelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io.substrait.proto.AggregateRel;
import io.substrait.proto.ConsistentPartitionWindowRel;
import io.substrait.proto.CrossRel;
import io.substrait.proto.ExpandRel;
import io.substrait.proto.ExtensionLeafRel;
import io.substrait.proto.ExtensionMultiRel;
import io.substrait.proto.ExtensionSingleRel;
Expand All @@ -21,6 +22,7 @@
import io.substrait.proto.NestedLoopJoinRel;
import io.substrait.proto.ProjectRel;
import io.substrait.proto.ReadRel;
import io.substrait.proto.RelCommon;
import io.substrait.proto.SetRel;
import io.substrait.proto.SortRel;
import io.substrait.relation.extensions.EmptyDetail;
Expand Down Expand Up @@ -87,6 +89,9 @@ public Rel from(io.substrait.proto.Rel rel) {
case PROJECT -> {
return newProject(rel.getProject());
}
case EXPAND -> {
return newExpand(rel.getExpand());
}
case CROSS -> {
return newCross(rel.getCross());
}
Expand Down Expand Up @@ -155,7 +160,10 @@ protected Filter newFilter(FilterRel rel) {
}

protected NamedStruct newNamedStruct(ReadRel rel) {
var namedStruct = rel.getBaseSchema();
return newNamedStruct(rel.getBaseSchema());
}

protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) {
var struct = namedStruct.getStruct();
return ImmutableNamedStruct.builder()
.names(namedStruct.getNamesList())
Expand Down Expand Up @@ -389,6 +397,43 @@ protected Project newProject(ProjectRel rel) {
return builder.build();
}

protected Expand newExpand(ExpandRel rel) {
var input = from(rel.getInput());
var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
var builder =
Expand.builder()
.input(input)
.fields(
rel.getFieldsList().stream()
.map(
expandField ->
switch (expandField.getFieldTypeCase()) {
case CONSISTENT_FIELD -> Expand.ExpandField.builder()
.consistentField(converter.from(expandField.getConsistentField()))
.build();
case SWITCHING_FIELD -> Expand.ExpandField.builder()
.switchingField(
Expand.SwitchingField.builder()
.duplicates(
expandField
.getSwitchingField()
.getDuplicatesList()
.stream()
.map(converter::from)
.collect(java.util.stream.Collectors.toList()))
.build())
.build();
case FIELDTYPE_NOT_SET -> Expand.ExpandField.builder().build();
})
.collect(java.util.stream.Collectors.toList()));

builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));
return builder.build();
}

protected Aggregate newAggregate(AggregateRel rel) {
var input = from(rel.getInput());
var protoExprConverter =
Expand Down Expand Up @@ -647,6 +692,16 @@ protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
}

protected static Optional<RelCommon.Hint> optionalHint(io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
relCommon.hasHint()
? RelCommon.Hint.newBuilder()
.setAlias(relCommon.getHint().getAlias())
.addAllOutputNames(relCommon.getHint().getOutputNamesList())
.build()
: null);
}

protected Optional<AdvancedExtension> optionalAdvancedExtension(
io.substrait.proto.RelCommon relCommon) {
return Optional.ofNullable(
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/java/io/substrait/relation/Rel.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.substrait.relation;

import io.substrait.extension.AdvancedExtension;
import io.substrait.proto.RelCommon;
import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;
Expand All @@ -21,6 +22,8 @@ public interface Rel {

List<Rel> getInputs();

Optional<RelCommon.Hint> getHint();

@Value.Immutable
public abstract static class Remap {
public abstract List<Integer> indices();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ public Optional<Rel> visit(Project project) throws EXCEPTION {
.build());
}

@Override
public Optional<Rel> visit(Expand expand) throws EXCEPTION {
throw new UnsupportedOperationException();
}

@Override
public Optional<Rel> visit(Sort sort) throws EXCEPTION {
var input = sort.getInput().accept(this);
Expand Down
Loading

0 comments on commit 6ad28ae

Please sign in to comment.