diff --git a/include/structured/Dialect/Substrait/IR/SubstraitOps.td b/include/structured/Dialect/Substrait/IR/SubstraitOps.td index 9729a25cd6b5..35968e00996d 100644 --- a/include/structured/Dialect/Substrait/IR/SubstraitOps.td +++ b/include/structured/Dialect/Substrait/IR/SubstraitOps.td @@ -17,6 +17,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" class Substrait_Op traits = []> : Op { @@ -49,6 +50,86 @@ def StringArrayAttr : let storageType = [{ ::mlir::ArrayAttr }]; } +//===----------------------------------------------------------------------===// +// Extensions +//===----------------------------------------------------------------------===// +// The definitions in this section are related to the extension messages. +// See https://substrait.io/serialization/binary_serialization/ and +// https://github.com/substrait-io/substrait/blob/main/proto/substrait/extensions/extensions.proto. +//===----------------------------------------------------------------------===// + +def Substrait_ExtensionUriOp : Substrait_Op<"extension_uri", [ + Symbol + ]> { + let summary = "Declares a simple extension URI"; + let description = [{ + This op represents the `SimpleExtensionURI` message type of Substrait. It is + a `Symbol` op, so it can be looked up in the symbol table of the plan it is + contained in. + + Example code: + + ```mlir + substrait.plan version 0 : 42 : 1 { + extension_uri @uri at "http://some.url/with/extensions.yml" + extension_function @function at @uri["func1"] + // ... + } + ``` + }]; + let arguments = (ins + SymbolNameAttr:$sym_name, // corresponds to `anchor` + StrAttr:$uri + ); + let assemblyFormat = "$sym_name `at` $uri attr-dict"; +} + +class Substrait_ExtensionOp traits = []> : + Substrait_Op<"extension_" # mnemonic, traits # [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods + ]> { + let description = [{ + This op represents the `SimpleExtensionDeclaration` message type of + Substrait along with the `Extension}] + # snakeCaseToCamelCase.ret # + [{` message type in the `mapping_type` case. It is both a `Symbol` op, so it + can be looked up in the symbol table of the plan it is contained in. + Conversely, its symbol reference `uri` must refer to an extension URI op + in the nearest symbol table. + }]; + let arguments = (ins + SymbolNameAttr:$sym_name, // corresponds to `anchor` + FlatSymbolRefAttr:$uri, + StrAttr:$name + ); + let assemblyFormat = "$sym_name `at` $uri `[` $name `]` attr-dict"; + let extraClassDefinition = [{ + /// Implement `SymbolOpInterface`. + ::mlir::LogicalResult $cppClass::verifySymbolUses( + mlir::SymbolTableCollection &symbolTables) { + if (!symbolTables.lookupNearestSymbolFrom(*this, + getUriAttr())) + return emitOpError() << "refers to " << getUriAttr() + << ", which is not a valid 'uri' op"; + return success(); + } + }]; +} + +def Substrait_ExtensionFunctionOp : Substrait_ExtensionOp<"function"> { + let summary = "Declares a simple extension function"; +} + +def Substrait_ExtensionTypeOp : Substrait_ExtensionOp<"type"> { + let summary = "Declares a simple extension type"; +} + +def Substrait_ExtensionTypeVariationOp : + Substrait_ExtensionOp<"type_variation"> { + let summary = "Declares a simple extension type variation"; +} + //===----------------------------------------------------------------------===// // Plan //===----------------------------------------------------------------------===// @@ -58,20 +139,23 @@ def StringArrayAttr : //===----------------------------------------------------------------------===// def PlanBodyOp : AnyOf<[ - IsOp<"::mlir::substrait::PlanRelOp"> + IsOp<"::mlir::substrait::PlanRelOp">, + IsOp<"::mlir::substrait::ExtensionUriOp">, + IsOp<"::mlir::substrait::ExtensionFunctionOp">, + IsOp<"::mlir::substrait::ExtensionTypeOp">, + IsOp<"::mlir::substrait::ExtensionTypeVariationOp">, ]>; def Substrait_PlanOp : Substrait_Op<"plan", [ DeclareOpInterfaceMethods, - NoTerminator, NoRegionArguments, SingleBlock + NoTerminator, NoRegionArguments, SingleBlock, SymbolTable ]> { let summary = "Represents a Substrait plan"; let description = [{ This op represents the `Plan` message type of Substrait. It carries the version information inline as attributes, so it also subsumes the `Version` - message type. The body of the op consists of the `relation`s and (once - implemented) the extensions and types as well as their URLs defined in the - plan. + message type. The body of the op consists of the `relation`s and the + function and type extensions defined in the plan. }]; let arguments = (ins UI32Attr:$major_number, diff --git a/lib/Target/SubstraitPB/Export.cpp b/lib/Target/SubstraitPB/Export.cpp index 475894b0902d..b6d25ee39506 100644 --- a/lib/Target/SubstraitPB/Export.cpp +++ b/lib/Target/SubstraitPB/Export.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -29,33 +30,63 @@ namespace pb = google::protobuf; namespace { -// Forward declaration for the export function of the given operation type. +/// Main structure to drive export from the dialect to protobuf. This class +/// holds the visitor functions for the various ops etc. from the dialect as +/// well as state and utilities around the state that is built up during export. +class SubstraitExporter { +public: +// Declaration for the export function of the given operation type. // // We need one such function for most op type that we want to export. The -// forward declarations are necessary such all export functions are available -// for the definitions indepedently of the order of these definitions. The // `MESSAGE_TYPE` argument corresponds to the protobuf message type returned // by the function. #define DECLARE_EXPORT_FUNC(OP_TYPE, MESSAGE_TYPE) \ - static FailureOr> exportOperation(OP_TYPE op); - -DECLARE_EXPORT_FUNC(CrossOp, Rel) -DECLARE_EXPORT_FUNC(EmitOp, Rel) -DECLARE_EXPORT_FUNC(ExpressionOpInterface, Expression) -DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression) -DECLARE_EXPORT_FUNC(FilterOp, Rel) -DECLARE_EXPORT_FUNC(LiteralOp, Expression) -DECLARE_EXPORT_FUNC(ModuleOp, Plan) -DECLARE_EXPORT_FUNC(NamedTableOp, Rel) -DECLARE_EXPORT_FUNC(PlanOp, Plan) -DECLARE_EXPORT_FUNC(ProjectOp, Rel) -DECLARE_EXPORT_FUNC(RelOpInterface, Rel) - -FailureOr> exportOperation(Operation *op); -FailureOr> exportOperation(RelOpInterface op); - -FailureOr> exportType(Location loc, - mlir::Type mlirType) { + FailureOr> exportOperation(OP_TYPE op); + + DECLARE_EXPORT_FUNC(CrossOp, Rel) + DECLARE_EXPORT_FUNC(EmitOp, Rel) + DECLARE_EXPORT_FUNC(ExpressionOpInterface, Expression) + DECLARE_EXPORT_FUNC(FieldReferenceOp, Expression) + DECLARE_EXPORT_FUNC(FilterOp, Rel) + DECLARE_EXPORT_FUNC(LiteralOp, Expression) + DECLARE_EXPORT_FUNC(ModuleOp, Plan) + DECLARE_EXPORT_FUNC(NamedTableOp, Rel) + DECLARE_EXPORT_FUNC(PlanOp, Plan) + DECLARE_EXPORT_FUNC(ProjectOp, Rel) + DECLARE_EXPORT_FUNC(RelOpInterface, Rel) + + FailureOr> exportOperation(Operation *op); + FailureOr> exportType(Location loc, + mlir::Type mlirType); + +private: + /// Returns the nearest symbol table to op. The symbol table is cached in + /// `this` such that repeated calls that request the same symbol do not + /// rebuild that table. + SymbolTable &getSymbolTableFor(Operation *op) { + Operation *nearestSymbolTableOp = SymbolTable::getNearestSymbolTable(op); + if (!symbolTable || symbolTable->getOp() != nearestSymbolTableOp) { + symbolTable = std::make_unique(nearestSymbolTableOp); + } + return *symbolTable; + } + + /// Looks up the anchor value corresponding to the given symbol name in the + /// context of the given op. The op is used to determine which symbol table + /// was used to assign anchors. + template + int32_t lookupAnchor(Operation *contextOp, const SymNameType &symName) { + SymbolTable &symbolTable = getSymbolTableFor(contextOp); + Operation *calleeOp = symbolTable.lookup(symName); + return anchorsByOp.at(calleeOp); + } + + DenseMap anchorsByOp{}; // Maps anchors to ops. + std::unique_ptr symbolTable; // Symbol table cache. +}; + +FailureOr> +SubstraitExporter::exportType(Location loc, mlir::Type mlirType) { MLIRContext *context = mlirType.getContext(); // Handle SI1. @@ -103,7 +134,7 @@ FailureOr> exportType(Location loc, return emitError(loc) << "could not export unsupported type " << mlirType; } -FailureOr> exportOperation(CrossOp op) { +FailureOr> SubstraitExporter::exportOperation(CrossOp op) { // Build `RelCommon` message. auto relCommon = std::make_unique(); auto direct = std::make_unique(); @@ -144,7 +175,7 @@ FailureOr> exportOperation(CrossOp op) { return rel; } -FailureOr> exportOperation(EmitOp op) { +FailureOr> SubstraitExporter::exportOperation(EmitOp op) { auto inputOp = dyn_cast_if_present(op.getInput().getDefiningOp()); if (!inputOp) @@ -181,7 +212,7 @@ FailureOr> exportOperation(EmitOp op) { } FailureOr> -exportOperation(ExpressionOpInterface op) { +SubstraitExporter::exportOperation(ExpressionOpInterface op) { return llvm::TypeSwitch>>( op) .Case( @@ -190,7 +221,8 @@ exportOperation(ExpressionOpInterface op) { [](auto op) { return op->emitOpError("not supported for export"); }); } -FailureOr> exportOperation(FieldReferenceOp op) { +FailureOr> +SubstraitExporter::exportOperation(FieldReferenceOp op) { using FieldReference = Expression::FieldReference; using ReferenceSegment = Expression::ReferenceSegment; @@ -248,7 +280,8 @@ FailureOr> exportOperation(FieldReferenceOp op) { return expression; } -FailureOr> exportOperation(FilterOp op) { +FailureOr> +SubstraitExporter::exportOperation(FilterOp op) { // Build `RelCommon` message. auto relCommon = std::make_unique(); auto direct = std::make_unique(); @@ -293,7 +326,8 @@ FailureOr> exportOperation(FilterOp op) { return rel; } -FailureOr> exportOperation(LiteralOp op) { +FailureOr> +SubstraitExporter::exportOperation(LiteralOp op) { // Build `Literal` message depending on type. auto value = llvm::cast(op.getValue()); mlir::Type literalType = value.getType(); @@ -324,7 +358,8 @@ FailureOr> exportOperation(LiteralOp op) { return expression; } -FailureOr> exportOperation(ModuleOp op) { +FailureOr> +SubstraitExporter::exportOperation(ModuleOp op) { if (!op->getAttrs().empty()) { op->emitOpError("has attributes"); return failure(); @@ -343,7 +378,8 @@ FailureOr> exportOperation(ModuleOp op) { return failure(); } -FailureOr> exportOperation(NamedTableOp op) { +FailureOr> +SubstraitExporter::exportOperation(NamedTableOp op) { Location loc = op.getLoc(); // Build `NamedTable` message. @@ -390,7 +426,100 @@ FailureOr> exportOperation(NamedTableOp op) { return rel; } -FailureOr> exportOperation(PlanOp op) { +/// Helper for creating unique anchors from symbol names. While in MLIR, symbol +/// names and their references are strings, in Substrait they are integer +/// numbers. In order to preserve the anchor values through an import/export +/// process (without modifications), the symbol names generated during import +/// have the form `.` such that the `anchor` value can be +/// recovered. During assigning of anchors, the uniquer fills a map mapping the +/// symbol ops to the assigned anchor values such that uses of the symbol can +/// look them up. +class AnchorUniquer { +public: + AnchorUniquer(StringRef prefix, DenseMap &anchorsByOp) + : prefix(prefix), anchorsByOp(anchorsByOp) {} + + /// Assign a unique anchor to the given op and register the result in the + /// mapping. + template + int32_t assignAnchor(OpTy op) { + StringRef symName = op.getSymName(); + int32_t anchor; + { + // Attempt to recover the anchor from the symbol name. + if (!symName.starts_with(prefix) || + symName.drop_front(prefix.size()).getAsInteger(10, anchor)) { + // If that fails, find one that isn't used yet. + anchor = nextAnchor; + } + // Ensure uniqueness either way. + while (anchors.contains(anchor)) + anchor = nextAnchor++; + } + anchors.insert(anchor); + auto [_, hasInserted] = anchorsByOp.try_emplace(op, anchor); + assert(hasInserted && "op had already been assigned an anchor"); + return anchor; + } + +private: + StringRef prefix; + DenseMap &anchorsByOp; // Maps ops to anchor values. + DenseSet anchors; // Already assigned anchors. + int32_t nextAnchor{0}; // Next anchor candidate. +}; + +/// Traits for common handling of `ExtensionFunctionOp`, `ExtensionTypeOp`, and +/// `ExtensionTypeVariationOp`. While their corresponding protobuf message types +/// are structurally the same, they are (1) different classes and (2) have +/// different field names. The Trait thus provides the message type class as +/// well as accessors for that class for each of the op types. +template +struct ExtensionOpTraits; + +template <> +struct ExtensionOpTraits { + using ExtensionMessageType = + extensions::SimpleExtensionDeclaration::ExtensionFunction; + static void setAnchor(ExtensionMessageType &ext, int32_t anchor) { + ext.set_function_anchor(anchor); + } + static ExtensionMessageType * + getMutableExtension(extensions::SimpleExtensionDeclaration &decl) { + return decl.mutable_extension_function(); + } +}; + +template <> +struct ExtensionOpTraits { + using ExtensionMessageType = + extensions::SimpleExtensionDeclaration::ExtensionType; + static void setAnchor(ExtensionMessageType &ext, int32_t anchor) { + ext.set_type_anchor(anchor); + } + static ExtensionMessageType * + getMutableExtension(extensions::SimpleExtensionDeclaration &decl) { + return decl.mutable_extension_type(); + } +}; + +template <> +struct ExtensionOpTraits { + using ExtensionMessageType = + extensions::SimpleExtensionDeclaration::ExtensionTypeVariation; + static void setAnchor(ExtensionMessageType &ext, int32_t anchor) { + ext.set_type_variation_anchor(anchor); + } + static ExtensionMessageType * + getMutableExtension(extensions::SimpleExtensionDeclaration &decl) { + return decl.mutable_extension_type_variation(); + } +}; + +FailureOr> SubstraitExporter::exportOperation(PlanOp op) { + using extensions::SimpleExtensionDeclaration; + using extensions::SimpleExtensionURI; + // Build `Version` message. auto version = std::make_unique(); version->set_major_number(op.getMajorNumber()); @@ -403,6 +532,61 @@ FailureOr> exportOperation(PlanOp op) { auto plan = std::make_unique(); plan->set_allocated_version(version.release()); + // Add `extension_uris` to plan. + { + AnchorUniquer anchorUniquer("extension_uri.", anchorsByOp); + for (auto uriOp : op.getOps()) { + int32_t anchor = anchorUniquer.assignAnchor(uriOp); + + // Create `SimpleExtensionURI` message. + SimpleExtensionURI *uri = plan->add_extension_uris(); + uri->set_uri(uriOp.getUri().str()); + uri->set_extension_uri_anchor(anchor); + } + } + + // Add `extensions` to plan. This requires the URIs to exist. + { + // Each extension type has its own anchor uniquer. + AnchorUniquer funcUniquer("extension_function.", anchorsByOp); + AnchorUniquer typeUniquer("extension_type.", anchorsByOp); + AnchorUniquer typeVarUniquer("extension_type_variation.", anchorsByOp); + + // Export an op of a given type using the corresponding uniquer. + auto exportExtensionOperation = [&](AnchorUniquer *uniquer, auto extOp) { + using OpTy = decltype(extOp); + using OpTraits = ExtensionOpTraits; + + // Compute URI reference and anchor value. + int32_t uriReference = lookupAnchor(op, extOp.getUri()); + int32_t anchor = uniquer->assignAnchor(extOp); + + // Create `SimpleExtensionDeclaration` and extension-specific messages. + typename OpTraits::ExtensionMessageType ext; + OpTraits::setAnchor(ext, anchor); + ext.set_extension_uri_reference(uriReference); + ext.set_name(extOp.getName().str()); + SimpleExtensionDeclaration *decl = plan->add_extensions(); + *OpTraits::getMutableExtension(*decl) = ext; + }; + + // Iterate over the different types of extension ops. This must be a single + // loop in order to preserve the order, which allows for interleaving of + // different types in both the protobuf and the MLIR form. + for (Operation &extOp : op.getOps()) { + TypeSwitch(extOp) + .Case([&](auto extOp) { + exportExtensionOperation(&funcUniquer, extOp); + }) + .Case([&](auto extOp) { + exportExtensionOperation(&typeUniquer, extOp); + }) + .Case([&](auto extOp) { + exportExtensionOperation(&typeVarUniquer, extOp); + }); + } + } + // Add `relation`s to plan. for (auto relOp : op.getOps()) { Operation *terminator = relOp.getBody().front().getTerminator(); @@ -433,7 +617,8 @@ FailureOr> exportOperation(PlanOp op) { return std::move(plan); } -FailureOr> exportOperation(ProjectOp op) { +FailureOr> +SubstraitExporter::exportOperation(ProjectOp op) { // Build `RelCommon` message. auto relCommon = std::make_unique(); auto direct = std::make_unique(); @@ -483,7 +668,8 @@ FailureOr> exportOperation(ProjectOp op) { return rel; } -FailureOr> exportOperation(RelOpInterface op) { +FailureOr> +SubstraitExporter::exportOperation(RelOpInterface op) { return llvm::TypeSwitch>>(op) .Case< // clang-format off @@ -501,7 +687,8 @@ FailureOr> exportOperation(RelOpInterface op) { }); } -FailureOr> exportOperation(Operation *op) { +FailureOr> +SubstraitExporter::exportOperation(Operation *op) { return llvm::TypeSwitch>>( op) .Case( @@ -525,7 +712,8 @@ namespace substrait { LogicalResult translateSubstraitToProtobuf(Operation *op, llvm::raw_ostream &output, substrait::ImportExportOptions options) { - FailureOr> result = exportOperation(op); + SubstraitExporter exporter; + FailureOr> result = exporter.exportOperation(op); if (failed(result)) return failure(); diff --git a/lib/Target/SubstraitPB/Import.cpp b/lib/Target/SubstraitPB/Import.cpp index 97dec0055fb0..17432f164bf3 100644 --- a/lib/Target/SubstraitPB/Import.cpp +++ b/lib/Target/SubstraitPB/Import.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -58,6 +59,32 @@ DECLARE_IMPORT_FUNC(ProjectRel, Rel, ProjectOp) DECLARE_IMPORT_FUNC(ReadRel, Rel, RelOpInterface) DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface) +// Helpers to build symbol names from anchors deterministically. This allows +// to reate symbol references from anchors without look-up structure. Also, +// the format is exploited by the export logic to recover the original anchor +// values of (unmodified) imported plans. + +/// Builds a deterministic symbol name for an URI with the given anchor. +static std::string buildUriSymName(int32_t anchor) { + return ("extension_uri." + Twine(anchor)).str(); +} + +/// Builds a deterministic symbol name for a function with the given anchor. +static std::string buildFuncSymName(int32_t anchor) { + return ("extension_function." + Twine(anchor)).str(); +} + +/// Builds a deterministic symbol name for a type with the given anchor. +static std::string buildTypeSymName(int32_t anchor) { + return ("extension_type." + Twine(anchor)).str(); +} + +/// Builds a deterministic symbol name for a type variation with the given +/// anchor. +static std::string buildTypeVarSymName(int32_t anchor) { + return ("extension_type_variation." + Twine(anchor)).str(); +} + static mlir::FailureOr importType(MLIRContext *context, const proto::Type &type) { @@ -302,15 +329,81 @@ importNamedTable(ImplicitLocOpBuilder builder, const Rel &message) { static FailureOr importPlan(ImplicitLocOpBuilder builder, const Plan &message) { + using extensions::SimpleExtensionDeclaration; + using extensions::SimpleExtensionURI; + using ExtensionFunction = SimpleExtensionDeclaration::ExtensionFunction; + using ExtensionType = SimpleExtensionDeclaration::ExtensionType; + using ExtensionTypeVariation = + SimpleExtensionDeclaration::ExtensionTypeVariation; + + MLIRContext *context = builder.getContext(); + Location loc = UnknownLoc::get(context); + const Version &version = message.version(); auto planOp = builder.create( version.major_number(), version.minor_number(), version.patch_number(), version.git_hash(), version.producer()); planOp.getBody().push_back(new Block()); - for (const auto &relation : message.relations()) { - OpBuilder::InsertionGuard insertGuard(builder); - builder.setInsertionPointToEnd(&planOp.getBody().front()); + OpBuilder::InsertionGuard insertGuard(builder); + builder.setInsertionPointToEnd(&planOp.getBody().front()); + + // Import `extension_uris` creating symbol names deterministically. + for (const SimpleExtensionURI &extUri : message.extension_uris()) { + int32_t anchor = extUri.extension_uri_anchor(); + StringRef uri = extUri.uri(); + std::string symName = buildUriSymName(anchor); + builder.create(symName, uri); + } + + // Import `extension`s reconstructing symbol references to URI ops from the + // corresponding anchors using the same method as above. + for (const SimpleExtensionDeclaration &ext : message.extensions()) { + SimpleExtensionDeclaration::MappingTypeCase mappingCase = + ext.mapping_type_case(); + switch (mappingCase) { + case SimpleExtensionDeclaration::kExtensionFunction: { + const ExtensionFunction &func = ext.extension_function(); + int32_t anchor = func.function_anchor(); + int32_t uriRef = func.extension_uri_reference(); + const std::string &funcName = func.name(); + std::string symName = buildFuncSymName(anchor); + std::string uriSymName = buildUriSymName(uriRef); + builder.create(symName, uriSymName, funcName); + break; + } + case SimpleExtensionDeclaration::kExtensionType: { + const ExtensionType &type = ext.extension_type(); + int32_t anchor = type.type_anchor(); + int32_t uriRef = type.extension_uri_reference(); + const std::string &typeName = type.name(); + std::string symName = buildTypeSymName(anchor); + std::string uriSymName = buildUriSymName(uriRef); + builder.create(symName, uriSymName, typeName); + break; + } + case SimpleExtensionDeclaration::kExtensionTypeVariation: { + const ExtensionTypeVariation &typeVar = ext.extension_type_variation(); + int32_t anchor = typeVar.type_variation_anchor(); + int32_t uriRef = typeVar.extension_uri_reference(); + const std::string &typeVarName = typeVar.name(); + std::string symName = buildTypeVarSymName(anchor); + std::string uriSymName = buildUriSymName(uriRef); + builder.create(symName, uriSymName, + typeVarName); + break; + } + default: + const pb::FieldDescriptor *desc = + SimpleExtensionDeclaration::GetDescriptor()->FindFieldByNumber( + mappingCase); + return emitError(loc) + << Twine("unsupported SimpleExtensionDeclaration type: ") + + desc->name(); + } + } + + for (const PlanRel &relation : message.relations()) { if (failed(importPlanRel(builder, relation))) return failure(); } diff --git a/test/Dialect/Substrait/plan-invalid.mlir b/test/Dialect/Substrait/plan-invalid.mlir new file mode 100644 index 000000000000..b9dfd2cfd2ea --- /dev/null +++ b/test/Dialect/Substrait/plan-invalid.mlir @@ -0,0 +1,42 @@ +// RUN: structured-opt -verify-diagnostics -split-input-file %s + +// Test error if no symbol was found for `extension_function` op. +substrait.plan version 0 : 42 : 1 { + // expected-error@+1 {{'substrait.extension_function' op refers to @extension, which is not a valid 'uri' op}} + extension_function @function at @extension["somefunc"] +} + +// ----- + +// Test error if no symbol was found for `extension_type` op. +substrait.plan version 0 : 42 : 1 { + // expected-error@+1 {{'substrait.extension_type' op refers to @extension, which is not a valid 'uri' op}} + extension_type @type at @extension["sometype"] +} + +// ----- + +// Test error if no symbol was found for `extension_type_variation` op. +substrait.plan version 0 : 42 : 1 { + // expected-error@+1 {{'substrait.extension_type_variation' op refers to @extension, which is not a valid 'uri' op}} + extension_type_variation @type_var at @extension["sometypevar"] +} + +// ----- + +// Test error if symbol was in the wrong scope. +substrait.extension_uri @extension at "http://some.url/with/extensions.yml" +substrait.plan version 0 : 42 : 1 { + // expected-error@+1 {{'substrait.extension_function' op refers to @extension, which is not a valid 'uri' op}} + extension_function @function at @extension["somefunc"] +} + +// ----- + +// Test error if no symbol refers to an op of the wrong type. +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function.1 at @extension["somefunc"] + // expected-error@+1 {{'substrait.extension_function' op refers to @function.1, which is not a valid 'uri' op}} + extension_function @function.2 at @function.1["somefunc"] +} diff --git a/test/Dialect/Substrait/plan-version.mlir b/test/Dialect/Substrait/plan.mlir similarity index 60% rename from test/Dialect/Substrait/plan-version.mlir rename to test/Dialect/Substrait/plan.mlir index e6a16b8a2647..81960379e1c4 100644 --- a/test/Dialect/Substrait/plan-version.mlir +++ b/test/Dialect/Substrait/plan.mlir @@ -63,3 +63,22 @@ substrait.plan version 0 : 42 : 1 { yield %0 : tuple> } } + +// ----- + +// CHECK: substrait.plan version 0 : 42 : 1 { +// CHECK-NEXT: extension_uri @extension at "http://some.url/with/extensions.yml" +// CHECK-NEXT: extension_function @function at @extension["somefunc"] +// CHECK-NEXT: extension_type @type at @extension["sometype"] +// CHECK-NEXT: extension_type_variation @type_var at @extension["sometypevar"] +// CHECK-NEXT: extension_uri @other.extension at "http://other.url/with/more/extensions.yml" +// CHECK-NEXT: extension_function @other.function at @other.extension["someotherfunc"] +// CHECK-NEXT: } +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_function @function at @extension["somefunc"] + extension_type @type at @extension["sometype"] + extension_type_variation @type_var at @extension["sometypevar"] + extension_uri @other.extension at "http://other.url/with/more/extensions.yml" + extension_function @other.function at @other.extension["someotherfunc"] +} diff --git a/test/Dialect/Substrait/plan-relation-invalid.mlir b/test/Dialect/Substrait/relation-invalid.mlir similarity index 100% rename from test/Dialect/Substrait/plan-relation-invalid.mlir rename to test/Dialect/Substrait/relation-invalid.mlir diff --git a/test/Target/SubstraitPB/Export/plan-simple.mlir b/test/Target/SubstraitPB/Export/plan-simple.mlir deleted file mode 100644 index 39a965b4ae1e..000000000000 --- a/test/Target/SubstraitPB/Export/plan-simple.mlir +++ /dev/null @@ -1,45 +0,0 @@ -// RUN: structured-translate -substrait-to-protobuf %s \ -// RUN: | FileCheck %s - -// RUN: structured-translate -substrait-to-protobuf %s \ -// RUN: | structured-translate -protobuf-to-substrait \ -// RUN: | structured-translate -substrait-to-protobuf \ -// RUN: | FileCheck %s - -// CHECK-LABEL: relations { -// CHECK-NEXT: rel { -// CHECK-NEXT: read { -// CHECK-NEXT: common { -// CHECK-NEXT: direct { -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: base_schema { -// CHECK-NEXT: names: "a" -// CHECK-NEXT: names: "b" -// CHECK-NEXT: struct { -// CHECK-NEXT: types { -// CHECK-NEXT: i32 { -// CHECK-NEXT: nullability: NULLABILITY_REQUIRED -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: types { -// CHECK-NEXT: i32 { -// CHECK-NEXT: nullability: NULLABILITY_REQUIRED -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: nullability: NULLABILITY_REQUIRED -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: named_table { -// CHECK-NEXT: names: "t1" -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: version { -substrait.plan version 0 : 42 : 1 { - relation { - %0 = named_table @t1 as ["a", "b"] : tuple - yield %0 : tuple - } -} diff --git a/test/Target/SubstraitPB/Export/plan-version.mlir b/test/Target/SubstraitPB/Export/plan-version.mlir deleted file mode 100644 index cdc853478684..000000000000 --- a/test/Target/SubstraitPB/Export/plan-version.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: structured-translate -substrait-to-protobuf --split-input-file %s \ -// RUN: | FileCheck %s - -// RUN: structured-translate -substrait-to-protobuf %s \ -// RUN: --split-input-file --output-split-marker="# -----" \ -// RUN: | structured-translate -protobuf-to-substrait \ -// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \ -// RUN: | structured-translate -substrait-to-protobuf \ -// RUN: --split-input-file --output-split-marker="# -----" \ -// RUN: | FileCheck %s - -// CHECK-LABEL: version { -// CHECK-DAG: minor_number: 42 -// CHECK-DAG: patch_number: 1 -// CHECK-DAG: git_hash: "hash" -// CHECK-DAG: producer: "producer" -// CHECK-NEXT: } -substrait.plan - version 0 : 42 : 1 - git_hash "hash" - producer "producer" - {} - -// ----- - -// CHECK: relations { -// CHECK-NEXT: root { -// CHECK-NEXT: input { -// CHECK-NEXT: read { -// CHECK: named_table { -// CHECK-NEXT: names -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: names: "x" -// CHECK-NEXT: names: "y" -// CHECK-NEXT: names: "z" -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: version - -substrait.plan version 0 : 42 : 1 { - relation as ["x", "y", "z"] { - %0 = named_table @t as ["a", "b", "c"] : tuple> - yield %0 : tuple> - } -} diff --git a/test/Target/SubstraitPB/Export/plan.mlir b/test/Target/SubstraitPB/Export/plan.mlir new file mode 100644 index 000000000000..a2f929aefb28 --- /dev/null +++ b/test/Target/SubstraitPB/Export/plan.mlir @@ -0,0 +1,127 @@ +// RUN: structured-translate -substrait-to-protobuf --split-input-file %s \ +// RUN: | FileCheck %s + +// RUN: structured-translate -substrait-to-protobuf %s \ +// RUN: --split-input-file --output-split-marker="# -----" \ +// RUN: | structured-translate -protobuf-to-substrait \ +// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \ +// RUN: | structured-translate -substrait-to-protobuf \ +// RUN: --split-input-file --output-split-marker="# -----" \ +// RUN: | FileCheck %s + +// CHECK-LABEL: version { +// CHECK-DAG: minor_number: 42 +// CHECK-DAG: patch_number: 1 +// CHECK-DAG: git_hash: "hash" +// CHECK-DAG: producer: "producer" +// CHECK-NEXT: } +substrait.plan + version 0 : 42 : 1 + git_hash "hash" + producer "producer" + {} + +// ----- + +// CHECK: relations { +// CHECK-NEXT: root { +// CHECK-NEXT: input { +// CHECK-NEXT: read { +// CHECK: named_table { +// CHECK-NEXT: names +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: names: "x" +// CHECK-NEXT: names: "y" +// CHECK-NEXT: names: "z" +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: version + +substrait.plan version 0 : 42 : 1 { + relation as ["x", "y", "z"] { + %0 = named_table @t as ["a", "b", "c"] : tuple> + yield %0 : tuple> + } +} + +// ----- + +// CHECK: extension_uris { +// CHECK-NEXT: extension_uri_anchor: 1 +// CHECK-NEXT: uri: "http://url.1/with/extensions.yml" +// CHECK: extension_uris { +// CHECK-NEXT: extension_uri_anchor: 2 +// CHECK-NEXT: uri: "http://url.2/with/extensions.yml" +// CHECK: extension_uris { +// CHECK-NEXT: extension_uri_anchor: 42 +// CHECK-NEXT: uri: "http://url.42/with/extensions.yml" +// CHECK: extension_uris { +// CHECK-NEXT: uri: "http://some.url/with/extensions.yml" +// CHECK: extension_uris { +// CHECK-NEXT: extension_uri_anchor: 3 +// CHECK-NEXT: uri: "http://url.foo/with/extensions.yml" +// CHECK: extension_uris { +// CHECK-NEXT: extension_uri_anchor: 4 +// CHECK-NEXT: uri: "http://url.bar/with/extensions.yml" +// CHECK: extensions { +// CHECK-NEXT: extension_function { +// CHECK-NEXT: extension_uri_reference: 1 +// CHECK-NEXT: function_anchor: 1 +// CHECK-NEXT: name: "func1" +// CHECK: extensions { +// CHECK-NEXT: extension_function { +// CHECK-NEXT: extension_uri_reference: 42 +// CHECK-NEXT: function_anchor: 42 +// CHECK-NEXT: name: "func42" +// CHECK: extensions { +// CHECK-NEXT: extension_type { +// CHECK-NEXT: extension_uri_reference: 2 +// CHECK-NEXT: type_anchor: 1 +// CHECK-NEXT: name: "type1" +// CHECK: extensions { +// CHECK-NEXT: extension_type { +// CHECK-NEXT: extension_uri_reference: 2 +// CHECK-NEXT: type_anchor: 42 +// CHECK-NEXT: name: "type42" +// CHECK: extensions { +// CHECK-NEXT: extension_type_variation { +// CHECK-NEXT: extension_uri_reference: 1 +// CHECK-NEXT: type_variation_anchor: 1 +// CHECK-NEXT: name: "typevar1" +// CHECK: extensions { +// CHECK-NEXT: extension_type_variation { +// CHECK-NEXT: extension_uri_reference: 1 +// CHECK-NEXT: type_variation_anchor: 42 +// CHECK-NEXT: name: "typevar2" + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension_uri.1 at "http://url.1/with/extensions.yml" + extension_uri @extension_uri.2 at "http://url.2/with/extensions.yml" + extension_uri @extension_uri.42 at "http://url.42/with/extensions.yml" + extension_uri @extension at "http://some.url/with/extensions.yml" + extension_uri @extension_uri.foo at "http://url.foo/with/extensions.yml" + extension_uri @extension_uri.bar at "http://url.bar/with/extensions.yml" + extension_function @extension_function.1 at @extension_uri.1["func1"] + extension_function @extension_function.42 at @extension_uri.42["func42"] + extension_type @extension_type.1 at @extension_uri.2["type1"] + extension_type @extension_type.42 at @extension_uri.2["type42"] + extension_type_variation @extension_type_variation.1 at @extension_uri.1["typevar1"] + extension_type_variation @extension_type_variation.42 at @extension_uri.1["typevar2"] +} + +// ----- + + +// CHECK: extension_uris { +// CHECK-NEXT: uri: "http://some.url/with/extensions.yml" +// CHECK: extension_uris { +// CHECK-NEXT: extension_uri_anchor: 1 +// CHECK-NEXT: uri: "http://other.url/with/more/extensions.yml" + +substrait.plan version 0 : 42 : 1 { + extension_uri @extension at "http://some.url/with/extensions.yml" + // If not handled carefully, parsing this symbol into an anchor may clash. + extension_uri @extension_uri.0 at "http://other.url/with/more/extensions.yml" +} diff --git a/test/Target/SubstraitPB/Import/plan-simple.textpb b/test/Target/SubstraitPB/Import/plan-simple.textpb deleted file mode 100644 index 86921ebd5148..000000000000 --- a/test/Target/SubstraitPB/Import/plan-simple.textpb +++ /dev/null @@ -1,46 +0,0 @@ -# RUN: structured-translate -protobuf-to-substrait %s \ -# RUN: | FileCheck %s - -# RUN: structured-translate -protobuf-to-substrait %s \ -# RUN: | structured-translate -substrait-to-protobuf \ -# RUN: | structured-translate -protobuf-to-substrait \ -# RUN: | FileCheck %s - -# CHECK-LABEL: substrait.plan -# CHECK-NEXT: relation { -# CHECK-NEXT: %[[V0:.*]] = named_table @t1 as ["a", "b"] : tuple -# CHECK-NEXT: yield %[[V0]] : tuple -relations { - rel { - read { - common { - direct { - } - } - base_schema { - names: "a" - names: "b" - struct { - types { - i32 { - nullability: NULLABILITY_REQUIRED - } - } - types { - i32 { - nullability: NULLABILITY_REQUIRED - } - } - nullability: NULLABILITY_REQUIRED - } - } - named_table { - names: "t1" - } - } - } -} -version { - minor_number: 42 - patch_number: 1 -} diff --git a/test/Target/SubstraitPB/Import/plan-version.textpb b/test/Target/SubstraitPB/Import/plan-version.textpb deleted file mode 100644 index 3e057c231817..000000000000 --- a/test/Target/SubstraitPB/Import/plan-version.textpb +++ /dev/null @@ -1,17 +0,0 @@ -# RUN: structured-translate -protobuf-to-substrait %s \ -# RUN: | FileCheck %s - -# RUN: structured-translate -protobuf-to-substrait %s \ -# RUN: | structured-translate -substrait-to-protobuf \ -# RUN: | structured-translate -protobuf-to-substrait \ -# RUN: | FileCheck %s - -# CHECK: substrait.plan version 0 : 42 : 1 -# CHECK-SAME: git_hash "hash" producer "producer" { -# CHECK-NEXT: } -version { - minor_number: 42 - patch_number: 1 - git_hash: "hash" - producer: "producer" -} diff --git a/test/Target/SubstraitPB/Import/plan.textpb b/test/Target/SubstraitPB/Import/plan.textpb new file mode 100644 index 000000000000..c44c95b20af1 --- /dev/null +++ b/test/Target/SubstraitPB/Import/plan.textpb @@ -0,0 +1,149 @@ +# RUN: structured-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" \ +# RUN: | FileCheck %s + +# RUN: structured-translate -protobuf-to-substrait %s \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ +# RUN: | structured-translate -substrait-to-protobuf \ +# RUN: --split-input-file --output-split-marker="# ""-----" \ +# RUN: | structured-translate -protobuf-to-substrait \ +# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \ +# RUN: | FileCheck %s + +# CHECK-LABEL: substrait.plan version 0 : 42 : 1 +# CHECK-SAME: git_hash "hash" producer "producer" { +# CHECK-NEXT: } +version { + minor_number: 42 + patch_number: 1 + git_hash: "hash" + producer: "producer" +} + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: relation { +# CHECK-NEXT: %[[V0:.*]] = named_table @t1 as ["a", "b"] : tuple +# CHECK-NEXT: yield %[[V0]] : tuple +relations { + rel { + read { + common { + direct { + } + } + base_schema { + names: "a" + names: "b" + struct { + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + types { + i32 { + nullability: NULLABILITY_REQUIRED + } + } + nullability: NULLABILITY_REQUIRED + } + } + named_table { + names: "t1" + } + } + } +} +version { + minor_number: 42 + patch_number: 1 +} + +# ----- + +# CHECK-LABEL: substrait.plan +# CHECK-NEXT: extension_uri @extension_uri.1 at "http://url.1/with/extensions.yml" +# CHECK-NEXT: extension_uri @extension_uri.2 at "http://url.2/with/extensions.yml" +# CHECK-NEXT: extension_uri @extension_uri.42 at "http://url.42/with/extensions.yml" +# CHECK-NEXT: extension_uri @extension_uri.0 at "http://some.url/with/extensions.yml" +# CHECK-NEXT: extension_uri @extension_uri.3 at "http://url.foo/with/extensions.yml" +# CHECK-NEXT: extension_uri @extension_uri.4 at "http://url.bar/with/extensions.yml" +# CHECK-NEXT: extension_function @extension_function.1 at @extension_uri.1["func1"] +# CHECK-NEXT: extension_function @extension_function.42 at @extension_uri.42["func42"] +# CHECK-NEXT: extension_type @extension_type.1 at @extension_uri.2["type1"] +# CHECK-NEXT: extension_type @extension_type.42 at @extension_uri.2["type42"] +# CHECK-NEXT: extension_type_variation @extension_type_variation.1 at @extension_uri.1["typevar1"] +# CHECK-NEXT: extension_type_variation @extension_type_variation.42 at @extension_uri.1["typevar2"] +# CHECK-NEXT: } + +extension_uris { + extension_uri_anchor: 1 + uri: "http://url.1/with/extensions.yml" +} +extension_uris { + extension_uri_anchor: 2 + uri: "http://url.2/with/extensions.yml" +} +extension_uris { + extension_uri_anchor: 42 + uri: "http://url.42/with/extensions.yml" +} +extension_uris { + uri: "http://some.url/with/extensions.yml" +} +extension_uris { + extension_uri_anchor: 3 + uri: "http://url.foo/with/extensions.yml" +} +extension_uris { + extension_uri_anchor: 4 + uri: "http://url.bar/with/extensions.yml" +} +extensions { + extension_function { + extension_uri_reference: 1 + function_anchor: 1 + name: "func1" + } +} +extensions { + extension_function { + extension_uri_reference: 42 + function_anchor: 42 + name: "func42" + } +} +extensions { + extension_type { + extension_uri_reference: 2 + type_anchor: 1 + name: "type1" + } +} +extensions { + extension_type { + extension_uri_reference: 2 + type_anchor: 42 + name: "type42" + } +} +extensions { + extension_type_variation { + extension_uri_reference: 1 + type_variation_anchor: 1 + name: "typevar1" + } +} +extensions { + extension_type_variation { + extension_uri_reference: 1 + type_variation_anchor: 42 + name: "typevar2" + } +} +version { + minor_number: 42 + patch_number: 1 +}