Skip to content

Commit

Permalink
large checkpoint
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed May 1, 2024
1 parent 662f7e2 commit ef72392
Show file tree
Hide file tree
Showing 24 changed files with 823 additions and 309 deletions.
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter(
let agg_expr = udaf::create_aggregate_expr(
fun,
&args,
&sort_exprs,
// &sort_exprs,
&ordering_reqs,
physical_input_schema,
name,
Expand Down
30 changes: 30 additions & 0 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,36 @@ impl Signature {
/// - Some(false) indicates that the function is monotonically decreasing w.r.t. the argument in question.
pub type FuncMonotonicity = Vec<Option<bool>>;

/// Creates a detailed error message for a function with wrong signature.
///
/// For example, a query like `select round(3.14, 1.1);` would yield:
/// ```text
/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
/// Candidate functions:
/// round(Float64, Int64)
/// round(Float32, Int64)
/// round(Float64)
/// round(Float32)
/// ```
pub fn generate_signature_error_msg(
func_name: &str,
func_signature: Signature,
input_expr_types: &[DataType],
) -> String {
let candidate_signatures = func_signature
.type_signature
.to_string_repr()
.iter()
.map(|args_str| format!("\t{func_name}({args_str})"))
.collect::<Vec<String>>()
.join("\n");

format!(
"No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
197 changes: 197 additions & 0 deletions datafusion/expr-common/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::{compute::can_cast_types, datatypes::{DataType, TimeUnit}};
use datafusion_common::utils::list_ndims;

use crate::{signature::{TIMEZONE_WILDCARD, FIXED_SIZE_LIST_WILDCARD}, type_coercion::binary::comparison_binary_numeric_coercion};

/// Return true if a value of type `type_from` can be coerced
/// (losslessly converted) into a value of `type_to`
///
/// See the module level documentation for more detail on coercion.
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
if type_into == type_from {
return true;
}
if let Some(coerced) = coerced_from(type_into, type_from) {
return coerced == *type_into;
}
false
}

fn coerced_from<'a>(
type_into: &'a DataType,
type_from: &'a DataType,
) -> Option<DataType> {
use self::DataType::*;

// match Dictionary first
match (type_into, type_from) {
// coerced dictionary first
(_, Dictionary(_, value_type))
if coerced_from(type_into, value_type).is_some() =>
{
Some(type_into.clone())
}
(Dictionary(_, value_type), _)
if coerced_from(value_type, type_from).is_some() =>
{
Some(type_into.clone())
}
// coerced into type_into
(Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()),
(Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => {
Some(type_into.clone())
}
(Int32, _)
if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) =>
{
Some(type_into.clone())
}
(Int64, _)
if matches!(
type_from,
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
) =>
{
Some(type_into.clone())
}
(UInt8, _) if matches!(type_from, Null | UInt8) => Some(type_into.clone()),
(UInt16, _) if matches!(type_from, Null | UInt8 | UInt16) => {
Some(type_into.clone())
}
(UInt32, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => {
Some(type_into.clone())
}
(UInt64, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => {
Some(type_into.clone())
}
(Float32, _)
if matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
) =>
{
Some(type_into.clone())
}
(Float64, _)
if matches!(
type_from,
Null | Int8
| Int16
| Int32
| Int64
| UInt8
| UInt16
| UInt32
| UInt64
| Float32
| Float64
| Decimal128(_, _)
) =>
{
Some(type_into.clone())
}
(Timestamp(TimeUnit::Nanosecond, None), _)
if matches!(
type_from,
Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8
) =>
{
Some(type_into.clone())
}
(Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => {
Some(type_into.clone())
}
// Any type can be coerced into strings
(Utf8 | LargeUtf8, _) => Some(type_into.clone()),
(Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),

(List(_), _) if matches!(type_from, FixedSizeList(_, _)) => {
Some(type_into.clone())
}

// Only accept list and largelist with the same number of dimensions unless the type is Null.
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
(List(_) | LargeList(_), _)
if datafusion_common::utils::base_type(type_from).eq(&Null)
|| list_ndims(type_from) == list_ndims(type_into) =>
{
Some(type_into.clone())
}
// should be able to coerce wildcard fixed size list to non wildcard fixed size list
(FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), _) => match type_from {
FixedSizeList(f_from, size_from) => {
match coerced_from(f_into.data_type(), f_from.data_type()) {
Some(data_type) if &data_type != f_into.data_type() => {
let new_field =
Arc::new(f_into.as_ref().clone().with_data_type(data_type));
Some(FixedSizeList(new_field, *size_from))
}
Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)),
_ => None,
}
}
_ => None,
},

(Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
match type_from {
Timestamp(_, Some(from_tz)) => {
Some(Timestamp(unit.clone(), Some(from_tz.clone())))
}
Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
// In the absence of any other information assume the time zone is "+00" (UTC).
Some(Timestamp(unit.clone(), Some("+00".into())))
}
_ => None,
}
}
(Timestamp(_, Some(_)), _)
if matches!(
type_from,
Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8
) =>
{
Some(type_into.clone())
}
// More coerce rules.
// Note that not all rules in `comparison_coercion` can be reused here.
// For example, all numeric types can be coerced into Utf8 for comparison,
// but not for function arguments.
_ => comparison_binary_numeric_coercion(type_into, type_from).and_then(
|coerced_type| {
if *type_into == coerced_type {
Some(coerced_type)
} else {
None
}
},
),
}
}
1 change: 1 addition & 0 deletions datafusion/expr-common/src/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
// under the License.

pub mod binary;
pub mod functions;
9 changes: 6 additions & 3 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ mod udaf;
mod udf;
mod udwf;

pub mod aggregate_function;
pub mod conditional_expressions;
pub mod execution_props;
pub mod expr;
Expand All @@ -52,10 +51,14 @@ pub mod var_provider;
pub mod window_frame;
pub mod window_state;

pub use aggregate_function::AggregateFunction;
pub mod aggregate_function {
pub use datafusion_functions_aggregate_common::builtin::AggregateFunction;
}

pub use datafusion_functions_aggregate_common::builtin::AggregateFunction;
pub use built_in_window_function::BuiltInWindowFunction;
pub use datafusion_expr_common::columnar_value::ColumnarValue;
pub use datafusion_expr_common::accumulator::Accumulator;
pub use datafusion_expr_common::columnar_value::ColumnarValue;
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};
pub use datafusion_expr_common::operator::Operator;
pub use datafusion_expr_common::signature::{
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@
//! i64. However, i64 -> i32 is never performed as there are i64
//! values which can not be represented by i32 values.

pub mod aggregates;
pub mod functions;
pub mod other;

pub mod aggregates {
pub use datafusion_functions_aggregate_common::type_coercion::*;
}

pub mod binary {
pub use datafusion_expr_common::type_coercion::binary::*;
}
Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ path = "src/lib.rs"
arrow = { workspace = true }
datafusion-common = { workspace = true }
datafusion-expr-common = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
strum = { version = "0.26.1", features = ["derive"] }
strum_macros = "0.26.0"
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
// specific language governing permissions and limitations
// under the License.

//! Aggregate function module contains all built-in aggregate functions definitions

use std::fmt::{Display, Formatter};
use std::str::FromStr;
use std::sync::Arc;
use std::{fmt, str::FromStr};

use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility};

use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};

use datafusion_expr_common::signature::{generate_signature_error_msg, Signature, TypeSignature, Volatility};
use strum_macros::EnumIter;

use crate::type_coercion::*;

/// Enum of all built-in aggregate functions
// Contributor's guide for adding new aggregate functions
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
Expand Down Expand Up @@ -154,8 +154,8 @@ impl AggregateFunction {
}
}

impl fmt::Display for AggregateFunction {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
impl Display for AggregateFunction {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
Expand Down Expand Up @@ -232,7 +232,7 @@ impl AggregateFunction {
.map_err(|_| {
plan_datafusion_err!(
"{}",
utils::generate_signature_error_msg(
generate_signature_error_msg(
&format!("{self}"),
self.signature(),
input_expr_types,
Expand Down Expand Up @@ -307,7 +307,7 @@ pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
let fun = AggregateFunction::Avg;
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
let coerced_data_types = coerce_types(
&fun,
input_expr_types,
&fun.signature(),
Expand Down Expand Up @@ -431,4 +431,4 @@ mod tests {
assert_eq!(func_from_str, func_original);
}
}
}
}
Loading

0 comments on commit ef72392

Please sign in to comment.