Skip to content

Commit

Permalink
Add random queries into aggregate fuzz tester
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Oct 14, 2024
1 parent 708bf74 commit 4c5d621
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 19 deletions.
41 changes: 28 additions & 13 deletions datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ use rand::{Rng, SeedableRng};
use tokio::task::JoinSet;

use crate::fuzz_cases::aggregation_fuzzer::{
AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig,
AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder,
};

// ========================================================================
Expand All @@ -58,33 +58,48 @@ use crate::fuzz_cases::aggregation_fuzzer::{
// Dimensions to test:
// Aggregation functions:
// Arguments type to aggregation functions
// Group by columns:
// Group by columns (0..n)

/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `no group by`
/// Fuzz test for `core` aggregates (sum/sum distinct/max/min/count/avg)
/// and integral arguments
///
///
#[tokio::test(flavor = "multi_thread")]
async fn test_basic_prim_aggr_no_group() {
async fn core_aggregates() {
let builder = AggregationFuzzerBuilder::default();

// Define data generator config
let columns = vec![ColumnDescr::new("a", DataType::Int32)];
let columns = vec![
ColumnDescr::new("a", DataType::Int32),
ColumnDescr::new("b", DataType::Int64),
ColumnDescr::new("c", DataType::Utf8),
];

let data_gen_config = DatasetGeneratorConfig {
columns,
rows_num_range: (512, 1024),
sort_keys_set: Vec::new(),
};

let query_builder = QueryBuilder::new()
.with_table_name("fuzz_table")
.with_aggregate_function("sum")
.with_distinct_aggregate_function("sum")
.with_aggregate_function("max")
.with_aggregate_function("min")
.with_aggregate_function("count")
.with_distinct_aggregate_function("count")
.with_aggregate_function("avg")
.with_aggregate_argument("a") // integral arguments only
.with_aggregate_argument("b")
.with_group_by_column("a")
.with_group_by_column("b")
.with_group_by_column("c"); // group by string as well

// Build fuzzer
let fuzzer = builder
.data_gen_config(data_gen_config)
.data_gen_rounds(16)
.add_sql("SELECT sum(a) FROM fuzz_table")
.add_sql("SELECT sum(distinct a) FROM fuzz_table")
.add_sql("SELECT max(a) FROM fuzz_table")
.add_sql("SELECT min(a) FROM fuzz_table")
.add_sql("SELECT count(a) FROM fuzz_table")
.add_sql("SELECT count(distinct a) FROM fuzz_table")
.add_sql("SELECT avg(a) FROM fuzz_table")
.add_sql_from_builder(query_builder, 10) // 10 random queries
.table_name("fuzz_table")
.build();

Expand Down
134 changes: 128 additions & 6 deletions datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashSet;
use std::sync::Arc;

use arrow::util::pretty::pretty_format_batches;
use arrow_array::RecordBatch;
use rand::{thread_rng, Rng};
use tokio::task::JoinSet;

use crate::fuzz_cases::aggregation_fuzzer::{
check_equality_of_batches,
context_generator::{SessionContextGenerator, SessionContextWithParams},
data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig},
run_sql,
};
use arrow::util::pretty::pretty_format_batches;
use arrow_array::RecordBatch;
use rand::{thread_rng, Rng};
use tokio::task::JoinSet;

/// Rounds to call `generate` of [`SessionContextGenerator`]
/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`]
Expand Down Expand Up @@ -60,6 +60,14 @@ impl AggregationFuzzerBuilder {
}
}

/// Adds n random SQL queries to the fuzzer
pub fn add_sql_from_builder(mut self, query_builder: QueryBuilder, n: usize) -> Self {
for _ in 0..n {
self = self.add_sql(&query_builder.generate_query());
}
self
}

pub fn add_sql(mut self, sql: &str) -> Self {
self.candidate_sqls.push(Arc::from(sql));
self
Expand Down Expand Up @@ -98,7 +106,7 @@ impl AggregationFuzzerBuilder {
}
}

impl Default for AggregationFuzzerBuilder {
impl std::default::Default for AggregationFuzzerBuilder {
fn default() -> Self {
Self::new()
}
Expand Down Expand Up @@ -279,3 +287,117 @@ impl AggregationFuzzTestTask {
}
}
}

/// Random aggregate query builder
#[derive(Debug, Default)]
pub struct QueryBuilder {
table_name: String,
/// Aggregate functions to be used in the query
/// (function_name, is_distinct)
aggregate_functions: Vec<(String, bool)>,
/// Columns to be used in group by
group_by_columns: Vec<String>,
/// Columns to be used as arguments in the aggregate functions
arguments: Vec<String>,
}
impl QueryBuilder {
pub fn new() -> Self {
std::default::Default::default()
}

/// Set the table name for the query
pub fn with_table_name(mut self, table_name: impl Into<String>) -> Self {
self.table_name = table_name.into();
self
}

/// Add a new aggregate function to the query
pub fn with_aggregate_function(
mut self,
aggregate_function: impl Into<String>,
) -> Self {
self.aggregate_functions
.push((aggregate_function.into(), false));
self
}

/// Add a new distinct aggregate function to the query
pub fn with_distinct_aggregate_function(
mut self,
aggregate_function: impl Into<String>,
) -> Self {
self.aggregate_functions
.push((aggregate_function.into(), true));
self
}

/// Add a column to be used in the group bys
pub fn with_group_by_column(mut self, group_by: impl Into<String>) -> Self {
self.group_by_columns.push(group_by.into());
self
}

/// Add a column to be used as an argument in the aggregate functions
pub fn with_aggregate_argument(mut self, argument: impl Into<String>) -> Self {
self.arguments.push(argument.into());
self
}

pub fn generate_query(&self) -> String {
let group_by = self.random_group_by();
let mut query = format!("SELECT ");
query.push_str(&self.random_aggregate_functions().join(", "));
query.push_str(" FROM ");
query.push_str(&self.table_name);
if !group_by.is_empty() {
query.push_str(" GROUP BY ");
query.push_str(&group_by.join(", "));
}
query
}

/// Generate a random number of aggregate functions (potentially repeating)
/// to use in the query between 1 and 5
fn random_aggregate_functions(&self) -> Vec<String> {
let mut rng = thread_rng();
let num_aggregate_functions = rng.gen_range(1..5);

let mut alias_gen = 1;

let mut aggregate_functions = vec![];
while aggregate_functions.len() < num_aggregate_functions {
let idx = rng.gen_range(0..self.aggregate_functions.len());
let (function_name, is_distinct) = &self.aggregate_functions[idx];
let argument = self.random_argument();
let alias = format!("col{}", alias_gen);
let distinct = if *is_distinct { "DISTINCT " } else { "" };
alias_gen += 1;
let function = format!("{function_name}({distinct}{argument}) as {alias}");
aggregate_functions.push(function);
}
aggregate_functions
}

/// Pick a random aggregate function argument
fn random_argument(&self) -> String {
let mut rng = thread_rng();
let idx = rng.gen_range(0..self.arguments.len());
self.arguments[idx].clone()
}

/// Pick a random number of fields to group by (non repeating)
fn random_group_by(&self) -> Vec<String> {
let mut rng = thread_rng();
let num_group_by = rng.gen_range(1..self.group_by_columns.len());

let mut already_used = HashSet::new();
let mut group_by = vec![];
while group_by.len() < num_group_by {
let idx = rng.gen_range(0..self.group_by_columns.len());
if already_used.insert(idx) {
group_by.push(self.group_by_columns[idx].clone());
}
}
group_by
}
}

0 comments on commit 4c5d621

Please sign in to comment.