Skip to content

Commit

Permalink
add substring_index() function in sparksql
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Aug 8, 2023
1 parent 4780a2e commit d788f8e
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 1 deletion.
8 changes: 8 additions & 0 deletions velox/docs/functions/spark/string.rst
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ Unless specified otherwise, all functions return NULL if at least one of the arg
SELECT substring('Spark SQL', -10, 3); -- "Sp"
SELECT substring('Spark SQL', -20, 3); -- ""

.. spark:function:: substring_index(string, delim, count) -> varchar
Returns the substring from ``string`` before ``count`` occurrences of the delimiter ``delim``.
If ``count`` is positive, everything to the left of the final delimiter
(counting from the left) is returned. If ``count`` is negative, everything to the right
of the final delimiter (counting from the right) is returned. The function
substring_index performs a case-sensitive match when searching for ``delim``.

.. spark:function:: translate(string, match, replace) -> varchar
Returns a new translated string. It translates the character in ``string`` by a
Expand Down
1 change: 1 addition & 0 deletions velox/expression/tests/SparkExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ int main(int argc, char** argv) {
"replace",
"might_contain",
"unix_timestamp"};

return FuzzerRunner::run(
FLAGS_only, FLAGS_seed, skipFunctions, FLAGS_special_forms);
}
2 changes: 1 addition & 1 deletion velox/functions/lib/string/StringCore.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ static int64_t findNthInstanceByteIndexFromStart(

// Find next occurrence
return findNthInstanceByteIndexFromStart(
string, subString, instance - 1, byteIndex + subString.size());
string, subString, instance - 1, byteIndex + 1);
}

/// Returns the start byte index of the Nth instance of subString in
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ void registerFunctions(const std::string& prefix) {
prefix + "instr", instrSignatures(), makeInstr);
exec::registerStatefulVectorFunction(
prefix + "length", lengthSignatures(), makeLength);
registerFunction<SubstringIndexFunction, Varchar, Varchar, Varchar, int32_t>(
{prefix + "substring_index"});

registerFunction<Md5Function, Varchar, Varbinary>({prefix + "md5"});
registerFunction<Sha1HexStringFunction, Varchar, Varbinary>(
Expand Down
74 changes: 74 additions & 0 deletions velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,80 @@ struct EndsWithFunction {
}
};

/// substring_index function
/// substring_index(string, string, int) -> string
/// substring_index(str, delim, count) - Returns the substring from str before
/// count occurrences of the delimiter delim. If count is positive, everything
/// to the left of the final delimiter (counting from the left) is returned. If
/// count is negative, everything to the right of the final delimiter (counting
/// from the right) is returned. The function substring_index performs a
/// case-sensitive match when searching for delim.
template <typename T>
struct SubstringIndexFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

// Results refer to strings in the first argument.
static constexpr int32_t reuse_strings_from_arg = 0;

// ASCII input always produces ASCII result.
static constexpr bool is_default_ascii_behavior = true;

FOLLY_ALWAYS_INLINE void call(
out_type<Varchar>& result,
const arg_type<Varchar>& str,
const arg_type<Varchar>& delim,
const int32_t& count) {
doCall<false>(result, str, delim, count);
}

FOLLY_ALWAYS_INLINE void callAscii(
out_type<Varchar>& result,
const arg_type<Varchar>& str,
const arg_type<Varchar>& delim,
const int32_t& count) {
doCall<true>(result, str, delim, count);
}

template <bool isAscii>
FOLLY_ALWAYS_INLINE void doCall(
out_type<Varchar>& result,
const arg_type<Varchar>& str,
const arg_type<Varchar>& delim,
const int32_t& count) {
if (count == 0) {
result.setEmpty();
return;
}

int64_t index;
if (count > 0) {
index = stringImpl::stringPosition<isAscii, true>(str, delim, count);
} else {
index = stringImpl::stringPosition<isAscii, false>(str, delim, -count);
}

auto start = 1;
auto length = stringImpl::length<isAscii>(str);
auto delimLength = stringImpl::length<isAscii>(delim);

if (index != 0) {
if (count > 0) {
length = index - 1;
} else {
start = index + delimLength;
length = length - index - delimLength + 1;
}
}

auto byteRange =
stringCore::getByteRange<isAscii>(str.data(), start, length);

// Generating output string
result.setNoCopy(StringView(
str.data() + byteRange.first, byteRange.second - byteRange.first));
}
};

/// ltrim(trimStr, srcStr) -> varchar
/// Remove leading specified characters from srcStr. The specified character
/// is any character contained in trimStr.
Expand Down
32 changes: 32 additions & 0 deletions velox/functions/sparksql/tests/StringTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ class StringTest : public SparkFunctionBaseTest {
return evaluateOnce<std::string>("left(c0, c1)", str, length);
}

std::optional<std::string> substring_index(
const std::optional<std::string>& str,
const std::optional<std::string>& delim,
int32_t count) {
return evaluateOnce<std::string, std::string, std::string, int32_t>(
"substring_index(c0, c1, c2)", str, delim, count);
}

std::optional<std::string> overlay(
std::optional<std::string> input,
std::optional<std::string> replace,
Expand Down Expand Up @@ -370,6 +378,30 @@ TEST_F(StringTest, endsWith) {
EXPECT_EQ(endsWith(std::nullopt, "abc"), std::nullopt);
}

TEST_F(StringTest, substring_index) {
EXPECT_EQ(substring_index("www.apache.org", ".", 3), "www.apache.org");
EXPECT_EQ(substring_index("www.apache.org", ".", 2), "www.apache");
EXPECT_EQ(substring_index("www.apache.org", ".", 1), "www");
EXPECT_EQ(substring_index("www.apache.org", ".", 0), "");
EXPECT_EQ(substring_index("www.apache.org", ".", -1), "org");
EXPECT_EQ(substring_index("www.apache.org", ".", -2), "apache.org");
EXPECT_EQ(substring_index("www.apache.org", ".", -3), "www.apache.org");
// Str is empty string.
EXPECT_EQ(substring_index("", ".", 1), "");
// Empty string delim.
EXPECT_EQ(substring_index("www.apache.org", "", 1), "");
// Delim does not exist in str.
EXPECT_EQ(substring_index("www.apache.org", "#", 2), "www.apache.org");
// Delim is 2 chars.
EXPECT_EQ(substring_index("www||apache||org", "||", 2), "www||apache");
EXPECT_EQ(substring_index("www||apache||org", "||", -2), "apache||org");
// Non ascii chars.
EXPECT_EQ(substring_index("大千世界大千世界", "", 2), "大千世界大");
// Overlapped delim.
EXPECT_EQ(substring_index("||||||", "|||", 3), "||");
EXPECT_EQ(substring_index("||||||", "|||", -4), "|||");
}

TEST_F(StringTest, trim) {
EXPECT_EQ(trim(""), "");
EXPECT_EQ(trim(" data\t "), "data\t");
Expand Down

0 comments on commit d788f8e

Please sign in to comment.