Skip to content

Commit

Permalink
[Substrait] Implement CallOp for ScalarFunction message.
Browse files Browse the repository at this point in the history
This allows to use the externally declared functions from #853 as scalar
function calls (which have the semantics of a typical function call).
The current design anticipates that aggregate and window functions can
be modelled with the same op, but future PR will need to show if and how
that is possible.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Jul 22, 2024
1 parent 662a78d commit f512dab
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 5 deletions.
37 changes: 37 additions & 0 deletions include/structured/Dialect/Substrait/IR/SubstraitOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,43 @@ def Substrait_LiteralOp : Substrait_ExpressionOp<"literal", [
let assemblyFormat = "$value attr-dict";
}

def Substrait_CallOp : Substrait_ExpressionOp<"call", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
]> {
let summary = "Function call expression";
let description = [{
Represents a `ScalarFunction` message (or, in the future, other `*Function`
messages) together with all messages it contains and the `Expression`
message it is contained in.

Currently, the specification of the function, which is in an external YAML
file, is not taken into account, for example, to verify whether a matching
overload exists or to verify/compute the result type.

Example:

```mlir
extension_uri @extension at "http://some.url/with/extensions.yml"
extension_function @function at @extension["somefunc"]
relation {
// ...
%1 = call @function(%0) : (tuple<si32>) -> si1
// ...
}
```
}];
// TODO(ingomueller): Add `FunctionOptions`.
// TODO(ingomueller): Add support for `enum` and `type` argument types.
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<Substrait_FieldType>:$args
);
let results = (outs Substrait_FieldType:$result);
let assemblyFormat = [{
$callee `(` $args `)` attr-dict `:` `(` type($args) `)` `->` type($result)
}];
}

