Skip to content

Commit

Permalink
refactor: modularize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aykut-bozkurt committed Nov 17, 2024
1 parent c8af064 commit c69fd25
Show file tree
Hide file tree
Showing 10 changed files with 4,134 additions and 4,067 deletions.
4,069 changes: 2 additions & 4,067 deletions src/lib.rs

Large diffs are not rendered by default.

307 changes: 307 additions & 0 deletions src/pgrx_tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
use std::fs::File;
use std::marker::PhantomData;
use std::{collections::HashMap, fmt::Debug};

use crate::type_compat::map::Map;

use arrow::array::RecordBatch;
use arrow_schema::SchemaRef;
use parquet::arrow::ArrowWriter;
use pgrx::{
datum::{Time, TimeWithTimeZone},
FromDatum, IntoDatum, Spi,
};

pub(crate) enum CopyOptionValue {
StringOption(String),
IntOption(i64),
}

pub(crate) fn comma_separated_copy_options(options: &HashMap<String, CopyOptionValue>) -> String {
let mut comma_sepated_options = String::new();

for (option_idx, (key, value)) in options.iter().enumerate() {
match value {
CopyOptionValue::StringOption(value) => {
comma_sepated_options.push_str(&format!("{} '{}'", key, value));
}
CopyOptionValue::IntOption(value) => {
comma_sepated_options.push_str(&format!("{} {}", key, value));
}
}

if option_idx < options.len() - 1 {
comma_sepated_options.push_str(", ");
}
}

comma_sepated_options
}

pub(crate) struct TestTable<T: IntoDatum + FromDatum> {
uri: String,
order_by_col: String,
copy_to_options: HashMap<String, CopyOptionValue>,
copy_from_options: HashMap<String, CopyOptionValue>,
_data: PhantomData<T>,
}

impl<T: IntoDatum + FromDatum> TestTable<T> {
pub(crate) fn new(typename: String) -> Self {
Spi::run("DROP TABLE IF EXISTS test_expected, test_result;").unwrap();

let create_table_command = format!("CREATE TABLE test_expected (a {});", &typename);
Spi::run(create_table_command.as_str()).unwrap();

let create_table_command = format!("CREATE TABLE test_result (a {});", &typename);
Spi::run(create_table_command.as_str()).unwrap();

let mut copy_to_options = HashMap::new();
copy_to_options.insert(
"format".to_string(),
CopyOptionValue::StringOption("parquet".to_string()),
);

let mut copy_from_options = HashMap::new();
copy_from_options.insert(
"format".to_string(),
CopyOptionValue::StringOption("parquet".to_string()),
);

let uri = "/tmp/test.parquet".to_string();

let order_by_col = "a".to_string();

Self {
uri,
order_by_col,
copy_to_options,
copy_from_options,
_data: PhantomData,
}
}

pub(crate) fn with_order_by_col(mut self, order_by_col: String) -> Self {
self.order_by_col = order_by_col;
self
}

pub(crate) fn with_copy_to_options(
mut self,
copy_to_options: HashMap<String, CopyOptionValue>,
) -> Self {
self.copy_to_options = copy_to_options;
self
}

pub(crate) fn with_copy_from_options(
mut self,
copy_from_options: HashMap<String, CopyOptionValue>,
) -> Self {
self.copy_from_options = copy_from_options;
self
}

pub(crate) fn with_uri(mut self, uri: String) -> Self {
self.uri = uri;
self
}

pub(crate) fn insert(&self, insert_command: &str) {
Spi::run(insert_command).unwrap();
}

pub(crate) fn select_all(&self, table_name: &str) -> Vec<(Option<T>,)> {
let select_command = format!(
"SELECT a FROM {} ORDER BY {};",
table_name, self.order_by_col
);

Spi::connect(|client| {
let mut results = Vec::new();
let tup_table = client.select(&select_command, None, None).unwrap();

for row in tup_table {
let val = row["a"].value::<T>();
results.push((val.expect("could not select"),));
}

results
})
}

pub(crate) fn copy_to_parquet(&self) {
let mut copy_to_query = format!("COPY (SELECT a FROM test_expected) TO '{}'", self.uri);

if !self.copy_to_options.is_empty() {
copy_to_query.push_str(" WITH (");

let options_str = comma_separated_copy_options(&self.copy_to_options);
copy_to_query.push_str(&options_str);

copy_to_query.push(')');
}

copy_to_query.push(';');

Spi::run(copy_to_query.as_str()).unwrap();
}

pub(crate) fn copy_from_parquet(&self) {
let mut copy_from_query = format!("COPY test_result FROM '{}'", self.uri);

if !self.copy_from_options.is_empty() {
copy_from_query.push_str(" WITH (");

let options_str = comma_separated_copy_options(&self.copy_from_options);
copy_from_query.push_str(&options_str);

copy_from_query.push(')');
}

copy_from_query.push(';');

Spi::run(copy_from_query.as_str()).unwrap();
}
}

