diff --git a/.devcontainer/.env b/.devcontainer/.env index 14f05d0..484b326 100644 --- a/.devcontainer/.env +++ b/.devcontainer/.env @@ -2,10 +2,11 @@ AWS_ACCESS_KEY_ID=minioadmin AWS_SECRET_ACCESS_KEY=minioadmin AWS_REGION=us-east-1 +AWS_ENDPOINT_URL=http://localhost:9000 +AWS_ALLOW_HTTP=true AWS_S3_TEST_BUCKET=testbucket MINIO_ROOT_USER=minioadmin MINIO_ROOT_PASSWORD=minioadmin # Others RUST_TEST_THREADS=1 -PG_PARQUET_TEST=true diff --git a/.devcontainer/create-test-buckets.sh b/.devcontainer/create-test-buckets.sh deleted file mode 100644 index 65dfef0..0000000 --- a/.devcontainer/create-test-buckets.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -aws --endpoint-url http://localhost:9000 s3 mb s3://$AWS_S3_TEST_BUCKET diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index e2c90a8..90f0a2d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -3,7 +3,6 @@ "dockerComposeFile": "docker-compose.yml", "service": "app", "workspaceFolder": "/workspace", - "postStartCommand": "bash .devcontainer/create-test-buckets.sh", "postAttachCommand": "sudo chown -R rust /workspace", "customizations": { "vscode": { diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 259cfc8..616e225 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -23,10 +23,12 @@ services: env_file: - .env network_mode: host - command: server /data + entrypoint: "./entrypoint.sh" restart: unless-stopped healthcheck: test: ["CMD", "curl", "http://localhost:9000"] interval: 6s timeout: 2s retries: 3 + volumes: + - ./minio-entrypoint.sh:/entrypoint.sh diff --git a/.devcontainer/minio-entrypoint.sh b/.devcontainer/minio-entrypoint.sh new file mode 100755 index 0000000..7831ba5 --- /dev/null +++ b/.devcontainer/minio-entrypoint.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +trap "echo 'Caught termination signal. Exiting...'; exit 0" SIGINT SIGTERM + +minio server /data & + +minio_pid=$! + +while ! curl $AWS_ENDPOINT_URL; do + echo "Waiting for $AWS_ENDPOINT_URL..." + sleep 1 +done + +# set access key and secret key +mc alias set local $AWS_ENDPOINT_URL $MINIO_ROOT_USER $MINIO_ROOT_PASSWORD + +# create bucket +mc mb local/$AWS_S3_TEST_BUCKET + +wait $minio_pid diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d4ce9d..96f8b00 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -107,15 +107,19 @@ jobs: - name: Start Minio for s3 emulator tests run: | - docker run -d --env-file .devcontainer/.env -p 9000:9000 minio/minio server /data - - while ! nc -z localhost 9000; do - echo "Waiting for localhost:9000..." + docker run -d \ + --env-file .devcontainer/.env \ + -p 9000:9000 \ + --entrypoint "./entrypoint.sh" \ + --volume ./.devcontainer/minio-entrypoint.sh:/entrypoint.sh \ + --name miniocontainer \ + minio/minio + + while ! nc -z $AWS_ENDPOINT_URL; do + echo "Waiting for $AWS_ENDPOINT_URL..." sleep 1 done - aws --endpoint-url http://localhost:9000 s3 mb s3://$AWS_S3_TEST_BUCKET - - name: Run tests run: | # Run tests with coverage tool diff --git a/Cargo.lock b/Cargo.lock index a6702d3..7bcb35d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -317,7 +317,7 @@ dependencies = [ "aws-sdk-sts", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -345,9 +345,9 @@ dependencies = [ [[package]] name = "aws-runtime" -version = "1.4.3" +version = "1.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" +checksum = "b5ac934720fbb46206292d2c75b57e67acfc56fe7dfd34fb9a02334af08409ea" dependencies = [ "aws-credential-types", "aws-sigv4", @@ -370,15 +370,15 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.50.0" +version = "1.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ada54e5f26ac246dc79727def52f7f8ed38915cb47781e2a72213957dc3a7d5" +checksum = "b68fde0d69c8bfdc1060ea7da21df3e39f6014da316783336deff0a9ec28f4bf" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", "aws-smithy-http", - "aws-smithy-json", + "aws-smithy-json 0.61.1", "aws-smithy-query", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -393,9 +393,9 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.5" +version = "1.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" +checksum = "7d3820e0c08d0737872ff3c7c1f21ebbb6693d832312d6152bf18ef50a5471c2" dependencies = [ "aws-credential-types", "aws-smithy-http", @@ -454,6 +454,15 @@ dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-json" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee4e69cc50921eb913c6b662f8d909131bb3e6ad6cb6090d3a39b66fc5c52095" +dependencies = [ + "aws-smithy-types", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -466,9 +475,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.3" +version = "1.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" +checksum = "9f20685047ca9d6f17b994a07f629c813f08b5bce65523e47124879e60103d45" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -517,6 +526,7 @@ dependencies = [ "base64-simd", "bytes", "bytes-utils", + "futures-core", "http 0.2.12", "http 1.1.0", "http-body 0.4.6", @@ -529,6 +539,8 @@ dependencies = [ "ryu", "serde", "time", + "tokio", + "tokio-util", ] [[package]] @@ -2235,6 +2247,7 @@ dependencies = [ "arrow-schema", "aws-config", "aws-credential-types", + "aws-sdk-sts", "futures", "object_store", "once_cell", diff --git a/Cargo.toml b/Cargo.toml index e59a625..dc8acc8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,8 +23,9 @@ pg_test = [] arrow = {version = "53", default-features = false} arrow-cast = {version = "53", default-features = false} arrow-schema = {version = "53", default-features = false} -aws-config = { version = "1.5", default-features = false, features = ["rustls"]} -aws-credential-types = {version = "1.2", default-features = false} +aws-config = { version = "1", default-features = false, features = ["rustls"]} +aws-credential-types = {version = "1", default-features = false} +aws-sdk-sts = "1" futures = "0.3" object_store = {version = "0.11", default-features = false, features = ["aws"]} once_cell = "1" diff --git a/README.md b/README.md index 353b01f..74fdae7 100644 --- a/README.md +++ b/README.md @@ -174,10 +174,14 @@ region = eu-central-1 Alternatively, you can use the following environment variables when starting postgres to configure the S3 client: - `AWS_ACCESS_KEY_ID`: the access key ID of the AWS account - `AWS_SECRET_ACCESS_KEY`: the secret access key of the AWS account +- `AWS_SESSION_TOKEN`: the session token for the AWS account - `AWS_REGION`: the default region of the AWS account -- `AWS_SHARED_CREDENTIALS_FILE`: an alternative location for the credentials file -- `AWS_CONFIG_FILE`: an alternative location for the config file -- `AWS_PROFILE`: the name of the profile from the credentials and config file (default profile name is `default`) +- `AWS_ENDPOINT_URL`: the endpoint +- `AWS_SHARED_CREDENTIALS_FILE`: an alternative location for the credentials file **(only via environment variables)** +- `AWS_CONFIG_FILE`: an alternative location for the config file **(only via environment variables)** +- `AWS_PROFILE`: the name of the profile from the credentials and config file (default profile name is `default`) **(only via environment variables)** +- `AWS_ALLOW_HTTP`: allows http endpoints **(only via environment variables)** + > [!NOTE] > To be able to write into a object store location, you need to grant `parquet_object_store_write` role to your current postgres user. diff --git a/src/arrow_parquet/uri_utils.rs b/src/arrow_parquet/uri_utils.rs index 3ff97af..2091458 100644 --- a/src/arrow_parquet/uri_utils.rs +++ b/src/arrow_parquet/uri_utils.rs @@ -1,11 +1,7 @@ use std::{sync::Arc, sync::LazyLock}; use arrow::datatypes::SchemaRef; -use aws_config::{ - environment::{EnvironmentVariableCredentialsProvider, EnvironmentVariableRegionProvider}, - meta::{credentials::CredentialsProviderChain, region::RegionProviderChain}, - profile::{ProfileFileCredentialsProvider, ProfileFileRegionProvider}, -}; +use aws_config::BehaviorVersion; use aws_credential_types::provider::ProvideCredentials; use object_store::{ aws::{AmazonS3, AmazonS3Builder}, @@ -92,48 +88,40 @@ fn object_store_with_location(uri: &Url, copy_from: bool) -> (Arc AmazonS3 { let mut aws_s3_builder = AmazonS3Builder::new().with_bucket_name(bucket_name); - let is_test_running = std::env::var("PG_PARQUET_TEST").is_ok(); - - if is_test_running { - // use minio for testing - aws_s3_builder = aws_s3_builder.with_endpoint("http://localhost:9000"); - aws_s3_builder = aws_s3_builder.with_allow_http(true); - } + // AWS_ALLOW_HTTP + if let Ok(aws_allow_http) = std::env::var("AWS_ALLOW_HTTP") { + aws_s3_builder = aws_s3_builder + .with_allow_http(aws_allow_http.parse().unwrap_or_else(|e| panic!("{}", e))); + }; - let aws_profile_name = std::env::var("AWS_PROFILE").unwrap_or("default".to_string()); + // first tries to load the profile files from the environment variables and then from the profile + let sdk_config = aws_config::defaults(BehaviorVersion::v2024_03_28()) + .load() + .await; - let region_provider = RegionProviderChain::first_try(EnvironmentVariableRegionProvider::new()) - .or_else( - ProfileFileRegionProvider::builder() - .profile_name(aws_profile_name.clone()) - .build(), - ); + if let Some(credential_provider) = sdk_config.credentials_provider() { + if let Ok(credentials) = credential_provider.provide_credentials().await { + // AWS_ACCESS_KEY_ID + aws_s3_builder = aws_s3_builder.with_access_key_id(credentials.access_key_id()); - let region = region_provider.region().await; + // AWS_SECRET_ACCESS_KEY + aws_s3_builder = aws_s3_builder.with_secret_access_key(credentials.secret_access_key()); - if let Some(region) = region { - aws_s3_builder = aws_s3_builder.with_region(region.to_string()); + if let Some(token) = credentials.session_token() { + // AWS_SESSION_TOKEN + aws_s3_builder = aws_s3_builder.with_token(token); + } + } } - let credential_provider = CredentialsProviderChain::first_try( - "Environment", - EnvironmentVariableCredentialsProvider::new(), - ) - .or_else( - "Profile", - ProfileFileCredentialsProvider::builder() - .profile_name(aws_profile_name) - .build(), - ); - - if let Ok(credentials) = credential_provider.provide_credentials().await { - aws_s3_builder = aws_s3_builder.with_access_key_id(credentials.access_key_id()); - - aws_s3_builder = aws_s3_builder.with_secret_access_key(credentials.secret_access_key()); + // AWS_ENDPOINT_URL + if let Some(aws_endpoint_url) = sdk_config.endpoint_url() { + aws_s3_builder = aws_s3_builder.with_endpoint(aws_endpoint_url); + } - if let Some(token) = credentials.session_token() { - aws_s3_builder = aws_s3_builder.with_token(token); - } + // AWS_REGION + if let Some(aws_region) = sdk_config.region() { + aws_s3_builder = aws_s3_builder.with_region(aws_region.as_ref()); } aws_s3_builder.build().unwrap_or_else(|e| panic!("{}", e)) diff --git a/src/lib.rs b/src/lib.rs index 817f224..100c80b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ use pgrx::{prelude::*, GucContext, GucFlags, GucRegistry}; mod arrow_parquet; mod parquet_copy_hook; mod parquet_udfs; +#[cfg(any(test, feature = "pg_test"))] mod pgrx_tests; mod pgrx_utils; mod type_compat; diff --git a/src/pgrx_tests/object_store.rs b/src/pgrx_tests/object_store.rs index 4272027..561aab6 100644 --- a/src/pgrx_tests/object_store.rs +++ b/src/pgrx_tests/object_store.rs @@ -2,6 +2,7 @@ mod tests { use std::io::Write; + use aws_config::BehaviorVersion; use pgrx::{pg_test, Spi}; use crate::pgrx_tests::common::TestTable; @@ -31,11 +32,13 @@ mod tests { std::env::remove_var("AWS_SECRET_ACCESS_KEY"); let region = std::env::var("AWS_REGION").unwrap(); std::env::remove_var("AWS_REGION"); + let endpoint = std::env::var("AWS_ENDPOINT_URL").unwrap(); + std::env::remove_var("AWS_ENDPOINT_URL"); // create a config file let aws_config_file_content = format!( - "[profile pg_parquet_test]\nregion = {}\naws_access_key_id = {}\naws_secret_access_key = {}\n", - region, access_key_id, secret_access_key + "[profile pg_parquet_test]\nregion = {}\naws_access_key_id = {}\naws_secret_access_key = {}\nendpoint_url = {}\n", + region, access_key_id, secret_access_key, endpoint ); std::env::set_var("AWS_PROFILE", "pg_parquet_test"); @@ -154,15 +157,78 @@ mod tests { } #[pg_test] - #[should_panic(expected = "404 Not Found")] - fn test_s3_object_store_read_invalid_uri() { - let s3_uri = "s3://randombucketwhichdoesnotexist/pg_parquet_test.parquet"; + fn test_s3_object_store_with_temporary_token() { + let tokio_rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap_or_else(|e| panic!("failed to create tokio runtime: {}", e)); - let create_table_command = "CREATE TABLE test_table (a int);"; - Spi::run(create_table_command).unwrap(); + let s3_uri = tokio_rt.block_on(async { + let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; + let client = aws_sdk_sts::Client::new(&config); - let copy_from_command = format!("COPY test_table FROM '{}';", s3_uri); - Spi::run(copy_from_command.as_str()).unwrap(); + let assume_role_result = client + .assume_role() + .role_session_name("testsession") + .role_arn("arn:xxx:xxx:xxx:xxxx") + .send() + .await + .unwrap(); + + let assumed_creds = assume_role_result.credentials().unwrap(); + + std::env::set_var("AWS_ACCESS_KEY_ID", assumed_creds.access_key_id()); + std::env::set_var("AWS_SECRET_ACCESS_KEY", assumed_creds.secret_access_key()); + std::env::set_var("AWS_SESSION_TOKEN", assumed_creds.session_token()); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + format!("s3://{}/pg_parquet_test.parquet", test_bucket_name) + }); + + let test_table = TestTable::::new("int4".into()).with_uri(s3_uri); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); + } + + #[pg_test] + #[should_panic(expected = "403 Forbidden")] + fn test_s3_object_store_with_missing_temporary_token_fail() { + let tokio_rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap_or_else(|e| panic!("failed to create tokio runtime: {}", e)); + + let s3_uri = tokio_rt.block_on(async { + let config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; + let client = aws_sdk_sts::Client::new(&config); + + let assume_role_result = client + .assume_role() + .role_session_name("testsession") + .role_arn("arn:xxx:xxx:xxx:xxxx") + .send() + .await + .unwrap(); + + let assumed_creds = assume_role_result.credentials().unwrap(); + + // we do not set the session token on purpose + std::env::set_var("AWS_ACCESS_KEY_ID", assumed_creds.access_key_id()); + std::env::set_var("AWS_SECRET_ACCESS_KEY", assumed_creds.secret_access_key()); + + let test_bucket_name: String = + std::env::var("AWS_S3_TEST_BUCKET").expect("AWS_S3_TEST_BUCKET not found"); + + format!("s3://{}/pg_parquet_test.parquet", test_bucket_name) + }); + + let test_table = TestTable::::new("int4".into()).with_uri(s3_uri); + + test_table.insert("INSERT INTO test_expected (a) VALUES (1), (2), (null);"); + test_table.assert_expected_and_result_rows(); } #[pg_test]