Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: modularize tests #78

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4,097 changes: 1 addition & 4,096 deletions src/lib.rs

Large diffs are not rendered by default.

315 changes: 315 additions & 0 deletions src/pgrx_tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
use std::fs::File;
use std::marker::PhantomData;
use std::{collections::HashMap, fmt::Debug};

use crate::type_compat::map::Map;

use arrow::array::RecordBatch;
use arrow_schema::SchemaRef;
use parquet::arrow::ArrowWriter;
use pgrx::{
datum::{Time, TimeWithTimeZone},
FromDatum, IntoDatum, Spi,
};

pub(crate) enum CopyOptionValue {
StringOption(String),
IntOption(i64),
}

pub(crate) fn comma_separated_copy_options(options: &HashMap<String, CopyOptionValue>) -> String {
let mut comma_sepated_options = String::new();

for (option_idx, (key, value)) in options.iter().enumerate() {
match value {
CopyOptionValue::StringOption(value) => {
comma_sepated_options.push_str(&format!("{} '{}'", key, value));
}
CopyOptionValue::IntOption(value) => {
comma_sepated_options.push_str(&format!("{} {}", key, value));
}
}

if option_idx < options.len() - 1 {
comma_sepated_options.push_str(", ");
}
}

comma_sepated_options
}

pub(crate) const LOCAL_TEST_FILE_PATH: &str = "/tmp/pg_parquet_test.parquet";

pub(crate) struct TestTable<T: IntoDatum + FromDatum> {
uri: String,
order_by_col: String,
copy_to_options: HashMap<String, CopyOptionValue>,
copy_from_options: HashMap<String, CopyOptionValue>,
_data: PhantomData<T>,
}

impl<T: IntoDatum + FromDatum> TestTable<T> {
pub(crate) fn new(typename: String) -> Self {
Spi::run("DROP TABLE IF EXISTS test_expected, test_result;").unwrap();

let create_table_command = format!("CREATE TABLE test_expected (a {});", &typename);
Spi::run(create_table_command.as_str()).unwrap();

let create_table_command = format!("CREATE TABLE test_result (a {});", &typename);
Spi::run(create_table_command.as_str()).unwrap();

let mut copy_to_options = HashMap::new();
copy_to_options.insert(
"format".to_string(),
CopyOptionValue::StringOption("parquet".to_string()),
);

let mut copy_from_options = HashMap::new();
copy_from_options.insert(
"format".to_string(),
CopyOptionValue::StringOption("parquet".to_string()),
);

let uri = LOCAL_TEST_FILE_PATH.to_string();

let order_by_col = "a".to_string();

Self {
uri,
order_by_col,
copy_to_options,
copy_from_options,
_data: PhantomData,
}
}

pub(crate) fn with_order_by_col(mut self, order_by_col: String) -> Self {
self.order_by_col = order_by_col;
self
}

pub(crate) fn with_copy_to_options(
mut self,
copy_to_options: HashMap<String, CopyOptionValue>,
) -> Self {
self.copy_to_options = copy_to_options;
self
}

pub(crate) fn with_copy_from_options(
mut self,
copy_from_options: HashMap<String, CopyOptionValue>,
) -> Self {
self.copy_from_options = copy_from_options;
self
}

pub(crate) fn with_uri(mut self, uri: String) -> Self {
self.uri = uri;
self
}

pub(crate) fn insert(&self, insert_command: &str) {
Spi::run(insert_command).unwrap();
}

fn select_all(&self, table_name: &str) -> Vec<(Option<T>,)> {
let select_command = format!(
"SELECT a FROM {} ORDER BY {};",
table_name, self.order_by_col
);

Spi::connect(|client| {
let mut results = Vec::new();
let tup_table = client.select(&select_command, None, None).unwrap();

for row in tup_table {
let val = row["a"].value::<T>();
results.push((val.expect("could not select"),));
}

results
})
}

pub(crate) fn copy_to_parquet(&self, table_name: &str) {
let mut copy_to_query = format!("COPY (SELECT a FROM {}) TO '{}'", table_name, self.uri);

if !self.copy_to_options.is_empty() {
copy_to_query.push_str(" WITH (");

let options_str = comma_separated_copy_options(&self.copy_to_options);
copy_to_query.push_str(&options_str);

copy_to_query.push(')');
}

copy_to_query.push(';');

Spi::run(copy_to_query.as_str()).unwrap();
}

pub(crate) fn copy_from_parquet(&self, table_name: &str) {
let mut copy_from_query = format!("COPY {} FROM '{}'", table_name, self.uri);

if !self.copy_from_options.is_empty() {
copy_from_query.push_str(" WITH (");

let options_str = comma_separated_copy_options(&self.copy_from_options);
copy_from_query.push_str(&options_str);

copy_from_query.push(')');
}

copy_from_query.push(';');

Spi::run(copy_from_query.as_str()).unwrap();
}

pub(crate) fn select_expected_and_result_rows(&self) -> TestResult<T> {
self.copy_to_parquet("test_expected");
self.copy_from_parquet("test_result");

let expected = self.select_all("test_expected");
let result = self.select_all("test_result");

TestResult { expected, result }
}

pub(crate) fn assert_expected_and_result_rows(&self)
where
T: Debug + PartialEq,
{
let test_result = self.select_expected_and_result_rows();
test_result.assert();
}
}

