From 6ad28ae25d2786fb195603eae4de783224f42723 Mon Sep 17 00:00:00 2001 From: andrew-coleman Date: Tue, 27 Aug 2024 10:44:32 +0100 Subject: [PATCH] feat: add ExpandRel support to core and spark Signed-off-by: Andrew Coleman --- .../io/substrait/dsl/SubstraitBuilder.java | 33 ++++++++ .../relation/AbstractRelVisitor.java | 5 ++ .../java/io/substrait/relation/Expand.java | 64 ++++++++++++++++ .../substrait/relation/ProtoRelConverter.java | 57 +++++++++++++- .../main/java/io/substrait/relation/Rel.java | 3 + .../relation/RelCopyOnWriteVisitor.java | 5 ++ .../substrait/relation/RelProtoConverter.java | 76 ++++++++++++++----- .../io/substrait/relation/RelVisitor.java | 2 + .../type/proto/ExpandRelRoundtripTest.java | 70 +++++++++++++++++ .../type/proto/ExtensionRoundtripTest.java | 17 +++++ gradle.properties | 2 +- .../substrait/debug/RelToVerboseString.scala | 9 +++ .../spark/expression/FunctionMappings.scala | 1 + .../spark/expression/ToSparkExpression.scala | 8 +- .../spark/logical/ToLogicalPlan.scala | 30 ++++++-- .../spark/logical/ToSubstraitRel.scala | 31 ++++++-- .../scala/io/substrait/spark/TPCDSPlan.scala | 2 +- .../scala/io/substrait/spark/TPCHPlan.scala | 7 +- 18 files changed, 385 insertions(+), 37 deletions(-) create mode 100644 core/src/main/java/io/substrait/relation/Expand.java create mode 100644 core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 1a03c3027..da07fcfc2 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -19,8 +19,10 @@ import io.substrait.plan.ImmutablePlan; import io.substrait.plan.ImmutableRoot; import io.substrait.plan.Plan; +import io.substrait.proto.RelCommon; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; +import io.substrait.relation.Expand; import io.substrait.relation.Fetch; import io.substrait.relation.Filter; import io.substrait.relation.Join; @@ -313,6 +315,37 @@ private Project project( return Project.builder().input(input).expressions(expressions).remap(remap).build(); } + public Expand expand(Function> fieldsFn, Rel input) { + return expand(fieldsFn, Optional.empty(), Optional.empty(), input); + } + + public Expand expand( + Function> fieldsFn, + List outputNames, + Rel input) { + return expand(fieldsFn, Optional.empty(), Optional.of(outputNames), input); + } + + public Expand expand( + Function> fieldsFn, + Rel.Remap remap, + List outputNames, + Rel input) { + return expand(fieldsFn, Optional.of(remap), Optional.of(outputNames), input); + } + + private Expand expand( + Function> fieldsFn, + Optional remap, + Optional> outputNames, + Rel input) { + var fields = fieldsFn.apply(input); + var expand = Expand.builder().input(input).fields(fields).remap(remap); + outputNames.ifPresent( + names -> expand.hint(RelCommon.Hint.newBuilder().addAllOutputNames(names).build())); + return expand.build(); + } + public Set set(Set.SetOp op, Rel... inputs) { return set(op, Optional.empty(), inputs); } diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 088eb6c9a..582814b7a 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -53,6 +53,11 @@ public OUTPUT visit(Project project) throws EXCEPTION { return visitFallback(project); } + @Override + public OUTPUT visit(Expand expand) throws EXCEPTION { + return visitFallback(expand); + } + @Override public OUTPUT visit(Sort sort) throws EXCEPTION { return visitFallback(sort); diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java new file mode 100644 index 000000000..94f57663e --- /dev/null +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -0,0 +1,64 @@ +package io.substrait.relation; + +import io.substrait.expression.Expression; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Enclosing +@Value.Immutable +public abstract class Expand extends SingleInputRel { + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Expand.class); + + public abstract List getFields(); + + @Override + public Type.Struct deriveRecordType() { + Type.Struct initial = getInput().getRecordType(); + return TypeCreator.of(initial.nullable()) + .struct( + Stream.concat( + initial.fields().stream(), + getFields().stream() + .map( + f -> { + if (f.getSwitchingField().isPresent()) { + return f.getSwitchingField().get().getDuplicates().get(0).getType(); + } else { + return f.getConsistentField().get().getType(); + } + }))); + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableExpand.Builder builder() { + return ImmutableExpand.builder(); + } + + @Value.Immutable + public abstract static class ExpandField { + public abstract Optional getSwitchingField(); + + public abstract Optional getConsistentField(); + + public static ImmutableExpand.ExpandField.Builder builder() { + return ImmutableExpand.ExpandField.builder(); + } + } + + @Value.Immutable + public abstract static class SwitchingField { + public abstract List getDuplicates(); + + public static ImmutableExpand.SwitchingField.Builder builder() { + return ImmutableExpand.SwitchingField.builder(); + } + } +} diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index f728988a6..5110eefbd 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -10,6 +10,7 @@ import io.substrait.proto.AggregateRel; import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; +import io.substrait.proto.ExpandRel; import io.substrait.proto.ExtensionLeafRel; import io.substrait.proto.ExtensionMultiRel; import io.substrait.proto.ExtensionSingleRel; @@ -21,6 +22,7 @@ import io.substrait.proto.NestedLoopJoinRel; import io.substrait.proto.ProjectRel; import io.substrait.proto.ReadRel; +import io.substrait.proto.RelCommon; import io.substrait.proto.SetRel; import io.substrait.proto.SortRel; import io.substrait.relation.extensions.EmptyDetail; @@ -87,6 +89,9 @@ public Rel from(io.substrait.proto.Rel rel) { case PROJECT -> { return newProject(rel.getProject()); } + case EXPAND -> { + return newExpand(rel.getExpand()); + } case CROSS -> { return newCross(rel.getCross()); } @@ -155,7 +160,10 @@ protected Filter newFilter(FilterRel rel) { } protected NamedStruct newNamedStruct(ReadRel rel) { - var namedStruct = rel.getBaseSchema(); + return newNamedStruct(rel.getBaseSchema()); + } + + protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) { var struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) @@ -389,6 +397,43 @@ protected Project newProject(ProjectRel rel) { return builder.build(); } + protected Expand newExpand(ExpandRel rel) { + var input = from(rel.getInput()); + var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + var builder = + Expand.builder() + .input(input) + .fields( + rel.getFieldsList().stream() + .map( + expandField -> + switch (expandField.getFieldTypeCase()) { + case CONSISTENT_FIELD -> Expand.ExpandField.builder() + .consistentField(converter.from(expandField.getConsistentField())) + .build(); + case SWITCHING_FIELD -> Expand.ExpandField.builder() + .switchingField( + Expand.SwitchingField.builder() + .duplicates( + expandField + .getSwitchingField() + .getDuplicatesList() + .stream() + .map(converter::from) + .collect(java.util.stream.Collectors.toList())) + .build()) + .build(); + case FIELDTYPE_NOT_SET -> Expand.ExpandField.builder().build(); + }) + .collect(java.util.stream.Collectors.toList())); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + return builder.build(); + } + protected Aggregate newAggregate(AggregateRel rel) { var input = from(rel.getInput()); var protoExprConverter = @@ -647,6 +692,16 @@ protected static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); } + protected static Optional optionalHint(io.substrait.proto.RelCommon relCommon) { + return Optional.ofNullable( + relCommon.hasHint() + ? RelCommon.Hint.newBuilder() + .setAlias(relCommon.getHint().getAlias()) + .addAllOutputNames(relCommon.getHint().getOutputNamesList()) + .build() + : null); + } + protected Optional optionalAdvancedExtension( io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( diff --git a/core/src/main/java/io/substrait/relation/Rel.java b/core/src/main/java/io/substrait/relation/Rel.java index 159feb3d5..6acd4df85 100644 --- a/core/src/main/java/io/substrait/relation/Rel.java +++ b/core/src/main/java/io/substrait/relation/Rel.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.extension.AdvancedExtension; +import io.substrait.proto.RelCommon; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; @@ -21,6 +22,8 @@ public interface Rel { List getInputs(); + Optional getHint(); + @Value.Immutable public abstract static class Remap { public abstract List indices(); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 2514eafe0..01b8f88bc 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -201,6 +201,11 @@ public Optional visit(Project project) throws EXCEPTION { .build()); } + @Override + public Optional visit(Expand expand) throws EXCEPTION { + throw new UnsupportedOperationException(); + } + @Override public Optional visit(Sort sort) throws EXCEPTION { var input = sort.getInput().accept(this); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 44dcc681c..6912fa42d 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -10,6 +10,7 @@ import io.substrait.proto.AggregateRel; import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; +import io.substrait.proto.ExpandRel; import io.substrait.proto.ExtensionLeafRel; import io.substrait.proto.ExtensionMultiRel; import io.substrait.proto.ExtensionSingleRel; @@ -50,7 +51,7 @@ public RelProtoConverter(ExtensionCollector functionCollector) { } private List toProto(Collection expressions) { - return expressions.stream().map(this::toProto).collect(java.util.stream.Collectors.toList()); + return expressions.stream().map(this::toProto).collect(Collectors.toList()); } private io.substrait.proto.Expression toProto(Expression expression) { @@ -74,7 +75,7 @@ private List toProtoS(Collection sorts) { .setExpr(toProto(s.expr())) .build(); }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); } private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) { @@ -88,13 +89,9 @@ public Rel visit(Aggregate aggregate) throws RuntimeException { .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) .addAllGroupings( - aggregate.getGroupings().stream() - .map(this::toProto) - .collect(java.util.stream.Collectors.toList())) + aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())) .addAllMeasures( - aggregate.getMeasures().stream() - .map(this::toProto) - .collect(java.util.stream.Collectors.toList())); + aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList())); aggregate.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setAggregate(builder).build(); @@ -113,14 +110,14 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .addAllArguments( IntStream.range(0, args.size()) .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .addAllSorts(toProtoS(measure.getFunction().sort())) .setFunctionReference( functionCollector.getFunctionReference(measure.getFunction().declaration())) .addAllOptions( measure.getFunction().options().stream() .map(ExpressionProtoConverter::from) - .collect(java.util.stream.Collectors.toList())); + .collect(Collectors.toList())); var builder = AggregateRel.Measure.newBuilder().setMeasure(func); @@ -226,7 +223,7 @@ public Rel visit(LocalFiles localFiles) throws RuntimeException { .addAllItems( localFiles.getItems().stream() .map(FileOrFiles::toProto) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .build()) .setBaseSchema(localFiles.getInitialSchema().toProto(typeProtoConverter)); localFiles.getFilter().ifPresent(t -> builder.setFilter(toProto(t))); @@ -350,7 +347,7 @@ private List toProtoWindowRelFun var options = f.options().stream() .map(ExpressionProtoConverter::from) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); return ConsistentPartitionWindowRel.WindowRelFunction.newBuilder() .setInvocation(f.invocation().toProto()) @@ -364,7 +361,7 @@ private List toProtoWindowRelFun .setUpperBound(BoundConverter.convert(f.upperBound())) .build(); }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); } @Override @@ -374,14 +371,50 @@ public Rel visit(Project project) throws RuntimeException { .setCommon(common(project)) .setInput(toProto(project.getInput())) .addAllExpressions( - project.getExpressions().stream() - .map(this::toProto) - .collect(java.util.stream.Collectors.toList())); + project.getExpressions().stream().map(this::toProto).collect(Collectors.toList())); project.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setProject(builder).build(); } + @Override + public Rel visit(Expand expand) throws RuntimeException { + var builder = + ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput())); + + expand + .getFields() + .forEach( + expandField -> { + expandField + .getConsistentField() + .ifPresent( + expression -> + builder.addFields( + ExpandRel.ExpandField.newBuilder() + .setConsistentField(toProto(expression)) + .build())); + expandField + .getSwitchingField() + .ifPresent( + switchingField -> + builder.addFields( + ExpandRel.ExpandField.newBuilder() + .setSwitchingField( + ExpandRel.SwitchingField.newBuilder() + .addAllDuplicates( + expandField + .getSwitchingField() + .get() + .getDuplicates() + .stream() + .map(this::toProto) + .collect(Collectors.toList()))) + .build())); + }); + return Rel.newBuilder().setExpand(builder).build(); + } + @Override public Rel visit(Sort sort) throws RuntimeException { var builder = @@ -417,7 +450,7 @@ public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException { virtualTableScan.getRows().stream() .map(this::toProto) .map(t -> t.getLiteral().getStruct()) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .build()) .setBaseSchema(virtualTableScan.getInitialSchema().toProto(typeProtoConverter)); @@ -469,6 +502,15 @@ private RelCommon common(io.substrait.relation.Rel rel) { } else { builder.setDirect(RelCommon.Direct.getDefaultInstance()); } + + rel.getHint() + .ifPresent( + md -> + builder.setHint( + RelCommon.Hint.newBuilder() + .addAllOutputNames(md.getOutputNamesList()) + .build())); + return builder.build(); } } diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index 28c5fe0c8..799e58bc1 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -23,6 +23,8 @@ public interface RelVisitor { OUTPUT visit(Project project) throws EXCEPTION; + OUTPUT visit(Expand expand) throws EXCEPTION; + OUTPUT visit(Sort sort) throws EXCEPTION; OUTPUT visit(Cross cross) throws EXCEPTION; diff --git a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java new file mode 100644 index 000000000..016827f7f --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java @@ -0,0 +1,70 @@ +package io.substrait.type.proto; + +import io.substrait.TestBase; +import io.substrait.expression.*; +import io.substrait.proto.RelCommon; +import io.substrait.relation.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; + +public class ExpandRelRoundtripTest extends TestBase { + final Rel input = + b.namedScan( + Stream.of("a_table").collect(Collectors.toList()), + Stream.of("column1", "column2").collect(Collectors.toList()), + Stream.of(R.I64, R.I64).collect(Collectors.toList())); + + private Expand.ExpandField getConsistentField(int index) { + return Expand.ExpandField.builder().consistentField(b.fieldReference(input, index)).build(); + } + + private Expand.ExpandField getSwitchingField(List indexes) { + return Expand.ExpandField.builder() + .switchingField( + Expand.SwitchingField.builder() + .addAllDuplicates( + indexes.stream() + .map(index -> b.fieldReference(input, index)) + .collect(Collectors.toList())) + .build()) + .build(); + } + + @Test + void expandConsistent() { + Rel rel = + Expand.builder() + .from(b.expand(__ -> Collections.emptyList(), input)) + .hint( + RelCommon.Hint.newBuilder() + .addAllOutputNames(Arrays.asList("name1", "name2")) + .build()) + .fields( + Stream.of(getConsistentField(0), getConsistentField(1)) + .collect(Collectors.toList())) + .build(); + verifyRoundTrip(rel); + } + + @Test + void expandSwitching() { + Rel rel = + Expand.builder() + .from(b.expand(__ -> Collections.emptyList(), input)) + .hint( + RelCommon.Hint.newBuilder() + .addAllOutputNames(Arrays.asList("name1", "name2")) + .build()) + .fields( + Stream.of( + getSwitchingField(Arrays.asList(0, 1)), + getSwitchingField(Arrays.asList(1, 0))) + .collect(Collectors.toList())) + .build(); + verifyRoundTrip(rel); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 06ed71dec..912fb2aea 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -5,8 +5,10 @@ import io.substrait.TestBase; import io.substrait.expression.Expression; import io.substrait.extension.AdvancedExtension; +import io.substrait.proto.RelCommon; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; +import io.substrait.relation.Expand; import io.substrait.relation.ExtensionLeaf; import io.substrait.relation.ExtensionMulti; import io.substrait.relation.ExtensionSingle; @@ -30,6 +32,7 @@ import io.substrait.type.NamedStruct; import io.substrait.type.Type; import io.substrait.type.TypeCreator; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.stream.Collectors; @@ -233,6 +236,20 @@ void project() { verifyRoundTrip(rel); } + @Test + void expand() { + Rel rel = + Expand.builder() + .from(b.expand(__ -> Collections.emptyList(), commonTable)) + .commonExtension(commonExtension) + .hint( + RelCommon.Hint.newBuilder() + .addAllOutputNames(Arrays.asList("name1", "name2")) + .build()) + .build(); + verifyRoundTrip(rel); + } + @Test void set() { Rel rel = diff --git a/gradle.properties b/gradle.properties index 2dcd60a56..b21824932 100644 --- a/gradle.properties +++ b/gradle.properties @@ -20,7 +20,7 @@ guava.version=32.1.3-jre immutables.version=2.10.1 jackson.version=2.16.1 junit.version=5.8.1 -protobuf.version=3.25.3 +protobuf.version=3.25.5 slf4j.version=2.0.13 sparkbundle.version=3.4 spark.version=3.4.2 diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 9f4f5c9f8..0ba749b9e 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -131,6 +131,15 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(expand: Expand): String = { + withBuilder(expand, 8)( + builder => { + builder + .append("fields=") + .append(expand.getFields) + }) + } + override def visit(aggregate: Aggregate): String = { withBuilder(aggregate, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index 08326454e..4bcae8dd9 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -67,6 +67,7 @@ class FunctionMappings { s[Count]("count"), s[Min]("min"), s[Max]("max"), + s[First]("any_value"), s[HyperLogLogPlusPlus]("approx_count_distinct") ) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index a4ee9aaee..e928689fa 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -18,11 +18,9 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan - import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery} -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.{Decimal, NullType} import org.apache.spark.unsafe.types.UTF8String - import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} @@ -77,6 +75,10 @@ class ToSparkExpression( Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.NullLiteral): Expression = { + Literal(null, ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.Cast): Expression = { val childExp = expr.input().accept(this) Cast(childExp, ToSubstraitType.convert(expr.getType)) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 908a8aa0d..886e71b70 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -18,7 +18,6 @@ package io.substrait.spark.logical import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ - import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ @@ -29,9 +28,8 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.types.{DataTypes, IntegerType, StructType} - -import io.substrait.`type`.{StringTypeVisitor, Type} +import org.apache.spark.sql.types.{DataTypes, IntegerType, Metadata, StructField, StructType} +import io.substrait.`type`.{NamedStruct, StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.plan.Plan @@ -82,11 +80,15 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] ) throw new IllegalArgumentException(msg) }) + + val filter = Option(measure.getPreMeasureFilter.orElse(null)) + .map(_.accept(expressionConverter)) + AggregateExpression( aggregateFunction, ToAggregateFunction.toSpark(function.aggregationPhase()), ToAggregateFunction.toSpark(function.invocation()), - None + filter ) } @@ -193,6 +195,24 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } } + override def visit(expand: relation.Expand): LogicalPlan = { + val child = expand.getInput.accept(this) + val names = expand.getHint.get().getOutputNamesList.asScala + + withChild(child) { + val projections = expand.getFields.asScala + .map(field => field.getSwitchingField.get.getDuplicates.asScala + .map(expr => expr.accept(expressionConverter)) + .map(toNamedExpression)) + + val output = projections.head.zip(names) + .map { case (t, name) => StructField(name, t.dataType, t.nullable) } + .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + Expand(projections, output, child) + } + } + override def visit(filter: relation.Filter): LogicalPlan = { val child = filter.getInput.accept(this) withChild(child) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 4085860f4..801d38c09 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -18,7 +18,6 @@ package io.substrait.spark.logical import io.substrait.spark.{SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ @@ -29,18 +28,18 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRela import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.types.StructType import ToSubstraitType.toNamedStruct - +import io.substrait.`type`.{NamedStruct, Type} import io.substrait.{proto, relation} import io.substrait.debug.TreePrinter -import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.expression.{ExpressionCreator, Expression => SExpression} import io.substrait.extension.ExtensionCollector import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan} +import io.substrait.proto.{DerivationExpression, RelCommon} import io.substrait.relation.RelProtoConverter import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} import io.substrait.relation.files.FileOrFiles.PathType -import java.util.Collections - +import java.util.{Collections, Optional} import scala.collection.JavaConverters.asJavaIterableConverter import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -72,7 +71,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { val substraitExps = expression.aggregateFunction.children.map(toExpression(output)) val invocation = SparkExtension.toAggregateFunction.apply(expression, substraitExps) - relation.Aggregate.Measure.builder.function(invocation).build() + val filter = expression.filter map toExpression(output) + relation.Aggregate.Measure.builder.function(invocation).preMeasureFilter(Optional.ofNullable(filter.orNull)).build() } private def collectAggregates( @@ -236,6 +236,25 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { .build() } + override def visitExpand(p: Expand): relation.Rel = { + val fields = p.projections.map(proj => { + relation.Expand.ExpandField.builder.switchingField( + relation.Expand.SwitchingField.builder.duplicates( + proj.map(toExpression(p.child.output)).asJava + ).build() + ).build() + }) + + val names = p.output.map(_.name) + + relation.Expand.builder + .remap(relation.Rel.Remap.offset(p.child.output.size, names.size)) + .fields(fields.asJava) + .hint(RelCommon.Hint.newBuilder.addAllOutputNames(names.asJava).build()) + .input(visit(p.child)) + .build() + } + private def toSortField(output: Seq[Attribute] = Nil)(order: SortOrder): SExpression.SortField = { val direction = (order.direction, order.nullOrdering) match { case (Ascending, NullsFirst) => SExpression.SortDirection.ASC_NULLS_FIRST diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 7cfb3cd2d..fd35b551a 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -32,7 +32,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } // "q9" failed in spark 3.3 - val successfulSQL: Set[String] = Set("q41", "q62", "q93", "q96", "q99") + val successfulSQL: Set[String] = Set("q4", "q7", "q18", "q22", "q26", "q28", "q29", "q37", "q41", "q48", "q50", "q62", "q69", "q82", "q85", "q88", "q90", "q93", "q96", "q97", "q99") tpcdsQueries.foreach { q => diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index df5ca4f81..224ac2e8d 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -16,6 +16,7 @@ */ package io.substrait.spark +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} import org.apache.spark.sql.TPCHBase class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { @@ -101,7 +102,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { " from lineitem group by l_partkey + l_orderkey") } - ignore("avg(distinct)") { + test("avg(distinct)") { assertSqlSubstraitRelRoundTrip( "select l_partkey, sum(l_tax), sum(distinct l_tax)," + " avg(l_discount), avg(distinct l_discount) from lineitem group by l_partkey") @@ -112,7 +113,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "select l_partkey, sum(l_extendedprice * (1.0-l_discount)) from lineitem group by l_partkey") } - ignore("simpleTestAggFilter") { + test("simpleTestAggFilter") { assertSqlSubstraitRelRoundTrip( "select sum(l_tax) filter(WHERE l_orderkey > l_partkey) from lineitem") // cast is added to avoid the difference by implicit cast @@ -149,7 +150,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { " where l_shipdate < date '1998-01-01' ") } - ignore("simpleTestGroupingSets [has Expand]") { + test("simpleTestGroupingSets [has Expand]") { assertSqlSubstraitRelRoundTrip( "select sum(l_discount) from lineitem group by grouping sets " + "((l_orderkey, L_COMMITDATE), l_shipdate)")