Skip to content

Commit

Permalink
Add support for Utf8View to crypto functions #13406 (#13407)
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 authored Nov 15, 2024
1 parent d840e98 commit 7e69580
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 19 deletions.
57 changes: 44 additions & 13 deletions datafusion/functions/src/crypto/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@

//! "crypto" DataFusion functions

use arrow::array::StringArray;
use arrow::array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait};
use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType;
use blake2::{Blake2b512, Blake2s256, Digest};
use blake3::Hasher as Blake3;
use datafusion_common::cast::as_binary_array;

use arrow::compute::StringArrayType;
use datafusion_common::plan_err;
use datafusion_common::{
cast::{as_generic_binary_array, as_generic_string_array},
exec_err, internal_err, DataFusionError, Result, ScalarValue,
cast::as_generic_binary_array, exec_err, internal_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::ColumnarValue;
use md5::Md5;
Expand Down Expand Up @@ -121,9 +122,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
}
let digest_algorithm = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
method.parse::<DigestAlgorithm>()
}
ScalarValue::Utf8View(Some(method))
| ScalarValue::Utf8(Some(method))
| ScalarValue::LargeUtf8(Some(method)) => method.parse::<DigestAlgorithm>(),
other => exec_err!("Unsupported data type {other:?} for function digest"),
},
ColumnarValue::Array(_) => {
Expand All @@ -132,6 +133,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
}?;
digest_process(&args[0], digest_algorithm)
}

impl FromStr for DigestAlgorithm {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<DigestAlgorithm> {
Expand Down Expand Up @@ -166,12 +168,14 @@ impl FromStr for DigestAlgorithm {
})
}
}

impl fmt::Display for DigestAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", format!("{self:?}").to_lowercase())
}
}
// /// computes md5 hash digest of the given input

/// computes md5 hash digest of the given input
pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return exec_err!(
Expand All @@ -180,7 +184,9 @@ pub fn md5(args: &[ColumnarValue]) -> Result<ColumnarValue> {
DigestAlgorithm::Md5
);
}

let value = digest_process(&args[0], DigestAlgorithm::Md5)?;

// md5 requires special handling because of its unique utf8 return type
Ok(match value {
ColumnarValue::Array(array) => {
Expand Down Expand Up @@ -214,7 +220,8 @@ pub fn utf8_or_binary_to_binary_type(
name: &str,
) -> Result<DataType> {
Ok(match arg_type {
DataType::LargeUtf8
DataType::Utf8View
| DataType::LargeUtf8
| DataType::Utf8
| DataType::Binary
| DataType::LargeBinary => DataType::Binary,
Expand Down Expand Up @@ -296,8 +303,30 @@ impl DigestAlgorithm {
where
T: OffsetSizeTrait,
{
let input_value = as_generic_string_array::<T>(value)?;
let array: ArrayRef = match self {
let array = match value.data_type() {
DataType::Utf8 | DataType::LargeUtf8 => {
let v = value.as_string::<T>();
self.digest_utf8_array_impl::<&GenericStringArray<T>>(v)
}
DataType::Utf8View => {
let v = value.as_string_view();
self.digest_utf8_array_impl::<&StringViewArray>(v)
}
other => {
return exec_err!("unsupported type for digest_utf_array: {other:?}")
}
};
Ok(ColumnarValue::Array(array))
}

pub fn digest_utf8_array_impl<'a, StringArrType>(
self,
input_value: StringArrType,
) -> ArrayRef
where
StringArrType: StringArrayType<'a>,
{
match self {
Self::Md5 => digest_to_array!(Md5, input_value),
Self::Sha224 => digest_to_array!(Sha224, input_value),
Self::Sha256 => digest_to_array!(Sha256, input_value),
Expand All @@ -318,8 +347,7 @@ impl DigestAlgorithm {
.collect();
Arc::new(binary_array)
}
};
Ok(ColumnarValue::Array(array))
}
}
}
pub fn digest_process(
Expand All @@ -328,6 +356,7 @@ pub fn digest_process(
) -> Result<ColumnarValue> {
match value {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8View => digest_algorithm.digest_utf8_array::<i32>(a.as_ref()),
DataType::Utf8 => digest_algorithm.digest_utf8_array::<i32>(a.as_ref()),
DataType::LargeUtf8 => digest_algorithm.digest_utf8_array::<i64>(a.as_ref()),
DataType::Binary => digest_algorithm.digest_binary_array::<i32>(a.as_ref()),
Expand All @@ -339,7 +368,9 @@ pub fn digest_process(
),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => {
ScalarValue::Utf8View(a)
| ScalarValue::Utf8(a)
| ScalarValue::LargeUtf8(a) => {
Ok(digest_algorithm
.digest_scalar(a.as_ref().map(|s: &String| s.as_bytes())))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions/src/crypto/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl DigestFunc {
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8, Utf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![Binary, Utf8]),
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/crypto/md5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Md5Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand All @@ -65,7 +65,7 @@ impl ScalarUDFImpl for Md5Func {
use DataType::*;
Ok(match &arg_types[0] {
LargeUtf8 | LargeBinary => LargeUtf8,
Utf8 | Binary => Utf8,
Utf8View | Utf8 | Binary => Utf8,
Null => Null,
Dictionary(_, t) => match **t {
LargeUtf8 | LargeBinary => LargeUtf8,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha224.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl SHA224Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl SHA256Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha384.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl SHA384Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/crypto/sha512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl SHA512Func {
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8, Binary, LargeBinary],
vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary],
Volatility::Immutable,
),
}
Expand Down
5 changes: 5 additions & 0 deletions datafusion/sqllogictest/test_files/expr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2225,6 +2225,11 @@ SELECT digest('','blake3');
----
af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262

# vverify utf8view
query ?
SELECT sha224(arrow_cast('tom', 'Utf8View'));
----
0bf6cb62649c42a9ae3876ab6f6d92ad36cb5414e495f8873292be4d

query T
SELECT substring('alphabet', 1)
Expand Down
60 changes: 60 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,66 @@ logical_plan
01)Projection: nullif(test.column1_utf8view, test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for md5
query TT
EXPLAIN SELECT
md5(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: md5(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha224
query TT
EXPLAIN SELECT
sha224(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha224(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha256
query TT
EXPLAIN SELECT
sha256(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha256(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha384
query TT
EXPLAIN SELECT
sha384(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha384(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for sha512
query TT
EXPLAIN SELECT
sha512(column1_utf8view) as c
FROM test;
----
logical_plan
01)Projection: sha512(test.column1_utf8view) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for digest
query TT
EXPLAIN SELECT
digest(column1_utf8view, 'md5') as c
FROM test;
----
logical_plan
01)Projection: digest(test.column1_utf8view, Utf8View("md5")) AS c
02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for binary operators
# `~` operator (regex match)
query TT
Expand Down

0 comments on commit 7e69580

Please sign in to comment.