Skip to content

Commit

Permalink
Fix panic in multiple distinct aggregates by fixing ScalarValue::new_…
Browse files Browse the repository at this point in the history
…list
  • Loading branch information
alamb committed Oct 30, 2023
1 parent bb1d7f9 commit 28ef629
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 202 deletions.
11 changes: 11 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ opt-level = 3
overflow-checks = false
panic = 'unwind'
rpath = false

[patch.crates-io]
arrow = { path= "/Users/alamb/Software/arrow-rs2/arrow" }
arrow-array = { path= "/Users/alamb/Software/arrow-rs2/arrow-array" }
arrow-buffer = { path= "/Users/alamb/Software/arrow-rs2/arrow-buffer" }
arrow-schema = { path= "/Users/alamb/Software/arrow-rs2/arrow-schema" }
arrow-select = { path= "/Users/alamb/Software/arrow-rs2/arrow-select" }
arrow-string = { path= "/Users/alamb/Software/arrow-rs2/arrow-string" }
arrow-ord = { path= "/Users/alamb/Software/arrow-rs2/arrow-ord" }
arrow-flight = { path= "/Users/alamb/Software/arrow-rs2/arrow-flight" }
parquet = { path= "/Users/alamb/Software/arrow-rs2/parquet" }
204 changes: 10 additions & 194 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,117 +600,6 @@ macro_rules! typed_cast {
}};
}

macro_rules! build_timestamp_list {
($TIME_UNIT:expr, $TIME_ZONE:expr, $VALUES:expr, $SIZE:expr) => {{
match $VALUES {
// the return on the macro is necessary, to short-circuit and return ArrayRef
None => {
return new_null_array(
&DataType::List(Arc::new(Field::new(
"item",
DataType::Timestamp($TIME_UNIT, $TIME_ZONE),
true,
))),
$SIZE,
)
}
Some(values) => match $TIME_UNIT {
TimeUnit::Second => {
build_values_list_tz!(
TimestampSecondBuilder,
TimestampSecond,
values,
$SIZE,
$TIME_ZONE
)
}
TimeUnit::Millisecond => build_values_list_tz!(
TimestampMillisecondBuilder,
TimestampMillisecond,
values,
$SIZE,
$TIME_ZONE
),
TimeUnit::Microsecond => build_values_list_tz!(
TimestampMicrosecondBuilder,
TimestampMicrosecond,
values,
$SIZE,
$TIME_ZONE
),
TimeUnit::Nanosecond => build_values_list_tz!(
TimestampNanosecondBuilder,
TimestampNanosecond,
values,
$SIZE,
$TIME_ZONE
),
},
}
}};
}

macro_rules! new_builder {
(StringBuilder, $len:expr) => {
StringBuilder::new()
};
(LargeStringBuilder, $len:expr) => {
LargeStringBuilder::new()
};
($el:ident, $len:expr) => {{
<$el>::with_capacity($len)
}};
}

macro_rules! build_values_list {
($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr) => {{
let builder = new_builder!($VALUE_BUILDER_TY, $VALUES.len());
let mut builder = ListBuilder::new(builder);

for _ in 0..$SIZE {
for scalar_value in $VALUES {
match scalar_value {
ScalarValue::$SCALAR_TY(Some(v)) => {
builder.values().append_value(v.clone());
}
ScalarValue::$SCALAR_TY(None) => {
builder.values().append_null();
}
_ => panic!("Incompatible ScalarValue for list"),
};
}
builder.append(true);
}

builder.finish()
}};
}

macro_rules! build_values_list_tz {
($VALUE_BUILDER_TY:ident, $SCALAR_TY:ident, $VALUES:expr, $SIZE:expr, $TIME_ZONE:expr) => {{
let mut builder = ListBuilder::new(
$VALUE_BUILDER_TY::with_capacity($VALUES.len()).with_timezone_opt($TIME_ZONE),
);

for _ in 0..$SIZE {
for scalar_value in $VALUES {
match scalar_value {
ScalarValue::$SCALAR_TY(Some(v), _) => {
builder.values().append_value(v.clone());
}
ScalarValue::$SCALAR_TY(None, _) => {
builder.values().append_null();
}
_ => panic!("Incompatible ScalarValue for list"),
};
}
builder.append(true);
}

builder.finish()
}};
}

