diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 1a03c3027..fb462d796 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -21,6 +21,7 @@ import io.substrait.plan.Plan; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; +import io.substrait.relation.Expand; import io.substrait.relation.Fetch; import io.substrait.relation.Filter; import io.substrait.relation.Join; @@ -313,6 +314,23 @@ private Project project( return Project.builder().input(input).expressions(expressions).remap(remap).build(); } + public Expand expand(Function> fieldsFn, Rel input) { + return expand(fieldsFn, Optional.empty(), input); + } + + public Expand expand( + Function> fieldsFn, Rel.Remap remap, Rel input) { + return expand(fieldsFn, Optional.of(remap), input); + } + + private Expand expand( + Function> fieldsFn, + Optional remap, + Rel input) { + var fields = fieldsFn.apply(input); + return Expand.builder().input(input).fields(fields).remap(remap).build(); + } + public Set set(Set.SetOp op, Rel... inputs) { return set(op, Optional.empty(), inputs); } diff --git a/core/src/main/java/io/substrait/hint/Hint.java b/core/src/main/java/io/substrait/hint/Hint.java new file mode 100644 index 000000000..238bf44e7 --- /dev/null +++ b/core/src/main/java/io/substrait/hint/Hint.java @@ -0,0 +1,23 @@ +package io.substrait.hint; + +import io.substrait.proto.RelCommon; +import java.util.List; +import java.util.Optional; +import org.immutables.value.Value; + +@Value.Immutable +public abstract class Hint { + public abstract Optional getAlias(); + + public abstract List getOutputNames(); + + public RelCommon.Hint toProto() { + var builder = RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames()); + getAlias().ifPresent(builder::setAlias); + return builder.build(); + } + + public static ImmutableHint.Builder builder() { + return ImmutableHint.builder(); + } +} diff --git a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java index 088eb6c9a..582814b7a 100644 --- a/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java +++ b/core/src/main/java/io/substrait/relation/AbstractRelVisitor.java @@ -53,6 +53,11 @@ public OUTPUT visit(Project project) throws EXCEPTION { return visitFallback(project); } + @Override + public OUTPUT visit(Expand expand) throws EXCEPTION { + return visitFallback(expand); + } + @Override public OUTPUT visit(Sort sort) throws EXCEPTION { return visitFallback(sort); diff --git a/core/src/main/java/io/substrait/relation/Expand.java b/core/src/main/java/io/substrait/relation/Expand.java new file mode 100644 index 000000000..9efff60af --- /dev/null +++ b/core/src/main/java/io/substrait/relation/Expand.java @@ -0,0 +1,62 @@ +package io.substrait.relation; + +import io.substrait.expression.Expression; +import io.substrait.type.Type; +import io.substrait.type.TypeCreator; +import java.util.List; +import java.util.stream.Stream; +import org.immutables.value.Value; + +@Value.Enclosing +@Value.Immutable +public abstract class Expand extends SingleInputRel { + static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Expand.class); + + public abstract List getFields(); + + @Override + public Type.Struct deriveRecordType() { + Type.Struct initial = getInput().getRecordType(); + return TypeCreator.of(initial.nullable()) + .struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64))); + } + + @Override + public O accept(RelVisitor visitor) throws E { + return visitor.visit(this); + } + + public static ImmutableExpand.Builder builder() { + return ImmutableExpand.builder(); + } + + public interface ExpandField { + Type getType(); + } + + @Value.Immutable + public abstract static class ConsistentField implements ExpandField { + public abstract Expression getExpression(); + + public Type getType() { + return getExpression().getType(); + } + + public static ImmutableExpand.ConsistentField.Builder builder() { + return ImmutableExpand.ConsistentField.builder(); + } + } + + @Value.Immutable + public abstract static class SwitchingField implements ExpandField { + public abstract List getDuplicates(); + + public Type getType() { + return getDuplicates().get(0).getType(); + } + + public static ImmutableExpand.SwitchingField.Builder builder() { + return ImmutableExpand.SwitchingField.builder(); + } + } +} diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index f728988a6..c9bc60705 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -7,9 +7,11 @@ import io.substrait.extension.AdvancedExtension; import io.substrait.extension.ExtensionLookup; import io.substrait.extension.SimpleExtension; +import io.substrait.hint.Hint; import io.substrait.proto.AggregateRel; import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; +import io.substrait.proto.ExpandRel; import io.substrait.proto.ExtensionLeafRel; import io.substrait.proto.ExtensionMultiRel; import io.substrait.proto.ExtensionSingleRel; @@ -87,6 +89,9 @@ public Rel from(io.substrait.proto.Rel rel) { case PROJECT -> { return newProject(rel.getProject()); } + case EXPAND -> { + return newExpand(rel.getExpand()); + } case CROSS -> { return newCross(rel.getCross()); } @@ -155,7 +160,10 @@ protected Filter newFilter(FilterRel rel) { } protected NamedStruct newNamedStruct(ReadRel rel) { - var namedStruct = rel.getBaseSchema(); + return newNamedStruct(rel.getBaseSchema()); + } + + protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) { var struct = namedStruct.getStruct(); return ImmutableNamedStruct.builder() .names(namedStruct.getNamesList()) @@ -389,6 +397,38 @@ protected Project newProject(ProjectRel rel) { return builder.build(); } + protected Expand newExpand(ExpandRel rel) { + var input = from(rel.getInput()); + var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this); + var builder = + Expand.builder() + .input(input) + .fields( + rel.getFieldsList().stream() + .map( + expandField -> + switch (expandField.getFieldTypeCase()) { + case CONSISTENT_FIELD -> Expand.ConsistentField.builder() + .expression(converter.from(expandField.getConsistentField())) + .build(); + case SWITCHING_FIELD -> Expand.SwitchingField.builder() + .duplicates( + expandField.getSwitchingField().getDuplicatesList().stream() + .map(converter::from) + .collect(java.util.stream.Collectors.toList())) + .build(); + case FIELDTYPE_NOT_SET -> throw new UnsupportedOperationException( + "Expand fields not set"); + }) + .collect(java.util.stream.Collectors.toList())); + + builder + .commonExtension(optionalAdvancedExtension(rel.getCommon())) + .remap(optionalRelmap(rel.getCommon())) + .hint(optionalHint(rel.getCommon())); + return builder.build(); + } + protected Aggregate newAggregate(AggregateRel rel) { var input = from(rel.getInput()); var protoExprConverter = @@ -647,6 +687,16 @@ protected static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); } + protected static Optional optionalHint(io.substrait.proto.RelCommon relCommon) { + if (!relCommon.hasHint()) return Optional.empty(); + var hint = relCommon.getHint(); + var builder = Hint.builder().addAllOutputNames(hint.getOutputNamesList()); + if (!hint.getAlias().isEmpty()) { + builder.alias(hint.getAlias()); + } + return Optional.of(builder.build()); + } + protected Optional optionalAdvancedExtension( io.substrait.proto.RelCommon relCommon) { return Optional.ofNullable( diff --git a/core/src/main/java/io/substrait/relation/Rel.java b/core/src/main/java/io/substrait/relation/Rel.java index 159feb3d5..1472e9f4b 100644 --- a/core/src/main/java/io/substrait/relation/Rel.java +++ b/core/src/main/java/io/substrait/relation/Rel.java @@ -1,6 +1,7 @@ package io.substrait.relation; import io.substrait.extension.AdvancedExtension; +import io.substrait.hint.Hint; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.util.List; @@ -21,6 +22,8 @@ public interface Rel { List getInputs(); + Optional getHint(); + @Value.Immutable public abstract static class Remap { public abstract List indices(); diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index 2514eafe0..01b8f88bc 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -201,6 +201,11 @@ public Optional visit(Project project) throws EXCEPTION { .build()); } + @Override + public Optional visit(Expand expand) throws EXCEPTION { + throw new UnsupportedOperationException(); + } + @Override public Optional visit(Sort sort) throws EXCEPTION { var input = sort.getInput().accept(this); diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 44dcc681c..ce0aac7a0 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -10,6 +10,7 @@ import io.substrait.proto.AggregateRel; import io.substrait.proto.ConsistentPartitionWindowRel; import io.substrait.proto.CrossRel; +import io.substrait.proto.ExpandRel; import io.substrait.proto.ExtensionLeafRel; import io.substrait.proto.ExtensionMultiRel; import io.substrait.proto.ExtensionSingleRel; @@ -50,7 +51,7 @@ public RelProtoConverter(ExtensionCollector functionCollector) { } private List toProto(Collection expressions) { - return expressions.stream().map(this::toProto).collect(java.util.stream.Collectors.toList()); + return expressions.stream().map(this::toProto).collect(Collectors.toList()); } private io.substrait.proto.Expression toProto(Expression expression) { @@ -74,7 +75,7 @@ private List toProtoS(Collection sorts) { .setExpr(toProto(s.expr())) .build(); }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); } private io.substrait.proto.Expression.FieldReference toProto(FieldReference fieldReference) { @@ -88,13 +89,9 @@ public Rel visit(Aggregate aggregate) throws RuntimeException { .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) .addAllGroupings( - aggregate.getGroupings().stream() - .map(this::toProto) - .collect(java.util.stream.Collectors.toList())) + aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())) .addAllMeasures( - aggregate.getMeasures().stream() - .map(this::toProto) - .collect(java.util.stream.Collectors.toList())); + aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList())); aggregate.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setAggregate(builder).build(); @@ -113,14 +110,14 @@ private AggregateRel.Measure toProto(Aggregate.Measure measure) { .addAllArguments( IntStream.range(0, args.size()) .mapToObj(i -> args.get(i).accept(aggFuncDef, i, argVisitor)) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .addAllSorts(toProtoS(measure.getFunction().sort())) .setFunctionReference( functionCollector.getFunctionReference(measure.getFunction().declaration())) .addAllOptions( measure.getFunction().options().stream() .map(ExpressionProtoConverter::from) - .collect(java.util.stream.Collectors.toList())); + .collect(Collectors.toList())); var builder = AggregateRel.Measure.newBuilder().setMeasure(func); @@ -226,7 +223,7 @@ public Rel visit(LocalFiles localFiles) throws RuntimeException { .addAllItems( localFiles.getItems().stream() .map(FileOrFiles::toProto) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .build()) .setBaseSchema(localFiles.getInitialSchema().toProto(typeProtoConverter)); localFiles.getFilter().ifPresent(t -> builder.setFilter(toProto(t))); @@ -350,7 +347,7 @@ private List toProtoWindowRelFun var options = f.options().stream() .map(ExpressionProtoConverter::from) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); return ConsistentPartitionWindowRel.WindowRelFunction.newBuilder() .setInvocation(f.invocation().toProto()) @@ -364,7 +361,7 @@ private List toProtoWindowRelFun .setUpperBound(BoundConverter.convert(f.upperBound())) .build(); }) - .collect(java.util.stream.Collectors.toList()); + .collect(Collectors.toList()); } @Override @@ -374,14 +371,45 @@ public Rel visit(Project project) throws RuntimeException { .setCommon(common(project)) .setInput(toProto(project.getInput())) .addAllExpressions( - project.getExpressions().stream() - .map(this::toProto) - .collect(java.util.stream.Collectors.toList())); + project.getExpressions().stream().map(this::toProto).collect(Collectors.toList())); project.getExtension().ifPresent(ae -> builder.setAdvancedExtension(ae.toProto())); return Rel.newBuilder().setProject(builder).build(); } + @Override + public Rel visit(Expand expand) throws RuntimeException { + var builder = + ExpandRel.newBuilder().setCommon(common(expand)).setInput(toProto(expand.getInput())); + + expand + .getFields() + .forEach( + expandField -> { + if (expandField instanceof Expand.ConsistentField cf) { + builder.addFields( + ExpandRel.ExpandField.newBuilder() + .setConsistentField(toProto(cf.getExpression())) + .build()); + + } else if (expandField instanceof Expand.SwitchingField sf) { + builder.addFields( + ExpandRel.ExpandField.newBuilder() + .setSwitchingField( + ExpandRel.SwitchingField.newBuilder() + .addAllDuplicates( + sf.getDuplicates().stream() + .map(this::toProto) + .collect(Collectors.toList()))) + .build()); + } else { + throw new RuntimeException( + "Consistent or Switching fields must be set for the Expand relation."); + } + }); + return Rel.newBuilder().setExpand(builder).build(); + } + @Override public Rel visit(Sort sort) throws RuntimeException { var builder = @@ -417,7 +445,7 @@ public Rel visit(VirtualTableScan virtualTableScan) throws RuntimeException { virtualTableScan.getRows().stream() .map(this::toProto) .map(t -> t.getLiteral().getStruct()) - .collect(java.util.stream.Collectors.toList())) + .collect(Collectors.toList())) .build()) .setBaseSchema(virtualTableScan.getInitialSchema().toProto(typeProtoConverter)); @@ -469,6 +497,9 @@ private RelCommon common(io.substrait.relation.Rel rel) { } else { builder.setDirect(RelCommon.Direct.getDefaultInstance()); } + + rel.getHint().ifPresent(md -> builder.setHint(md.toProto())); + return builder.build(); } } diff --git a/core/src/main/java/io/substrait/relation/RelVisitor.java b/core/src/main/java/io/substrait/relation/RelVisitor.java index 28c5fe0c8..799e58bc1 100644 --- a/core/src/main/java/io/substrait/relation/RelVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelVisitor.java @@ -23,6 +23,8 @@ public interface RelVisitor { OUTPUT visit(Project project) throws EXCEPTION; + OUTPUT visit(Expand expand) throws EXCEPTION; + OUTPUT visit(Sort sort) throws EXCEPTION; OUTPUT visit(Cross cross) throws EXCEPTION; diff --git a/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java new file mode 100644 index 000000000..ca7c1d2a1 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/ExpandRelRoundtripTest.java @@ -0,0 +1,65 @@ +package io.substrait.type.proto; + +import io.substrait.TestBase; +import io.substrait.hint.Hint; +import io.substrait.relation.Expand; +import io.substrait.relation.Rel; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; + +public class ExpandRelRoundtripTest extends TestBase { + final Rel input = + b.namedScan( + Stream.of("a_table").collect(Collectors.toList()), + Stream.of("column1", "column2").collect(Collectors.toList()), + Stream.of(R.I64, R.I64).collect(Collectors.toList())); + + private Expand.ExpandField getConsistentField(int index) { + return Expand.ConsistentField.builder().expression(b.fieldReference(input, index)).build(); + } + + private Expand.ExpandField getSwitchingField(List indexes) { + return Expand.SwitchingField.builder() + .addAllDuplicates( + indexes.stream() + .map(index -> b.fieldReference(input, index)) + .collect(Collectors.toList())) + .build(); + } + + @Test + void expandConsistent() { + Rel rel = + Expand.builder() + .from(b.expand(__ -> Collections.emptyList(), input)) + .hint( + Hint.builder() + .alias("alias1") + .addAllOutputNames(Arrays.asList("name1", "name2")) + .build()) + .fields( + Stream.of(getConsistentField(0), getConsistentField(1)) + .collect(Collectors.toList())) + .build(); + verifyRoundTrip(rel); + } + + @Test + void expandSwitching() { + Rel rel = + Expand.builder() + .from(b.expand(__ -> Collections.emptyList(), input)) + .hint(Hint.builder().addAllOutputNames(Arrays.asList("name1", "name2")).build()) + .fields( + Stream.of( + getSwitchingField(Arrays.asList(0, 1)), + getSwitchingField(Arrays.asList(1, 0))) + .collect(Collectors.toList())) + .build(); + verifyRoundTrip(rel); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java index 06ed71dec..eb2c05b29 100644 --- a/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java +++ b/core/src/test/java/io/substrait/type/proto/ExtensionRoundtripTest.java @@ -7,6 +7,7 @@ import io.substrait.extension.AdvancedExtension; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; +import io.substrait.relation.Expand; import io.substrait.relation.ExtensionLeaf; import io.substrait.relation.ExtensionMulti; import io.substrait.relation.ExtensionSingle; @@ -233,6 +234,16 @@ void project() { verifyRoundTrip(rel); } + @Test + void expand() { + Rel rel = + Expand.builder() + .from(b.expand(__ -> Collections.emptyList(), commonTable)) + .commonExtension(commonExtension) + .build(); + verifyRoundTrip(rel); + } + @Test void set() { Rel rel = diff --git a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala index 9f4f5c9f8..0ba749b9e 100644 --- a/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala +++ b/spark/src/main/scala/io/substrait/debug/RelToVerboseString.scala @@ -131,6 +131,15 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] { }) } + override def visit(expand: Expand): String = { + withBuilder(expand, 8)( + builder => { + builder + .append("fields=") + .append(expand.getFields) + }) + } + override def visit(aggregate: Aggregate): String = { withBuilder(aggregate, 10)( builder => { diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index 08326454e..4bcae8dd9 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -67,6 +67,7 @@ class FunctionMappings { s[Count]("count"), s[Min]("min"), s[Max]("max"), + s[First]("any_value"), s[HyperLogLogPlusPlus]("approx_count_distinct") ) diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index a4ee9aaee..e928689fa 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -18,11 +18,9 @@ package io.substrait.spark.expression import io.substrait.spark.{DefaultExpressionVisitor, HasOutputStack, ToSubstraitType} import io.substrait.spark.logical.ToLogicalPlan - import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, Expression, In, Literal, NamedExpression, ScalarSubquery} -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.{Decimal, NullType} import org.apache.spark.unsafe.types.UTF8String - import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} @@ -77,6 +75,10 @@ class ToSparkExpression( Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } + override def visit(expr: SExpression.NullLiteral): Expression = { + Literal(null, ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.Cast): Expression = { val childExp = expr.input().accept(this) Cast(childExp, ToSubstraitType.convert(expr.getType)) diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala index 908a8aa0d..45b6c2205 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala @@ -18,7 +18,6 @@ package io.substrait.spark.logical import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ - import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions._ @@ -29,13 +28,13 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat -import org.apache.spark.sql.types.{DataTypes, IntegerType, StructType} - +import org.apache.spark.sql.types.{DataTypes, IntegerType, StructField, StructType} import io.substrait.`type`.{StringTypeVisitor, Type} import io.substrait.{expression => exp} import io.substrait.expression.{Expression => SExpression} import io.substrait.plan.Plan import io.substrait.relation +import io.substrait.relation.Expand.{ConsistentField, SwitchingField} import io.substrait.relation.LocalFiles import org.apache.hadoop.fs.Path @@ -82,11 +81,15 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] ) throw new IllegalArgumentException(msg) }) + + val filter = Option(measure.getPreMeasureFilter.orElse(null)) + .map(_.accept(expressionConverter)) + AggregateExpression( aggregateFunction, ToAggregateFunction.toSpark(function.aggregationPhase()), ToAggregateFunction.toSpark(function.invocation()), - None + filter ) } @@ -193,6 +196,27 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan] } } + override def visit(expand: relation.Expand): LogicalPlan = { + val child = expand.getInput.accept(this) + val names = expand.getHint.get().getOutputNames.asScala + + withChild(child) { + val projections = expand.getFields.asScala + .map { + case sf: SwitchingField => sf.getDuplicates.asScala + .map(expr => expr.accept(expressionConverter)) + .map(toNamedExpression) + case _: ConsistentField => throw new UnsupportedOperationException("ConsistentField not currently supported") + } + + val output = projections.head.zip(names) + .map { case (t, name) => StructField(name, t.dataType, t.nullable) } + .map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + + Expand(projections, output, child) + } + } + override def visit(filter: relation.Filter): LogicalPlan = { val child = filter.getInput.accept(this) withChild(child) { diff --git a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala index 4085860f4..08a06c2e4 100644 --- a/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala +++ b/spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala @@ -18,7 +18,6 @@ package io.substrait.spark.logical import io.substrait.spark.{SparkExtension, ToSubstraitType} import io.substrait.spark.expression._ - import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.expressions._ @@ -29,18 +28,17 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRela import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.types.StructType import ToSubstraitType.toNamedStruct - import io.substrait.{proto, relation} import io.substrait.debug.TreePrinter -import io.substrait.expression.{Expression => SExpression, ExpressionCreator} +import io.substrait.expression.{ExpressionCreator, Expression => SExpression} import io.substrait.extension.ExtensionCollector +import io.substrait.hint.Hint import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan} import io.substrait.relation.RelProtoConverter import io.substrait.relation.files.{FileFormat, ImmutableFileOrFiles} import io.substrait.relation.files.FileOrFiles.PathType -import java.util.Collections - +import java.util.{Collections, Optional} import scala.collection.JavaConverters.asJavaIterableConverter import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -72,7 +70,8 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { val substraitExps = expression.aggregateFunction.children.map(toExpression(output)) val invocation = SparkExtension.toAggregateFunction.apply(expression, substraitExps) - relation.Aggregate.Measure.builder.function(invocation).build() + val filter = expression.filter map toExpression(output) + relation.Aggregate.Measure.builder.function(invocation).preMeasureFilter(Optional.ofNullable(filter.orNull)).build() } private def collectAggregates( @@ -236,6 +235,23 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging { .build() } + override def visitExpand(p: Expand): relation.Rel = { + val fields = p.projections.map(proj => { + relation.Expand.SwitchingField.builder.duplicates( + proj.map(toExpression(p.child.output)).asJava + ).build() + }) + + val names = p.output.map(_.name) + + relation.Expand.builder + .remap(relation.Rel.Remap.offset(p.child.output.size, names.size)) + .fields(fields.asJava) + .hint(Hint.builder.addAllOutputNames(names.asJava).build()) + .input(visit(p.child)) + .build() + } + private def toSortField(output: Seq[Attribute] = Nil)(order: SortOrder): SExpression.SortField = { val direction = (order.direction, order.nullOrdering) match { case (Ascending, NullsFirst) => SExpression.SortDirection.ASC_NULLS_FIRST diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index 7cfb3cd2d..fd35b551a 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -32,7 +32,7 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } // "q9" failed in spark 3.3 - val successfulSQL: Set[String] = Set("q41", "q62", "q93", "q96", "q99") + val successfulSQL: Set[String] = Set("q4", "q7", "q18", "q22", "q26", "q28", "q29", "q37", "q41", "q48", "q50", "q62", "q69", "q82", "q85", "q88", "q90", "q93", "q96", "q97", "q99") tpcdsQueries.foreach { q => diff --git a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala index df5ca4f81..224ac2e8d 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCHPlan.scala @@ -16,6 +16,7 @@ */ package io.substrait.spark +import io.substrait.spark.logical.{ToLogicalPlan, ToSubstraitRel} import org.apache.spark.sql.TPCHBase class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { @@ -101,7 +102,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { " from lineitem group by l_partkey + l_orderkey") } - ignore("avg(distinct)") { + test("avg(distinct)") { assertSqlSubstraitRelRoundTrip( "select l_partkey, sum(l_tax), sum(distinct l_tax)," + " avg(l_discount), avg(distinct l_discount) from lineitem group by l_partkey") @@ -112,7 +113,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { "select l_partkey, sum(l_extendedprice * (1.0-l_discount)) from lineitem group by l_partkey") } - ignore("simpleTestAggFilter") { + test("simpleTestAggFilter") { assertSqlSubstraitRelRoundTrip( "select sum(l_tax) filter(WHERE l_orderkey > l_partkey) from lineitem") // cast is added to avoid the difference by implicit cast @@ -149,7 +150,7 @@ class TPCHPlan extends TPCHBase with SubstraitPlanTestBase { " where l_shipdate < date '1998-01-01' ") } - ignore("simpleTestGroupingSets [has Expand]") { + test("simpleTestGroupingSets [has Expand]") { assertSqlSubstraitRelRoundTrip( "select sum(l_discount) from lineitem group by grouping sets " + "((l_orderkey, L_COMMITDATE), l_shipdate)")