Skip to content

Commit

Permalink
Add simdjson based get_json_object Spark function (5179)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored and marin-ma committed Jan 5, 2024
1 parent 5623f44 commit c5501a5
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 5 deletions.
6 changes: 4 additions & 2 deletions velox/docs/functions/spark/json.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ JSON Functions

.. spark:function:: get_json_object(json, path) -> varchar
Extracts a json object from path::
Extracts a json object from ``path``. Returns NULL if it finds json string
is malformed. ::

SELECT get_json_object('{"a":"b"}', '$.a'); -- b
SELECT get_json_object('{"a":"b"}', '$.a'); -- 'b'
SELECT get_json_object('{"a":{"b":"c"}}', '$.a'); -- '{"b":"c"}'
3 changes: 2 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ target_link_libraries(
velox_functions_spark_specialforms
velox_is_null_functions
velox_functions_util
Folly::folly)
Folly::folly
simdjson)

set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE
high_memory_pool)
Expand Down
4 changes: 2 additions & 2 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/prestosql/DateTimeFunctions.h"
#include "velox/functions/prestosql/JsonFunctions.h"
#include "velox/functions/prestosql/StringFunctions.h"
#include "velox/functions/sparksql/ArrayMinMaxFunction.h"
#include "velox/functions/sparksql/ArraySort.h"
Expand All @@ -36,6 +35,7 @@
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
#include "velox/functions/sparksql/RegisterCompare.h"
#include "velox/functions/sparksql/SIMDJsonFunctions.h"
#include "velox/functions/sparksql/Size.h"
#include "velox/functions/sparksql/String.h"
#include "velox/functions/sparksql/StringToMap.h"
Expand Down Expand Up @@ -128,7 +128,7 @@ void registerFunctions(const std::string& prefix) {
// Register size functions
registerSize(prefix + "size");

registerFunction<JsonExtractScalarFunction, Varchar, Varchar, Varchar>(
registerFunction<SIMDGetJsonObjectFunction, Varchar, Varchar, Varchar>(
{prefix + "get_json_object"});

// Register string functions.
Expand Down
182 changes: 182 additions & 0 deletions velox/functions/sparksql/SIMDJsonFunctions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/prestosql/SIMDJsonFunctions.h"

using namespace simdjson;

namespace facebook::velox::functions::sparksql {

template <typename T>
struct SIMDGetJsonObjectFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);
std::optional<std::string> formattedJsonPath_;

// ASCII input always produces ASCII result.
static constexpr bool is_default_ascii_behavior = true;

// Makes a conversion from spark's json path, e.g., converts
// "$.a.b" to "/a/b".
FOLLY_ALWAYS_INLINE std::string getFormattedJsonPath(
const arg_type<Varchar>& jsonPath) {
char formattedJsonPath[jsonPath.size() + 1];
int j = 0;
for (int i = 0; i < jsonPath.size(); i++) {
if (jsonPath.data()[i] == '$' || jsonPath.data()[i] == ']' ||
jsonPath.data()[i] == '\'') {
continue;
} else if (jsonPath.data()[i] == '[' || jsonPath.data()[i] == '.') {
formattedJsonPath[j] = '/';
j++;
} else {
formattedJsonPath[j] = jsonPath.data()[i];
j++;
}
}
formattedJsonPath[j] = '\0';
return std::string(formattedJsonPath, j + 1);
}

FOLLY_ALWAYS_INLINE void initialize(
const core::QueryConfig& config,
const arg_type<Varchar>* /*json*/,
const arg_type<Varchar>* jsonPath) {
if (jsonPath != nullptr) {
formattedJsonPath_ = getFormattedJsonPath(*jsonPath);
}
}

FOLLY_ALWAYS_INLINE simdjson::error_code extractStringResult(
simdjson_result<ondemand::value> rawResult,
out_type<Varchar>& result) {
simdjson::error_code error;
std::stringstream ss;
switch (rawResult.type()) {
// For number and bool types, we need to explicitly get the value
// for specific types instead of using `ss << rawResult`. Thus, we
// can make simdjson's internal parsing position moved and then we
// can check the validity of ending character.
case ondemand::json_type::number: {
switch (rawResult.get_number_type()) {
case ondemand::number_type::unsigned_integer: {
uint64_t numberResult;
error = rawResult.get_uint64().get(numberResult);
if (!error) {
ss << numberResult;
result.append(ss.str());
}
return error;
}
case ondemand::number_type::signed_integer: {
int64_t numberResult;
error = rawResult.get_int64().get(numberResult);
if (!error) {
ss << numberResult;
result.append(ss.str());
}
return error;
}
case ondemand::number_type::floating_point_number: {
double numberResult;
error = rawResult.get_double().get(numberResult);
if (!error) {
ss << numberResult;
result.append(ss.str());
}
return error;
}
default:
VELOX_UNREACHABLE();
}
}
case ondemand::json_type::boolean: {
bool boolResult;
error = rawResult.get_bool().get(boolResult);
if (!error) {
result.append(boolResult ? "true" : "false");
}
return error;
}
case ondemand::json_type::string: {
std::string_view stringResult;
error = rawResult.get_string().get(stringResult);
result.append(stringResult);
return error;
}
case ondemand::json_type::object: {
// For nested case, e.g., for "{"my": {"hello": 10}}", "$.my" will
// return an object type.
ss << rawResult;
result.append(ss.str());
return SUCCESS;
}
case ondemand::json_type::array: {
ss << rawResult;
result.append(ss.str());
return SUCCESS;
}
default: {
return UNSUPPORTED_ARCHITECTURE;
}
}
}

// This is a simple validation by checking whether the obtained result is
// followed by valid char. Because ondemand parsing we are using ignores json
// format validation for characters following the current parsing position.
bool isValidEndingCharacter(const char* currentPos) {
char endingChar = *currentPos;
if (endingChar == ',' || endingChar == '}' || endingChar == ']') {
return true;
}
if (endingChar == ' ' || endingChar == '\r' || endingChar == '\n' ||
endingChar == '\t') {
// These chars can be prior to a valid ending char.
return isValidEndingCharacter(currentPos++);
}
return false;
}

FOLLY_ALWAYS_INLINE bool call(
out_type<Varchar>& result,
const arg_type<Varchar>& json,
const arg_type<Varchar>& jsonPath) {
ParserContext ctx(json.data(), json.size());
try {
ctx.parseDocument();
simdjson_result<ondemand::value> rawResult =
formattedJsonPath_.has_value()
? ctx.jsonDoc.at_pointer(formattedJsonPath_.value().data())
: ctx.jsonDoc.at_pointer(getFormattedJsonPath(jsonPath).data());
// Field not found.
if (rawResult.error() == NO_SUCH_FIELD) {
return false;
}
auto error = extractStringResult(rawResult, result);
if (error) {
return false;
}
} catch (simdjson_error& e) {
return false;
}

const char* currentPos;
ctx.jsonDoc.current_location().get(currentPos);
return isValidEndingCharacter(currentPos);
}
};

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ add_executable(
ElementAtTest.cpp
HashTest.cpp
InTest.cpp
JsonFunctionsTest.cpp
LeastGreatestTest.cpp
MapTest.cpp
MightContainTest.cpp
Expand Down
105 changes: 105 additions & 0 deletions velox/functions/sparksql/tests/JsonFunctionsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/prestosql/types/JsonType.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"
#include "velox/type/Type.h"

#include <stdint.h>

namespace facebook::velox::functions::sparksql::test {
namespace {

class JsonFunctionTest : public SparkFunctionBaseTest {
protected:
std::optional<std::string> getJsonObject(
const std::optional<std::string>& json,
const std::optional<std::string>& jsonPath) {
return evaluateOnce<std::string>("get_json_object(c0, c1)", json, jsonPath);
}
};

TEST_F(JsonFunctionTest, getJsonObject) {
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hello"), "3.5");
EXPECT_EQ(getJsonObject(R"({"hello": 3.5})", "$.hello"), "3.5");
EXPECT_EQ(getJsonObject(R"({"hello": 292222730})", "$.hello"), "292222730");
EXPECT_EQ(getJsonObject(R"({"hello": -292222730})", "$.hello"), "-292222730");
EXPECT_EQ(getJsonObject(R"({"my": {"hello": 3.5}})", "$.my.hello"), "3.5");
EXPECT_EQ(getJsonObject(R"({"my": {"hello": true}})", "$.my.hello"), "true");
EXPECT_EQ(getJsonObject(R"({"hello": ""})", "$.hello"), "");
EXPECT_EQ(
getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.age"),
"5");
EXPECT_EQ(
getJsonObject(R"({"name": "Alice", "age": 5, "id": "001"})", "$.id"),
"001");
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])",
"$[0]['my']['param']['age']"),
"5");
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}}, {"other": "v1"}])",
"$[0].my.param.age"),
"5");