pub(crate) struct TestResult<T> {
pub(crate) expected: Vec<(Option<T>,)>,
pub(crate) result: Vec<(Option<T>,)>,
}

impl<T> TestResult<T>
where
T: Debug + PartialEq,
{
// almost all types are comparable by common equality
pub(crate) fn assert(&self) {
for (expected, actual) in self.expected.iter().zip(self.result.iter()) {
assert_eq!(expected, actual);
}
}
}

pub(crate) fn assert_int_text_map(expected: Option<Map>, actual: Option<Map>) {
if expected.is_none() {
assert!(actual.is_none());

Check warning on line 207 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L205-L207

Added lines #L205 - L207 were not covered by tests
} else {
assert!(actual.is_some());

Check warning on line 209 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L209

Added line #L209 was not covered by tests

let expected = expected.unwrap().entries;
let actual = actual.unwrap().entries;

Check warning on line 212 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L211-L212

Added lines #L211 - L212 were not covered by tests

for (expected, actual) in expected.iter().zip(actual.iter()) {
if expected.is_none() {
assert!(actual.is_none());

Check warning on line 216 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L214-L216

Added lines #L214 - L216 were not covered by tests
} else {
assert!(actual.is_some());

Check warning on line 218 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L218

Added line #L218 was not covered by tests

let expected = expected.unwrap();
let actual = actual.unwrap();

let expected_key: Option<i32> = expected.get_by_name("key").unwrap();
let actual_key: Option<i32> = actual.get_by_name("key").unwrap();

assert_eq!(expected_key, actual_key);

Check warning on line 226 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L220-L226

Added lines #L220 - L226 were not covered by tests

let expected_val: Option<String> = expected.get_by_name("val").unwrap();
let actual_val: Option<String> = actual.get_by_name("val").unwrap();

assert_eq!(expected_val, actual_val);

Check warning on line 231 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L228-L231

Added lines #L228 - L231 were not covered by tests
}
}
}
}

Check warning on line 235 in src/pgrx_tests/common.rs

View check run for this annotation

Codecov / codecov/patch

src/pgrx_tests/common.rs#L235

Added line #L235 was not covered by tests

pub(crate) fn assert_float(expected_result: Vec<Option<f32>>, result: Vec<Option<f32>>) {
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
if expected.is_none() {
assert!(actual.is_none());
}

if expected.is_some() {
assert!(actual.is_some());

let expected = expected.unwrap();
let actual = actual.unwrap();

if expected.is_nan() {
assert!(actual.is_nan());
} else if expected.is_infinite() {
assert!(actual.is_infinite());
assert!(expected.is_sign_positive() == actual.is_sign_positive());
} else {
assert_eq!(expected, actual);
}
}
}
}

pub(crate) fn assert_double(expected_result: Vec<Option<f64>>, result: Vec<Option<f64>>) {
for (expected, actual) in expected_result.into_iter().zip(result.into_iter()) {
if expected.is_none() {
assert!(actual.is_none());
}

if expected.is_some() {
assert!(actual.is_some());

let expected = expected.unwrap();
let actual = actual.unwrap();

if expected.is_nan() {
assert!(actual.is_nan());
} else if expected.is_infinite() {
assert!(actual.is_infinite());
assert!(expected.is_sign_positive() == actual.is_sign_positive());
} else {
assert_eq!(expected, actual);
}
}
}
}

pub(crate) fn timetz_to_utc_time(timetz: TimeWithTimeZone) -> Option<Time> {
Some(timetz.to_utc())
}

pub(crate) fn timetz_array_to_utc_time_array(
timetz_array: Vec<Option<TimeWithTimeZone>>,
) -> Option<Vec<Option<Time>>> {
Some(
timetz_array
.into_iter()
.map(|timetz| timetz.map(|timetz| timetz.to_utc()))
.collect(),
)
}

pub(crate) fn extension_exists(extension_name: &str) -> bool {
let query = format!(
"select count(*) = 1 from pg_available_extensions where name = '{}'",
extension_name
);

Spi::get_one(&query).unwrap().unwrap()
}

pub(crate) fn write_record_batch_to_parquet(schema: SchemaRef, record_batch: RecordBatch) {
let file = File::create(LOCAL_TEST_FILE_PATH).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();

writer.write(&record_batch).unwrap();
writer.close().unwrap();
}
Loading