From fd1a43bc0fb502e0bbc27ff220b9f1caa94ebfdd Mon Sep 17 00:00:00 2001 From: jiao_lv Date: Tue, 11 Oct 2022 18:29:52 +0800 Subject: [PATCH] Register custom row constructor function for Gluten. (#404) --- cpp/velox/CMakeLists.txt | 1 + cpp/velox/compute/RegistrationAllFunctions.cc | 39 +++++++++++ cpp/velox/compute/RegistrationAllFunctions.h | 22 ++++++ cpp/velox/compute/RowConstructor.cc | 69 +++++++++++++++++++ cpp/velox/compute/VeloxPlanConverter.cc | 7 +- cpp/velox/jni/jni_wrapper.cc | 2 + 6 files changed, 136 insertions(+), 4 deletions(-) create mode 100644 cpp/velox/compute/RegistrationAllFunctions.cc create mode 100644 cpp/velox/compute/RegistrationAllFunctions.h create mode 100644 cpp/velox/compute/RowConstructor.cc diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index 786db8a65086..03714295740d 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -156,6 +156,7 @@ set(VELOX_SRCS compute/VeloxPlanConverter.cc compute/ArrowTypeUtils.cc compute/VeloxToRowConverter.cc + compute/RowConstructor.cc compute/DwrfDatasource.cc compute/bridge.cc memory/velox_memory_pool.cc diff --git a/cpp/velox/compute/RegistrationAllFunctions.cc b/cpp/velox/compute/RegistrationAllFunctions.cc new file mode 100644 index 000000000000..668f075f2ffc --- /dev/null +++ b/cpp/velox/compute/RegistrationAllFunctions.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RegistrationAllFunctions.h" +#include "RowConstructor.cc" +#include "velox/functions/sparksql/Register.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +namespace velox::compute { + +void registerCustomFunctions() { + exec::registerVectorFunction( + "row_constructor", + std::vector>{}, + std::make_unique()); +} + +void registerAllFunctions() { + functions::prestosql::registerAllScalarFunctions(); + functions::sparksql::registerFunctions(""); + registerCustomFunctions(); +} + +} // namespace velox::compute diff --git a/cpp/velox/compute/RegistrationAllFunctions.h b/cpp/velox/compute/RegistrationAllFunctions.h new file mode 100644 index 000000000000..d423dad297c3 --- /dev/null +++ b/cpp/velox/compute/RegistrationAllFunctions.h @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace velox::compute { + +void registerAllFunctions(); + +} // namespace velox::compute diff --git a/cpp/velox/compute/RowConstructor.cc b/cpp/velox/compute/RowConstructor.cc new file mode 100644 index 000000000000..80989fa15335 --- /dev/null +++ b/cpp/velox/compute/RowConstructor.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/VectorFunction.h" +#include "velox/type/Type.h" + +using namespace facebook::velox; +using namespace facebook::velox::exec; + +namespace velox::compute { + +namespace { +class RowConstructor : public exec::VectorFunction { + void apply( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const override { + auto argsCopy = args; + + BufferPtr nulls = AlignedBuffer::allocate( + bits::nbytes(rows.size()), context.pool(), 1); + auto* nullsPtr = nulls->asMutable(); + auto cntNull = 0; + rows.applyToSelected([&](vector_size_t i) { + bits::clearNull(nullsPtr, i); + if (!bits::isBitNull(nullsPtr, i)) { + for (size_t c = 0; c < argsCopy.size(); c++) { + auto arg = argsCopy[c].get(); + if (arg->mayHaveNulls() && arg->isNullAt(i)) { + bits::setNull(nullsPtr, i, true); + cntNull++; + break; + } + } + } + }); + + RowVectorPtr localResult = std::make_shared( + context.pool(), + outputType, + nulls, + rows.size(), + std::move(argsCopy), + cntNull /*nullCount*/); + context.moveOrCopyResult(localResult, rows, result); + } + + bool isDefaultNullBehavior() const override { + return false; + } +}; + +} // namespace +} // namespace velox::compute diff --git a/cpp/velox/compute/VeloxPlanConverter.cc b/cpp/velox/compute/VeloxPlanConverter.cc index 52bc50f5e7f0..e73a011bb7b4 100644 --- a/cpp/velox/compute/VeloxPlanConverter.cc +++ b/cpp/velox/compute/VeloxPlanConverter.cc @@ -24,6 +24,7 @@ #include #include "ArrowTypeUtils.h" +#include "RegistrationAllFunctions.cc" #include "arrow/c/Bridge.h" #include "arrow/c/bridge.h" #include "bridge.h" @@ -33,7 +34,6 @@ #include "velox/functions/prestosql/aggregates/AverageAggregate.h" #include "velox/functions/prestosql/aggregates/CountAggregate.h" #include "velox/functions/prestosql/aggregates/MinMaxAggregates.h" -#include "velox/functions/sparksql/Register.h" using namespace facebook::velox; using namespace facebook::velox::exec; @@ -95,9 +95,8 @@ void VeloxInitializer::Init() { parquet::registerParquetReaderFactory(ParquetReaderType::NATIVE); // parquet::registerParquetReaderFactory(ParquetReaderType::DUCKDB); dwrf::registerDwrfReaderFactory(); - // Register Velox functions. - functions::sparksql::registerFunctions(""); - functions::prestosql::registerAllScalarFunctions(); + // Register Velox functions + registerAllFunctions(); aggregate::registerSumAggregate("sum"); aggregate::registerAverageAggregate("avg"); aggregate::registerCountAggregate("count"); diff --git a/cpp/velox/jni/jni_wrapper.cc b/cpp/velox/jni/jni_wrapper.cc index bb17248bfd18..8a49de609890 100644 --- a/cpp/velox/jni/jni_wrapper.cc +++ b/cpp/velox/jni/jni_wrapper.cc @@ -21,6 +21,7 @@ #include #include "compute/DwrfDatasource.h" +#include "compute/RegistrationAllFunctions.h" #include "compute/VeloxPlanConverter.h" #include "jni/jni_errors.h" #include "memory/velox_memory_pool.h" @@ -127,6 +128,7 @@ Java_io_glutenproject_vectorized_ExpressionEvaluatorJniWrapper_nativeDoValidate( std::unique_ptr execCtx_ = std::make_unique(pool.get(), queryCtx_.get()); + velox::compute::registerAllFunctions(); auto planValidator = std::make_shared< facebook::velox::substrait::SubstraitToVeloxPlanValidator>( pool.get(), execCtx_.get());