macro_rules! build_array_from_option {
($DATA_TYPE:ident, $ARRAY_TYPE:ident, $EXPR:expr, $SIZE:expr) => {{
match $EXPR {
Expand Down Expand Up @@ -1198,7 +1087,8 @@ impl ScalarValue {
}

/// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`]
/// corresponding to those values. For example,
/// corresponding to those values. For example, an iterator of
/// [`ScalarValue::Int32`] would be converted to an [`Int32Array`].
///
/// Returns an error if the iterator is empty or if the
/// [`ScalarValue`]s are not all the same type
Expand Down Expand Up @@ -1654,41 +1544,6 @@ impl ScalarValue {
Ok(array)
}

/// This function does not contains nulls but empty array instead.
fn iter_to_array_list_without_nulls(
values: &[ScalarValue],
data_type: &DataType,
) -> Result<GenericListArray<i32>> {
let mut elements: Vec<ArrayRef> = vec![];
let mut offsets = vec![];

if values.is_empty() {
offsets.push(0);
} else {
let arr = ScalarValue::iter_to_array(values.to_vec())?;
offsets.push(arr.len());
elements.push(arr);
}

// Concatenate element arrays to create single flat array
let flat_array = if elements.is_empty() {
new_empty_array(data_type)
} else {
let element_arrays: Vec<&dyn Array> =
elements.iter().map(|a| a.as_ref()).collect();
arrow::compute::concat(&element_arrays)?
};

let list_array = ListArray::new(
Arc::new(Field::new("item", flat_array.data_type().to_owned(), true)),
OffsetBuffer::<i32>::from_lengths(offsets),
flat_array,
None,
);

Ok(list_array)
}

/// This function build with nulls with nulls buffer.
fn iter_to_array_list(
scalars: impl IntoIterator<Item = ScalarValue>,
Expand Down Expand Up @@ -1776,7 +1631,8 @@ impl ScalarValue {
.unwrap()
}

/// Converts `Vec<ScalaValue>` to ListArray, simplified version of ScalarValue::to_array
/// Converts `Vec<ScalaValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
///
/// Example
/// ```
Expand All @@ -1802,52 +1658,12 @@ impl ScalarValue {
/// assert_eq!(result, &expected);
/// ```
pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> ArrayRef {
Arc::new(match data_type {
DataType::Boolean => build_values_list!(BooleanBuilder, Boolean, values, 1),
DataType::Int8 => build_values_list!(Int8Builder, Int8, values, 1),
DataType::Int16 => build_values_list!(Int16Builder, Int16, values, 1),
DataType::Int32 => build_values_list!(Int32Builder, Int32, values, 1),
DataType::Int64 => build_values_list!(Int64Builder, Int64, values, 1),
DataType::UInt8 => build_values_list!(UInt8Builder, UInt8, values, 1),
DataType::UInt16 => build_values_list!(UInt16Builder, UInt16, values, 1),
DataType::UInt32 => build_values_list!(UInt32Builder, UInt32, values, 1),
DataType::UInt64 => build_values_list!(UInt64Builder, UInt64, values, 1),
DataType::Utf8 => build_values_list!(StringBuilder, Utf8, values, 1),
DataType::LargeUtf8 => {
build_values_list!(LargeStringBuilder, LargeUtf8, values, 1)
}
DataType::Float32 => build_values_list!(Float32Builder, Float32, values, 1),
DataType::Float64 => build_values_list!(Float64Builder, Float64, values, 1),
DataType::Timestamp(unit, tz) => {
let values = Some(values);
build_timestamp_list!(unit.clone(), tz.clone(), values, 1)
}
DataType::List(_) | DataType::Struct(_) => {
ScalarValue::iter_to_array_list_without_nulls(values, data_type).unwrap()
}
DataType::Decimal128(precision, scale) => {
let mut vals = vec![];
for value in values.iter() {
if let ScalarValue::Decimal128(v, _, _) = value {
vals.push(v.to_owned())
}
}

let arr = Decimal128Array::from(vals)
.with_precision_and_scale(*precision, *scale)
.unwrap();
wrap_into_list_array(Arc::new(arr))
}

DataType::Null => {
let arr = new_null_array(&DataType::Null, values.len());
wrap_into_list_array(arr)
}
_ => panic!(
"Unsupported data type {:?} for ScalarValue::list_to_array",
data_type
),
})
let values = if values.is_empty() {
new_empty_array(data_type)
} else {
Self::iter_to_array(values.iter().cloned()).unwrap()
};
Arc::new(wrap_into_list_array(values))
}

/// Converts a scalar value into an array of `size` rows.
Expand Down
35 changes: 27 additions & 8 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2020,14 +2020,6 @@ statement ok
drop table t;




statement error DataFusion error: Execution error: Table 't_source' doesn't exist\.
drop table t_source;

statement error DataFusion error: Execution error: Table 't' doesn't exist\.
drop table t;

query I
select median(a) from (select 1 as a where 1=0);
----
Expand Down Expand Up @@ -2199,6 +2191,26 @@ NULL 1 10.1 10.1 10.1 10.1 0 NULL
statement ok
set datafusion.sql_parser.dialect = 'Generic';

## Multiple distinct aggregates and dictionaries
statement ok
create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)'));

query I?
select * from dict_test;
----
1 foo
2 bar

query II
select count(distinct column1), count(distinct column2) from dict_test group by column1;
----
1 1
1 1

statement error DataFusion error: SQL error: ParserError\("Expected end of statement, found: test"\)
drop table dict_ test;


# Prepare the table with dictionary values for testing
statement ok
CREATE TABLE value(x bigint) AS VALUES (1), (2), (3), (1), (3), (4), (5), (2);
Expand Down Expand Up @@ -2282,6 +2294,13 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
4
5

statement ok
drop table value

statement ok
drop table value_dict


# bool aggregation
statement ok
CREATE TABLE value_bool(x boolean, g int) AS VALUES (NULL, 0), (false, 0), (true, 0), (false, 1), (true, 2), (NULL, 3);
Expand Down

0 comments on commit 28ef629

Please sign in to comment.