Skip to content

Commit

Permalink
move other utils
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed May 19, 2024
1 parent bce1f1b commit 0338865
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 88 deletions.
81 changes: 80 additions & 1 deletion datafusion/physical-expr-common/src/aggregate/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
use std::{any::Any, sync::Arc};

use arrow::{
array::{ArrayRef, ArrowNativeTypeOp, AsArray},
compute::SortOptions,
datatypes::{DataType, Field},
datatypes::{
DataType, Decimal128Type, Field, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
ToByteSlice,
},
};
use datafusion_common::Result;
use datafusion_expr::Accumulator;

use crate::sort_expr::PhysicalSortExpr;

Expand All @@ -43,6 +50,60 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
}
}

/// Convert scalar values from an accumulator into arrays.
pub fn get_accum_scalar_values_as_arrays(
accum: &mut dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect()
}

/// Adjust array type metadata if needed
///
/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
/// default precision and scale, this function adjusts the output to
/// match `data_type`, if necessary
pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
let array = match data_type {
DataType::Decimal128(p, s) => Arc::new(
array
.as_primitive::<Decimal128Type>()
.clone()
.with_precision_and_scale(*p, *s)?,
) as ArrayRef,
DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
array
.as_primitive::<TimestampNanosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
array
.as_primitive::<TimestampMicrosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
array
.as_primitive::<TimestampMillisecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
array
.as_primitive::<TimestampSecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
// no adjustment needed for other arrays
_ => array,
};
Ok(array)
}

/// Construct corresponding fields for lexicographical ordering requirement expression
pub fn ordering_fields(
ordering_req: &[PhysicalSortExpr],
Expand All @@ -67,3 +128,21 @@ pub fn ordering_fields(
pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> {
ordering_req.iter().map(|item| item.options).collect()
}

/// A wrapper around a type to provide hash for floats
#[derive(Copy, Clone, Debug)]
pub struct Hashable<T>(pub T);

impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}

impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}

impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
93 changes: 6 additions & 87 deletions datafusion/physical-expr/src/aggregate/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,18 @@

//! Utilities used in aggregates

use std::sync::Arc;

// For backwards compatibility
pub use datafusion_physical_expr_common::aggregate::utils::{
down_cast_any_ref, get_sort_options, ordering_fields,
adjust_output_array, down_cast_any_ref, get_accum_scalar_values_as_arrays,
get_sort_options, ordering_fields, Hashable,
};

use arrow::array::{ArrayRef, ArrowNativeTypeOp};
use arrow_array::cast::AsArray;
use arrow_array::types::{
Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow_buffer::{ArrowNativeType, ToByteSlice};
use arrow_schema::DataType;
use arrow::array::ArrowNativeTypeOp;
use arrow_array::types::DecimalType;
use arrow_buffer::ArrowNativeType;
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;

/// Convert scalar values from an accumulator into arrays.
pub fn get_accum_scalar_values_as_arrays(
accum: &mut dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect()
}

// TODO: Move to functions-aggregate crate
/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
///
/// This is needed because different precisions for Decimal128/Decimal256 can
Expand Down Expand Up @@ -125,67 +108,3 @@ impl<T: DecimalType> DecimalAverager<T> {
}
}
}

/// Adjust array type metadata if needed
///
/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
/// default precision and scale, this function adjusts the output to
/// match `data_type`, if necessary
pub fn adjust_output_array(
data_type: &DataType,
array: ArrayRef,
) -> Result<ArrayRef, DataFusionError> {
let array = match data_type {
DataType::Decimal128(p, s) => Arc::new(
array
.as_primitive::<Decimal128Type>()
.clone()
.with_precision_and_scale(*p, *s)?,
) as ArrayRef,
DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => Arc::new(
array
.as_primitive::<TimestampNanosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => Arc::new(
array
.as_primitive::<TimestampMicrosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => Arc::new(
array
.as_primitive::<TimestampMillisecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(arrow_schema::TimeUnit::Second, tz) => Arc::new(
array
.as_primitive::<TimestampSecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
// no adjustment needed for other arrays
_ => array,
};
Ok(array)
}

/// A wrapper around a type to provide hash for floats
#[derive(Copy, Clone, Debug)]
pub(crate) struct Hashable<T>(pub T);

impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}

impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}

impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}

0 comments on commit 0338865

Please sign in to comment.