// Json object as result.
EXPECT_EQ(
getJsonObject(
R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})",
"$.my.param"),
R"({"name": "Alice", "age": "5", "id": "001"})");
EXPECT_EQ(
getJsonObject(
R"({"my": {"param": {"name": "Alice", "age": "5", "id": "001"}}})",
"$['my']['param']"),
R"({"name": "Alice", "age": "5", "id": "001"})");

// Array as result.
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])",
"$[1].other"),
R"(["v1", "v2"])");
// Array element as result.
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])",
"$[1].other[0]"),
"v1");
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"}}}, {"other": ["v1", "v2"]}])",
"$[1].other[1]"),
"v2");

// Field not found.
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$.hi"), std::nullopt);
// Illegal json.
EXPECT_EQ(getJsonObject(R"({"hello"-3.5})", "$.hello"), std::nullopt);
// Illegal json path.
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$hello"), std::nullopt);
EXPECT_EQ(getJsonObject(R"({"hello": "3.5"})", "$."), std::nullopt);
// Invalid ending character.
EXPECT_EQ(
getJsonObject(
R"([{"my": {"param": {"name": "Alice"quoted""}}}, {"other": ["v1", "v2"]}])",
"$[0].my.param.name"),
std::nullopt);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit c5501a5

Please sign in to comment.