//===----------------------------------------------------------------------===//
// Relations
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions lib/Dialect/Substrait/IR/Substrait.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ void SubstraitDialect::initialize() {
namespace mlir {
namespace substrait {

/// Implement `SymbolOpInterface`.
::mlir::LogicalResult
CallOp::verifySymbolUses(SymbolTableCollection &symbolTables) {
if (!symbolTables.lookupNearestSymbolFrom<ExtensionFunctionOp>(
*this, getCalleeAttr()))
return emitOpError() << "refers to " << getCalleeAttr()
<< ", which is not a valid 'extension_function' op";
return success();
}

LogicalResult
CrossOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
ValueRange operands, DictionaryAttr attributes,
Expand Down
50 changes: 49 additions & 1 deletion lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SubstraitExporter {
#define DECLARE_EXPORT_FUNC(OP_TYPE, MESSAGE_TYPE) \
FailureOr<std::unique_ptr<MESSAGE_TYPE>> exportOperation(OP_TYPE op);

DECLARE_EXPORT_FUNC(CallOp, Expression)
DECLARE_EXPORT_FUNC(CrossOp, Rel)
DECLARE_EXPORT_FUNC(EmitOp, Rel)
DECLARE_EXPORT_FUNC(ExpressionOpInterface, Expression)
Expand Down Expand Up @@ -133,6 +134,53 @@ SubstraitExporter::exportType(Location loc, mlir::Type mlirType) {
return emitError(loc) << "could not export unsupported type " << mlirType;
}

FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(CallOp op) {
using ScalarFunction = Expression::ScalarFunction;

Location loc = op.getLoc();

// Build `ScalarFunction` message.
// TODO(ingomueller): Support other `*Function` messages.
auto scalarFunction = std::make_unique<ScalarFunction>();
int32_t anchor = lookupAnchor(op, op.getCallee());
scalarFunction->set_function_reference(anchor);

// Build messages for arguments.
for (auto [i, operand] : llvm::enumerate(op->getOperands())) {
// Build `Expression` message for operand.
auto definingOp = llvm::dyn_cast_if_present<ExpressionOpInterface>(
operand.getDefiningOp());
if (!definingOp)
return op->emitOpError()
<< "with operand " << i
<< " that was not produced by Substrait relation op";

FailureOr<std::unique_ptr<Expression>> expression =
exportOperation(definingOp);
if (failed(expression))
return failure();

// Build `FunctionArgument` message and add to arguments.
FunctionArgument arg;
arg.set_allocated_value(expression->release());
*scalarFunction->add_arguments() = arg;
}

// Build message for `output_type`.
FailureOr<std::unique_ptr<proto::Type>> outputType =
exportType(loc, op.getResult().getType());
if (failed(outputType))
return failure();
scalarFunction->set_allocated_output_type(outputType->release());

// Build `Expression` message.
auto expression = std::make_unique<Expression>();
expression->set_allocated_scalar_function(scalarFunction.release());

return expression;
}

FailureOr<std::unique_ptr<Rel>> SubstraitExporter::exportOperation(CrossOp op) {
// Build `RelCommon` message.
auto relCommon = std::make_unique<RelCommon>();
Expand Down Expand Up @@ -214,7 +262,7 @@ FailureOr<std::unique_ptr<Expression>>
SubstraitExporter::exportOperation(ExpressionOpInterface op) {
return llvm::TypeSwitch<Operation *, FailureOr<std::unique_ptr<Expression>>>(
op)
.Case<FieldReferenceOp, LiteralOp>(
.Case<CallOp, FieldReferenceOp, LiteralOp>(
[&](auto op) { return exportOperation(op); })
.Default(
[](auto op) { return op->emitOpError("not supported for export"); });
Expand Down
53 changes: 49 additions & 4 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ DECLARE_IMPORT_FUNC(Plan, Plan, PlanOp)
DECLARE_IMPORT_FUNC(PlanRel, PlanRel, PlanRelOp)
DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)
DECLARE_IMPORT_FUNC(ScalarFunction, Expression::ScalarFunction, CallOp)

// Helpers to build symbol names from anchors deterministically. This allows
// to reate symbol references from anchors without look-up structure. Also,
Expand Down Expand Up @@ -145,12 +146,12 @@ importExpression(ImplicitLocOpBuilder builder, const Expression &message) {

Expression::RexTypeCase rex_type = message.rex_type_case();
switch (rex_type) {
case Expression::RexTypeCase::kLiteral: {
case Expression::kLiteral:
return importLiteral(builder, message.literal());
}
case Expression::RexTypeCase::kSelection: {
case Expression::kSelection:
return importFieldReference(builder, message.selection());
}
case Expression::kScalarFunction:
return importScalarFunction(builder, message.scalar_function());
default: {
const pb::FieldDescriptor *desc =
Expression::GetDescriptor()->FindFieldByNumber(rex_type);
Expand Down Expand Up @@ -519,6 +520,50 @@ static mlir::FailureOr<RelOpInterface> importRel(ImplicitLocOpBuilder builder,
return {emitOp};
}

static mlir::FailureOr<CallOp>
importScalarFunction(ImplicitLocOpBuilder builder,
const Expression::ScalarFunction &message) {
MLIRContext *context = builder.getContext();
Location loc = UnknownLoc::get(context);

// Import `output_type`.
proto::Type outputType = message.output_type();
FailureOr<mlir::Type> mlirOutputType = importType(context, outputType);
if (failed(mlirOutputType))
return failure();

// Import `arguments`.
SmallVector<Value> operands;
for (const FunctionArgument &arg : message.arguments()) {
// Error out on unsupported cases.
// TODO(ingomueller): Support other function argument types.
if (!arg.has_value()) {
const pb::FieldDescriptor *desc =
FunctionArgument::GetDescriptor()->FindFieldByNumber(
arg.arg_type_case());
return emitError(loc) << Twine("unsupported arg type: ") + desc->name();
}

// Handle `value` case.
const Expression &value = arg.value();
FailureOr<ExpressionOpInterface> expression =
importExpression(builder, value);
if (failed(expression))
return failure();
operands.push_back((*expression)->getResult(0));
}

// Import `function_refernece` field.
int32_t anchor = message.function_reference();
std::string calleeSymName = buildFuncSymName(anchor);

// Create op.
auto callOp =
builder.create<CallOp>(mlirOutputType.value(), calleeSymName, operands);

return {callOp};
}

} // namespace

namespace mlir {
Expand Down
27 changes: 27 additions & 0 deletions test/Dialect/Substrait/call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: structured-opt -split-input-file %s \
// RUN: | FileCheck %s

// CHECK-LABEL: substrait.plan
// CHECK: relation
// CHECK: named_table
// CHECK-NEXT: filter
// CHECK-NEXT: (%[[ARG0:.*]]: tuple<si32>)
// CHECK-NEXT: %[[V0:.*]] = field_reference %[[ARG0]]
// CHECK-NEXT: %[[V1:.*]] = call @function(%[[V0]]) : (si32) -> si1
// CHECK-NEXT: yield
// CHECK-NEXT: }

substrait.plan version 0 : 42 : 1 {
extension_uri @extension at "http://some.url/with/extensions.yml"
extension_function @function at @extension["somefunc"]
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = filter %0 : tuple<si32> {
^bb0(%arg : tuple<si32>):
%2 = field_reference %arg[[0]] : tuple<si32>
%3 = call @function(%2) : (si32) -> si1
yield %3 : si1
}
yield %1 : tuple<si32>
}
}
51 changes: 51 additions & 0 deletions test/Target/SubstraitPB/Export/call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: structured-translate -substrait-to-protobuf --split-input-file %s \
// RUN: | FileCheck %s

// RUN: structured-translate -substrait-to-protobuf %s \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | structured-translate -protobuf-to-substrait \
// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \
// RUN: | structured-translate -substrait-to-protobuf \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | FileCheck %s

// CHECK: extension_uris {
// CHECK-NEXT: uri: "http://some.url/with/extensions.yml"
// CHECK-NEXT: }
// CHECK-NEXT: extensions {
// CHECK-NEXT: extension_function {
// CHECK-NEXT: name: "somefunc"
// CHECK-NEXT: }
// CHECK: extensions {
// CHECK-NEXT: extension_function {
// CHECK-NEXT: function_anchor: 1
// CHECK-NEXT: name: "somefunc"
// CHECK: relations {
// CHECK-NEXT: rel {
// CHECK-NEXT: filter {
// CHECK-NOT: condition
// CHECK: condition {
// CHECK-NEXT: scalar_function {
// CHECK-NEXT: function_reference: 1
// CHECK-NEXT: output_type {
// CHECK-NEXT: bool {
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
// CHECK: arguments {
// CHECK-NEXT: value {
// CHECK-NEXT: selection {

substrait.plan version 0 : 42 : 1 {
extension_uri @extension at "http://some.url/with/extensions.yml"
extension_function @f1 at @extension["somefunc"]
extension_function @f2 at @extension["somefunc"]
relation {
%0 = named_table @t1 as ["a"] : tuple<si32>
%1 = filter %0 : tuple<si32> {
^bb0(%arg : tuple<si32>):
%2 = field_reference %arg[[0]] : tuple<si32>
%3 = call @f2(%2) : (si32) -> si1
yield %3 : si1
}
yield %1 : tuple<si32>
}
}
96 changes: 96 additions & 0 deletions test/Target/SubstraitPB/Import/call.textpb
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# RUN: structured-translate -protobuf-to-substrait %s \
# RUN: --split-input-file="# ""-----" \
# RUN: | FileCheck %s

# RUN: structured-translate -protobuf-to-substrait %s \
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
# RUN: | structured-translate -substrait-to-protobuf \
# RUN: --split-input-file --output-split-marker="# ""-----" \
# RUN: | structured-translate -protobuf-to-substrait \
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
# RUN: | FileCheck %s

# CHECK-LABEL: substrait.plan
# CHECK-NEXT: extension_uri @[[URI:.*]] at "http://some.url/with/extensions.yml"
# CHECK-NEXT: extension_function @[[F1:.*]] at @[[URI]]["somefunc"]
# CHECK-NEXT: extension_function @[[F2:.*]] at @[[URI]]["somefunc"]
# CHECK-NEXT: relation
# CHECK-NEXT: named_table
# CHECK-NEXT: filter
# CHECK-NEXT: (%[[V0:.*]]: tuple<si32>):
# CHECK-NEXT: %[[V1:.*]] = field_reference %[[V0]][{{\[}}0]] : tuple<si32>
# CHECK-NEXT: %[[V2:.*]] = call @[[F2]](%[[V1]]) : (si32) -> si1
# CHECK-NEXT: yield %[[V2]] : si1

extension_uris {
uri: "http://some.url/with/extensions.yml"
}
extensions {
extension_function {
name: "somefunc"
}
}
extensions {
extension_function {
function_anchor: 1
name: "somefunc"
}
}
relations {
rel {
filter {
common {
direct {
}
}
input {
read {
common {
direct {
}
}
base_schema {
names: "a"
struct {
types {
i32 {
nullability: NULLABILITY_REQUIRED
}
}
nullability: NULLABILITY_REQUIRED
}
}
named_table {
names: "t1"
}
}
}
condition {
scalar_function {
function_reference: 1
output_type {
bool {
nullability: NULLABILITY_REQUIRED
}
}
arguments {
value {
selection {
direct_reference {
struct_field {
}
}
root_reference {
}
}
}
}
}
}
}
}
}
version {
minor_number: 42
patch_number: 1
}

0 comments on commit f512dab

Please sign in to comment.