pub(crate) fn timetz_to_utc_time(timetz: TimeWithTimeZone) -> Option<Time> {
Some(timetz.to_utc())
}

pub(crate) fn timetz_array_to_utc_time_array(
timetz_array: Vec<Option<TimeWithTimeZone>>,
) -> Option<Vec<Option<Time>>> {
Some(
timetz_array
.into_iter()
.map(|timetz| timetz.map(|timetz| timetz.to_utc()))
.collect(),
)
}

pub(crate) fn assert_int_text_map(expected: Option<Map>, actual: Option<Map>) {
if expected.is_none() {
assert!(actual.is_none());

Check warning on line 185 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L183-L185

Added lines #L183 - L185 were not covered by tests
} else {
assert!(actual.is_some());

Check warning on line 187 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L187

Added line #L187 was not covered by tests

let expected = expected.unwrap().entries;
let actual = actual.unwrap().entries;

Check warning on line 190 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L189-L190

Added lines #L189 - L190 were not covered by tests

for (expected, actual) in expected.iter().zip(actual.iter()) {
if expected.is_none() {
assert!(actual.is_none());

Check warning on line 194 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L192-L194

Added lines #L192 - L194 were not covered by tests
} else {
assert!(actual.is_some());

Check warning on line 196 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L196

Added line #L196 was not covered by tests

let expected = expected.unwrap();
let actual = actual.unwrap();

let expected_key: Option<i32> = expected.get_by_name("key").unwrap();
let actual_key: Option<i32> = actual.get_by_name("key").unwrap();

assert_eq!(expected_key, actual_key);

Check warning on line 204 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L198-L204

Added lines #L198 - L204 were not covered by tests

let expected_val: Option<String> = expected.get_by_name("val").unwrap();
let actual_val: Option<String> = actual.get_by_name("val").unwrap();

assert_eq!(expected_val, actual_val);

Check warning on line 209 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L206-L209

Added lines #L206 - L209 were not covered by tests
}
}
}
}

Check warning on line 213 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L213

Added line #L213 was not covered by tests

pub(crate) struct TestResult<T> {
pub(crate) expected: Vec<(Option<T>,)>,
pub(crate) result: Vec<(Option<T>,)>,
}

pub(crate) fn test_common<T: IntoDatum + FromDatum>(test_table: TestTable<T>) -> TestResult<T> {
test_table.copy_to_parquet();
test_table.copy_from_parquet();

let expected = test_table.select_all("test_expected");
let result = test_table.select_all("test_result");

TestResult { expected, result }
}

pub(crate) fn test_assert<T>(expected_result: Vec<(Option<T>,)>, result: Vec<(Option<T>,)>)
where
T: Debug + PartialEq,
{
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
assert_eq!(expected, actual);
}
}

pub(crate) fn test_assert_float(expected_result: Vec<Option<f32>>, result: Vec<Option<f32>>) {
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
if expected.is_none() {
assert!(actual.is_none());
}

if expected.is_some() {
assert!(actual.is_some());

let expected = expected.unwrap();
let actual = actual.unwrap();

if expected.is_nan() {
assert!(actual.is_nan());
} else if expected.is_infinite() {
assert!(actual.is_infinite());
assert!(expected.is_sign_positive() == actual.is_sign_positive());
} else {
assert_eq!(expected, actual);
}
}
}
}

pub(crate) fn test_assert_double(expected_result: Vec<Option<f64>>, result: Vec<Option<f64>>) {
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
if expected.is_none() {
assert!(actual.is_none());
}

if expected.is_some() {
assert!(actual.is_some());

let expected = expected.unwrap();
let actual = actual.unwrap();

if expected.is_nan() {
assert!(actual.is_nan());
} else if expected.is_infinite() {
assert!(actual.is_infinite());
assert!(expected.is_sign_positive() == actual.is_sign_positive());
} else {
assert_eq!(expected, actual);
}
}
}
}

pub(crate) fn test_helper<T: IntoDatum + FromDatum + Debug + PartialEq>(test_table: TestTable<T>) {
let test_result = test_common(test_table);
test_assert(test_result.expected, test_result.result);
}

pub(crate) fn extension_exists(extension_name: &str) -> bool {
let query = format!(
"select count(*) = 1 from pg_available_extensions where name = '{}'",
extension_name
);

Spi::get_one(&query).unwrap().unwrap()
}

pub(crate) fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) {
let file = File::create("/tmp/test.parquet").unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();

writer.write(&record_batch).unwrap();
writer.close().unwrap();
}
Loading

0 comments on commit c69fd25

Please sign in to comment.