Skip to content

Commit

Permalink
fix(naming): rename variant to implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
chaojun-zhang committed Feb 2, 2023
1 parent 717bb01 commit 0a67c5b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 81 deletions.
34 changes: 17 additions & 17 deletions include/substrait/function/Extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct TypeVariant {

using TypeVariantPtr = std::shared_ptr<TypeVariant>;

using FunctionVariantMap =
using FunctionImplMap =
std::unordered_map<std::string, std::vector<FunctionImplementationPtr>>;

using TypeVariantMap = std::unordered_map<std::string, TypeVariantPtr>;
Expand All @@ -40,40 +40,40 @@ class Extension {
static std::shared_ptr<Extension> load(
const std::vector<std::string>& extensionFiles);

/// Add a scalar function variant.
void addScalarFunctionVariant(const FunctionImplementationPtr& functionVariant);
/// Add a scalar function implementation.
void addScalarFunctionImpl(const FunctionImplementationPtr& functionImpl);

/// Add a aggregate function variant.
void addAggregateFunctionVariant(const FunctionImplementationPtr& functionVariant);
/// Add a aggregate function implementation.
void addAggregateFunctionImpl(const FunctionImplementationPtr& functionImpl);

/// Add a window function variant.
void addWindowFunctionVariant(const FunctionImplementationPtr& functionVariant);
/// Add a window function implementation.
void addWindowFunctionImpl(const FunctionImplementationPtr& functionImpl);

/// Add a type variant.
void addTypeVariant(const TypeVariantPtr& functionVariant);
void addTypeVariant(const TypeVariantPtr& typeVariant);

/// Lookup type variant by given type name.
/// @return matched type variant
TypeVariantPtr lookupType(const std::string& typeName) const;

const FunctionVariantMap& scalaFunctionVariantMap() const {
return scalarFunctionVariantMap_;
const FunctionImplMap& scalaFunctionImplMap() const {
return scalarFunctionImplMap_;
}

const FunctionVariantMap& windowFunctionVariantMap() const {
return windowFunctionVariantMap_;
const FunctionImplMap& windowFunctionImplMap() const {
return windowFunctionImplMap_;
}

const FunctionVariantMap& aggregateFunctionVariantMap() const {
return aggregateFunctionVariantMap_;
const FunctionImplMap& aggregateFunctionImplMap() const {
return aggregateFunctionImplMap_;
}

private:
FunctionVariantMap scalarFunctionVariantMap_;
FunctionImplMap scalarFunctionImplMap_;

FunctionVariantMap aggregateFunctionVariantMap_;
FunctionImplMap aggregateFunctionImplMap_;

FunctionVariantMap windowFunctionVariantMap_;
FunctionImplMap windowFunctionImplMap_;

TypeVariantMap typeVariantMap_;
};
Expand Down
2 changes: 1 addition & 1 deletion include/substrait/function/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct FunctionImplementation {
ParameterizedTypePtr returnType;
std::optional<FunctionVariadic> variadic;

/// Test if the actual types matched with this function variant.
/// Test if the actual types matched with this function implement.
virtual bool tryMatch(const FunctionSignature& signature);

/// Create function signature by function name and arguments.
Expand Down
14 changes: 7 additions & 7 deletions include/substrait/function/FunctionLookup.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FunctionLookup {
virtual ~FunctionLookup() = default;

protected:
[[nodiscard]] virtual FunctionVariantMap getFunctionVariants() const = 0;
[[nodiscard]] virtual FunctionImplMap getFunctionImpls() const = 0;

ExtensionPtr extension_{};
};
Expand All @@ -31,8 +31,8 @@ class ScalarFunctionLookup : public FunctionLookup {
: FunctionLookup(extension) {}

protected:
[[nodiscard]] FunctionVariantMap getFunctionVariants() const override {
return extension_->scalaFunctionVariantMap();
[[nodiscard]] FunctionImplMap getFunctionImpls() const override {
return extension_->scalaFunctionImplMap();
}
};

Expand All @@ -42,8 +42,8 @@ class AggregateFunctionLookup : public FunctionLookup {
: FunctionLookup(extension) {}

protected:
[[nodiscard]] FunctionVariantMap getFunctionVariants() const override {
return extension_->aggregateFunctionVariantMap();
[[nodiscard]] FunctionImplMap getFunctionImpls() const override {
return extension_->aggregateFunctionImplMap();
}
};

Expand All @@ -53,8 +53,8 @@ class WindowFunctionLookup : public FunctionLookup {
: FunctionLookup(extension) {}

protected:
[[nodiscard]] FunctionVariantMap getFunctionVariants() const override {
return extension_->windowFunctionVariantMap();
[[nodiscard]] FunctionImplMap getFunctionImpls() const override {
return extension_->windowFunctionImplMap();
}
};

Expand Down
105 changes: 53 additions & 52 deletions substrait/function/Extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include <yaml-cpp/yaml.h>
#include "substrait/function/Extension.h"

bool decodeFunctionVariant(
bool decodeFunctionImpl(
const YAML::Node& node,
io::substrait::FunctionImplementation& function) {
const auto& returnType = node["return"];
Expand Down Expand Up @@ -103,7 +103,7 @@ struct YAML::convert<io::substrait::ScalarFunctionImplementation> {
static bool decode(
const Node& node,
io::substrait::ScalarFunctionImplementation& function) {
return decodeFunctionVariant(node, function);
return decodeFunctionImpl(node, function);
};
};

Expand All @@ -112,7 +112,7 @@ struct YAML::convert<io::substrait::AggregateFunctionImplementation> {
static bool decode(
const Node& node,
io::substrait::AggregateFunctionImplementation& function) {
const auto& res = decodeFunctionVariant(node, function);
const auto& res = decodeFunctionImpl(node, function);
if (res) {
const auto& intermediate = node["intermediate"];
if (intermediate) {
Expand Down Expand Up @@ -178,13 +178,14 @@ std::shared_ptr<Extension> Extension::load(
if (scalarFunctions && scalarFunctions.IsSequence()) {
for (auto& scalarFunctionNode : scalarFunctions) {
const auto functionName = scalarFunctionNode["name"].as<std::string>();
for (auto& scalaFunctionVariantNode : scalarFunctionNode["impls"]) {
auto scalarFunctionVariant =
scalaFunctionVariantNode.as<ScalarFunctionImplementation>();
scalarFunctionVariant.name = functionName;
scalarFunctionVariant.uri = extensionUri;
extension->addScalarFunctionVariant(
std::make_shared<ScalarFunctionImplementation>(scalarFunctionVariant));
for (auto& scalaFunctionImplNode : scalarFunctionNode["impls"]) {
auto scalarFunctionImpl =
scalaFunctionImplNode.as<ScalarFunctionImplementation>();
scalarFunctionImpl.name = functionName;
scalarFunctionImpl.uri = extensionUri;
extension->addScalarFunctionImpl(
std::make_shared<ScalarFunctionImplementation>(
scalarFunctionImpl));
}
}
}
Expand All @@ -194,15 +195,15 @@ std::shared_ptr<Extension> Extension::load(
for (auto& aggregateFunctionNode : aggregateFunctions) {
const auto functionName =
aggregateFunctionNode["name"].as<std::string>();
for (auto& aggregateFunctionVariantNode :
for (auto& aggregateFunctionImplNode :
aggregateFunctionNode["impls"]) {
auto aggregateFunctionVariant =
aggregateFunctionVariantNode.as<AggregateFunctionImplementation>();
aggregateFunctionVariant.name = functionName;
aggregateFunctionVariant.uri = extensionUri;
extension->addAggregateFunctionVariant(
auto aggregateFunctionImpl =
aggregateFunctionImplNode.as<AggregateFunctionImplementation>();
aggregateFunctionImpl.name = functionName;
aggregateFunctionImpl.uri = extensionUri;
extension->addAggregateFunctionImpl(
std::make_shared<AggregateFunctionImplementation>(
aggregateFunctionVariant));
aggregateFunctionImpl));
}
}
}
Expand All @@ -219,23 +220,23 @@ std::shared_ptr<Extension> Extension::load(
return extension;
}

void Extension::addWindowFunctionVariant(
const FunctionImplementationPtr& functionVariant) {
const auto& functionVariants =
windowFunctionVariantMap_.find(functionVariant->name);
if (functionVariants != windowFunctionVariantMap_.end()) {
auto& variants = functionVariants->second;
variants.emplace_back(functionVariant);
void Extension::addWindowFunctionImpl(
const FunctionImplementationPtr& functionImpl) {
const auto& functionImpls =
windowFunctionImplMap_.find(functionImpl->name);
if (functionImpls != windowFunctionImplMap_.end()) {
auto& impls = functionImpls->second;
impls.emplace_back(functionImpl);
} else {
std::vector<FunctionImplementationPtr> variants;
variants.emplace_back(functionVariant);
windowFunctionVariantMap_.insert(
{functionVariant->name, std::move(variants)});
std::vector<FunctionImplementationPtr> impls;
impls.emplace_back(functionImpl);
windowFunctionImplMap_.insert(
{functionImpl->name, std::move(impls)});
}
}

void Extension::addTypeVariant(const TypeVariantPtr& functionVariant) {
typeVariantMap_.insert({functionVariant->name, functionVariant});
void Extension::addTypeVariant(const TypeVariantPtr& typeVariant) {
typeVariantMap_.insert({typeVariant->name, typeVariant});
}

TypeVariantPtr Extension::lookupType(const std::string& typeName) const {
Expand All @@ -246,33 +247,33 @@ TypeVariantPtr Extension::lookupType(const std::string& typeName) const {
return nullptr;
}

void Extension::addScalarFunctionVariant(
const FunctionImplementationPtr& functionVariant) {
const auto& functionVariants =
scalarFunctionVariantMap_.find(functionVariant->name);
if (functionVariants != scalarFunctionVariantMap_.end()) {
auto& variants = functionVariants->second;
variants.emplace_back(functionVariant);
void Extension::addScalarFunctionImpl(
const FunctionImplementationPtr& functionImpl) {
const auto& functionImpls =
scalarFunctionImplMap_.find(functionImpl->name);
if (functionImpls != scalarFunctionImplMap_.end()) {
auto& impls = functionImpls->second;
impls.emplace_back(functionImpl);
} else {
std::vector<FunctionImplementationPtr> variants;
variants.emplace_back(functionVariant);
scalarFunctionVariantMap_.insert(
{functionVariant->name, std::move(variants)});
std::vector<FunctionImplementationPtr> impls;
impls.emplace_back(functionImpl);
scalarFunctionImplMap_.insert(
{functionImpl->name, std::move(impls)});
}
}

void Extension::addAggregateFunctionVariant(
const FunctionImplementationPtr& functionVariant) {
const auto& functionVariants =
aggregateFunctionVariantMap_.find(functionVariant->name);
if (functionVariants != aggregateFunctionVariantMap_.end()) {
auto& variants = functionVariants->second;
variants.emplace_back(functionVariant);
void Extension::addAggregateFunctionImpl(
const FunctionImplementationPtr& functionImpl) {
const auto& functionImpls =
aggregateFunctionImplMap_.find(functionImpl->name);
if (functionImpls != aggregateFunctionImplMap_.end()) {
auto& impls = functionImpls->second;
impls.emplace_back(functionImpl);
} else {
std::vector<FunctionImplementationPtr> variants;
variants.emplace_back(functionVariant);
aggregateFunctionVariantMap_.insert(
{functionVariant->name, std::move(variants)});
std::vector<FunctionImplementationPtr> impls;
impls.emplace_back(functionImpl);
aggregateFunctionImplMap_.insert(
{functionImpl->name, std::move(impls)});
}
}

Expand Down
8 changes: 4 additions & 4 deletions substrait/function/FunctionLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ namespace io::substrait {
FunctionImplementationPtr FunctionLookup::lookupFunction(
const FunctionSignature& signature) const {

const auto& functionImpls = getFunctionVariants();
const auto& functionImpls = getFunctionImpls();
auto functionImplsIter = functionImpls.find(signature.name);
if (functionImplsIter != functionImpls.end()) {
for (const auto& candidateFunctionVariant : functionImplsIter->second) {
if (candidateFunctionVariant->tryMatch(signature)) {
return candidateFunctionVariant;
for (const auto& candidateFunctionImpl : functionImplsIter->second) {
if (candidateFunctionImpl->tryMatch(signature)) {
return candidateFunctionImpl;
}
}
}
Expand Down

0 comments on commit 0a67c5b

Please sign in to comment.