diff --git a/README.md b/README.md index a5858e2f69..afd887655a 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ All runtimes support either `native-tls` or `rustls-tls`, there are features for | | | |----------|-----------------------------------------------------------------------------------------------------| -| `POST` | [presign_put](https://docs.rs/rust-s3/latest/s3/bucket/struct.Bucket.html#method.presign_post) | +| `POST` | [presign_post](https://docs.rs/rust-s3/latest/s3/bucket/struct.Bucket.html#method.presign_post) | | `PUT` | [presign_put](https://docs.rs/rust-s3/latest/s3/bucket/struct.Bucket.html#method.presign_put) | | `GET` | [presign_get](https://docs.rs/rust-s3/latest/s3/bucket/struct.Bucket.html#method.presign_get) | | `DELETE` | [presign_delete](https://docs.rs/rust-s3/latest/s3/bucket/struct.Bucket.html#method.presign_delete) | @@ -141,4 +141,3 @@ Each `GET` method has a `PUT` companion `sync` and `async` methods are generic o [dependencies] rust-s3 = "0.33" ``` - diff --git a/s3/Cargo.toml b/s3/Cargo.toml index 605228ca8d..71e8e54d85 100644 --- a/s3/Cargo.toml +++ b/s3/Cargo.toml @@ -68,6 +68,7 @@ maybe-async = { version = "0.2" } md5 = "0.7" percent-encoding = "2" serde = "1" +serde_json = "1" serde_derive = "1" quick-xml = { version = "0.28", features = ["serialize"] } sha2 = "0.10" diff --git a/s3/src/bucket.rs b/s3/src/bucket.rs index 61f4a02cfa..b9ec43e8cf 100644 --- a/s3/src/bucket.rs +++ b/s3/src/bucket.rs @@ -38,12 +38,14 @@ use tokio::io::AsyncRead; use futures::io::AsyncRead; use crate::error::S3Error; +use crate::post_policy::PresignedPost; use crate::request::Request; use crate::serde_types::{ BucketLocationResult, CompleteMultipartUploadData, CorsConfiguration, HeadObjectResult, InitiateMultipartUploadResponse, ListBucketResult, ListMultipartUploadsResult, Part, }; use crate::utils::{error_from_response_data, PutStreamResponse}; +use crate::PostPolicy; use http::header::HeaderName; use http::HeaderMap; @@ -165,36 +167,24 @@ impl Bucket { /// ```no_run /// use s3::bucket::Bucket; /// use s3::creds::Credentials; - /// use http::HeaderMap; - /// use http::header::HeaderName; + /// use s3::post_policy::*; + /// use std::borrow::Cow; /// /// let bucket_name = "rust-s3-test"; /// let region = "us-east-1".parse().unwrap(); /// let credentials = Credentials::default().unwrap(); /// let bucket = Bucket::new(bucket_name, region, credentials).unwrap(); /// - /// let post_policy = "eyAiZXhwaXJhdGlvbiI6ICIyMDE1LTEyLTMwVDEyOjAwOjAwLjAwMFoiLA0KICAiY29uZGl0aW9ucyI6IFsNCiAgICB7ImJ1Y2tldCI6ICJzaWd2NGV4YW1wbGVidWNrZXQifSwNCiAgICBbInN0YXJ0cy13aXRoIiwgIiRrZXkiLCAidXNlci91c2VyMS8iXSwNCiAgICB7ImFjbCI6ICJwdWJsaWMtcmVhZCJ9LA0KICAgIHsic3VjY2Vzc19hY3Rpb25fcmVkaXJlY3QiOiAiaHR0cDovL3NpZ3Y0ZXhhbXBsZWJ1Y2tldC5zMy5hbWF6b25hd3MuY29tL3N1Y2Nlc3NmdWxfdXBsb2FkLmh0bWwifSwNCiAgICBbInN0YXJ0cy13aXRoIiwgIiRDb250ZW50LVR5cGUiLCAiaW1hZ2UvIl0sDQogICAgeyJ4LWFtei1tZXRhLXV1aWQiOiAiMTQzNjUxMjM2NTEyNzQifSwNCiAgICB7IngtYW16LXNlcnZlci1zaWRlLWVuY3J5cHRpb24iOiAiQUVTMjU2In0sDQogICAgWyJzdGFydHMtd2l0aCIsICIkeC1hbXotbWV0YS10YWciLCAiIl0sDQoNCiAgICB7IngtYW16LWNyZWRlbnRpYWwiOiAiQUtJQUlPU0ZPRE5ON0VYQU1QTEUvMjAxNTEyMjkvdXMtZWFzdC0xL3MzL2F3czRfcmVxdWVzdCJ9LA0KICAgIHsieC1hbXotYWxnb3JpdGhtIjogIkFXUzQtSE1BQy1TSEEyNTYifSwNCiAgICB7IngtYW16LWRhdGUiOiAiMjAxNTEyMjlUMDAwMDAwWiIgfQ0KICBdDQp9"; + /// let post_policy = PostPolicy::new(86400).condition( + /// PostPolicyField::Key, + /// PostPolicyValue::StartsWith(Cow::from("user/user1/")) + /// ).unwrap(); /// - /// let url = bucket.presign_post("/test.file", 86400, post_policy.to_string()).unwrap(); - /// println!("Presigned url: {}", url); + /// let presigned_post = bucket.presign_post(post_policy).unwrap(); + /// println!("Presigned url: {}, fields: {:?}", presigned_post.url, presigned_post.fields); /// ``` - pub fn presign_post>( - &self, - path: S, - expiry_secs: u32, - // base64 encoded post policy document -> https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-post-example.html - post_policy: String, - ) -> Result { - validate_expiry(expiry_secs)?; - let request = RequestImpl::new( - self, - path.as_ref(), - Command::PresignPost { - expiry_secs, - post_policy, - }, - )?; - request.presigned() + pub fn presign_post(&self, post_policy: PostPolicy) -> Result { + post_policy.sign(self.clone()) } /// Get a presigned url for putting object to a given path @@ -2237,12 +2227,13 @@ impl Bucket { mod test { use crate::creds::Credentials; + use crate::post_policy::{PostPolicyField, PostPolicyValue}; use crate::region::Region; use crate::serde_types::CorsConfiguration; use crate::serde_types::CorsRule; - use crate::Bucket; use crate::BucketConfiguration; use crate::Tag; + use crate::{Bucket, PostPolicy}; use http::header::HeaderName; use http::HeaderMap; use std::env; @@ -2355,6 +2346,7 @@ mod test { .with_path_style() } + #[allow(dead_code)] fn test_digital_ocean_bucket() -> Bucket { Bucket::new("rust-s3", Region::DoFra1, test_digital_ocean_credentials()).unwrap() } @@ -2895,10 +2887,9 @@ mod test { } #[test] - #[ignore] fn test_presign_put() { let s3_path = "/test/test.file"; - let bucket = test_aws_bucket(); + let bucket = test_minio_bucket(); let mut custom_headers = HeaderMap::new(); custom_headers.insert( @@ -2915,20 +2906,39 @@ mod test { } #[test] - #[ignore] + fn test_presign_post() { + use std::borrow::Cow; + + let bucket = test_minio_bucket(); + + // Policy from sample + let policy = PostPolicy::new(86400) + .condition( + PostPolicyField::Key, + PostPolicyValue::StartsWith(Cow::from("user/user1/")), + ) + .unwrap(); + + let data = bucket.presign_post(policy).unwrap(); + + assert_eq!(data.url, "http://localhost:9000/rust-s3"); + assert_eq!(data.fields.len(), 6); + assert_eq!(data.dynamic_fields.len(), 1); + } + + #[test] fn test_presign_get() { let s3_path = "/test/test.file"; - let bucket = test_aws_bucket(); + let bucket = test_minio_bucket(); let url = bucket.presign_get(s3_path, 86400, None).unwrap(); assert!(url.contains("/test/test.file?")) } #[test] - #[ignore] fn test_presign_delete() { let s3_path = "/test/test.file"; - let bucket = test_aws_bucket(); + let bucket = test_minio_bucket(); let url = bucket.presign_delete(s3_path, 86400).unwrap(); assert!(url.contains("/test/test.file?")) diff --git a/s3/src/command.rs b/s3/src/command.rs index 1dfdcfec28..022c6c92f5 100644 --- a/s3/src/command.rs +++ b/s3/src/command.rs @@ -102,10 +102,6 @@ pub enum Command<'a> { expiry_secs: u32, custom_headers: Option, }, - PresignPost { - expiry_secs: u32, - post_policy: String, - }, PresignDelete { expiry_secs: u32, }, @@ -161,7 +157,6 @@ impl<'a> Command<'a> { HttpMethod::Post } Command::HeadObject => HttpMethod::Head, - Command::PresignPost { .. } => HttpMethod::Post, } } diff --git a/s3/src/error.rs b/s3/src/error.rs index ff0d73d0d1..01175e03bf 100644 --- a/s3/src/error.rs +++ b/s3/src/error.rs @@ -1,6 +1,7 @@ use thiserror::Error; #[derive(Error, Debug)] +#[non_exhaustive] pub enum S3Error { #[error("Utf8 decoding error: {0}")] Utf8(#[from] std::str::Utf8Error), @@ -53,4 +54,8 @@ pub enum S3Error { TimeFormatError(#[from] time::error::Format), #[error("fmt error: {0}")] FmtError(#[from] std::fmt::Error), + #[error("serde error: {0}")] + SerdeError(#[from] serde_json::Error), + #[error("post policy error: {0}")] + PostPolicyError(#[from] crate::post_policy::PostPolicyError), } diff --git a/s3/src/lib.rs b/s3/src/lib.rs index a8c42a6d69..86e7dabe31 100644 --- a/s3/src/lib.rs +++ b/s3/src/lib.rs @@ -10,12 +10,14 @@ pub use awsregion as region; pub use bucket::Bucket; pub use bucket::Tag; pub use bucket_ops::BucketConfiguration; +pub use post_policy::{PostPolicy, PostPolicyChecksum, PostPolicyField, PostPolicyValue}; pub use region::Region; pub mod bucket; pub mod bucket_ops; pub mod command; pub mod deserializer; +pub mod post_policy; pub mod serde_types; pub mod signing; diff --git a/s3/src/post_policy.rs b/s3/src/post_policy.rs new file mode 100644 index 0000000000..02c2a1c889 --- /dev/null +++ b/s3/src/post_policy.rs @@ -0,0 +1,738 @@ +use crate::error::S3Error; +use crate::utils::now_utc; +use crate::{signing, Bucket, LONG_DATETIME}; + +use awscreds::error::CredentialsError; +use awscreds::Rfc3339OffsetDateTime; +use serde::ser; +use serde::ser::{Serialize, SerializeMap, SerializeSeq, SerializeTuple, Serializer}; +use std::borrow::Cow; +use std::collections::HashMap; +use thiserror::Error; +use time::{Duration, OffsetDateTime}; + +#[derive(Clone, Debug)] +pub struct PostPolicy<'a> { + expiration: PostPolicyExpiration, + conditions: ConditionsSerializer<'a>, +} + +impl<'a> PostPolicy<'a> { + pub fn new(expiration: T) -> Self + where + T: Into, + { + Self { + expiration: expiration.into(), + conditions: ConditionsSerializer(Vec::new()), + } + } + + /// Build a finalized post policy with credentials + fn build(&self, now: &OffsetDateTime, bucket: &Bucket) -> Result { + let access_key = bucket.access_key()?.ok_or(S3Error::Credentials( + CredentialsError::ConfigMissingAccessKeyId, + ))?; + let credential = format!( + "{}/{}", + access_key, + signing::scope_string(now, &bucket.region)? + ); + + let mut post_policy = self + .clone() + .condition( + PostPolicyField::Bucket, + PostPolicyValue::Exact(Cow::from(bucket.name.clone())), + )? + .condition( + PostPolicyField::AmzAlgorithm, + PostPolicyValue::Exact(Cow::from("AWS4-HMAC-SHA256")), + )? + .condition( + PostPolicyField::AmzCredential, + PostPolicyValue::Exact(Cow::from(credential)), + )? + .condition( + PostPolicyField::AmzDate, + PostPolicyValue::Exact(Cow::from(now.format(LONG_DATETIME)?)), + )?; + + if let Some(security_token) = bucket.security_token()? { + post_policy = post_policy.condition( + PostPolicyField::AmzSecurityToken, + PostPolicyValue::Exact(Cow::from(security_token)), + )?; + } + Ok(post_policy.clone()) + } + + fn policy_string(&self) -> Result { + use base64::engine::general_purpose; + use base64::Engine; + + let data = serde_json::to_string(self)?; + + Ok(general_purpose::STANDARD.encode(data)) + } + + pub fn sign(&self, bucket: Bucket) -> Result { + use hmac::Mac; + + bucket.credentials_refresh()?; + let now = now_utc(); + + let policy = self.build(&now, &bucket)?; + let policy_string = policy.policy_string()?; + + let signing_key = signing::signing_key( + &now, + &bucket.secret_key()?.ok_or(S3Error::Credentials( + CredentialsError::ConfigMissingSecretKey, + ))?, + &bucket.region, + "s3", + )?; + + let mut hmac = signing::HmacSha256::new_from_slice(&signing_key)?; + hmac.update(policy_string.as_bytes()); + let signature = hex::encode(hmac.finalize().into_bytes()); + let mut fields: HashMap = HashMap::new(); + let mut dynamic_fields = HashMap::new(); + for field in policy.conditions.0.iter() { + let f: Cow = field.field.clone().into(); + match &field.value { + PostPolicyValue::Anything => { + dynamic_fields.insert(f.to_string(), "".to_string()); + } + PostPolicyValue::StartsWith(e) => { + dynamic_fields.insert(f.to_string(), e.clone().into_owned()); + } + PostPolicyValue::Range(b, e) => { + dynamic_fields.insert(f.to_string(), format!("{},{}", b, e)); + } + PostPolicyValue::Exact(e) => { + fields.insert(f.to_string(), e.clone().into_owned()); + } + } + } + fields.insert("x-amz-signature".to_string(), signature); + fields.insert("Policy".to_string(), policy_string); + let url = bucket.url(); + Ok(PresignedPost { + url, + fields, + dynamic_fields, + expiration: policy.expiration.into(), + }) + } + + /// Adds another condition to the policy by consuming this object + pub fn condition( + mut self, + field: PostPolicyField<'a>, + value: PostPolicyValue<'a>, + ) -> Result { + if matches!(field, PostPolicyField::ContentLengthRange) + != matches!(value, PostPolicyValue::Range(_, _)) + { + Err(PostPolicyError::MismatchedCondition)? + } + self.conditions.0.push(PostPolicyCondition { field, value }); + Ok(self) + } +} + +impl Serialize for PostPolicy<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("expiration", &self.expiration)?; + map.serialize_entry("conditions", &self.conditions)?; + map.end() + } +} + +#[derive(Clone, Debug)] +struct ConditionsSerializer<'a>(Vec>); + +impl Serialize for ConditionsSerializer<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut seq = serializer.serialize_seq(None)?; + for e in self.0.iter() { + if let PostPolicyField::AmzChecksumAlgorithm(checksum) = &e.field { + let checksum: Cow = (*checksum).into(); + seq.serialize_element(&PostPolicyCondition { + field: PostPolicyField::Custom(Cow::from("x-amz-checksum-algorithm")), + value: PostPolicyValue::Exact(Cow::from(checksum.to_uppercase())), + })?; + } + seq.serialize_element(&e)?; + } + seq.end() + } +} + +#[derive(Clone, Debug)] +struct PostPolicyCondition<'a> { + field: PostPolicyField<'a>, + value: PostPolicyValue<'a>, +} + +impl Serialize for PostPolicyCondition<'_> { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let f: Cow = self.field.clone().into(); + + match &self.value { + PostPolicyValue::Exact(e) => { + let mut map = serializer.serialize_map(Some(1))?; + map.serialize_entry(&f, e)?; + map.end() + } + PostPolicyValue::StartsWith(e) => { + let mut seq = serializer.serialize_tuple(3)?; + seq.serialize_element("starts-with")?; + let field = format!("${}", f); + seq.serialize_element(&field)?; + seq.serialize_element(e)?; + seq.end() + } + PostPolicyValue::Anything => { + let mut seq = serializer.serialize_tuple(3)?; + seq.serialize_element("starts-with")?; + let field = format!("${}", f); + seq.serialize_element(&field)?; + seq.serialize_element("")?; + seq.end() + } + PostPolicyValue::Range(b, e) => { + if matches!(self.field, PostPolicyField::ContentLengthRange) { + let mut seq = serializer.serialize_tuple(3)?; + seq.serialize_element("content-length-range")?; + seq.serialize_element(b)?; + seq.serialize_element(e)?; + seq.end() + } else { + Err(ser::Error::custom( + "Range is only valid for ContentLengthRange", + )) + } + } + } + } +} + +/// Policy fields to add to the conditions of the policy +#[derive(Clone, Debug)] +#[non_exhaustive] +pub enum PostPolicyField<'a> { + /// The destination path. Supports [`PostPolicyValue::StartsWith`] + Key, + /// The ACL policy. Supports [`PostPolicyValue::StartsWith`] + Acl, + /// Custom tag XML document + Tagging, + /// Successful redirect URL. Supports [`PostPolicyValue::StartsWith`] + SuccessActionRedirect, + /// Successful action status (e.g. 200, 201, or 204). + SuccessActionStatus, + + /// The cache control Supports [`PostPolicyValue::StartsWith`] + CacheControl, + /// The content length (must use the [`PostPolicyValue::Range`]) + ContentLengthRange, + /// The content type. Supports [`PostPolicyValue::StartsWith`] + ContentType, + /// Content Disposition. Supports [`PostPolicyValue::StartsWith`] + ContentDisposition, + /// The content encoding. Supports [`PostPolicyValue::StartsWith`] + ContentEncoding, + /// The Expires header to respond when fetching. Supports [`PostPolicyValue::StartsWith`] + Expires, + + /// The server-side encryption type + AmzServerSideEncryption, + /// The SSE key ID to use (if the algorithm specified requires it) + AmzServerSideEncryptionKeyId, + /// The SSE context to use (if the algorithm specified requires it) + AmzServerSideEncryptionContext, + /// The storage class to use + AmzStorageClass, + /// Specify a bucket relative or absolute UR redirect to redirect to when fetching this object + AmzWebsiteRedirectLocation, + /// Checksum algorithm, the value is the checksum + AmzChecksumAlgorithm(PostPolicyChecksum), + /// Any user-defined meta fields (AmzMeta("uuid".to_string) creates an x-amz-meta-uuid) + AmzMeta(Cow<'a, str>), + + /// The credential. Auto added by the presign_post + AmzCredential, + /// The signing algorithm. Auto added by the presign_post + AmzAlgorithm, + /// The signing date. Auto added by the presign_post + AmzDate, + /// The Security token (for Amazon DevPay) + AmzSecurityToken, + /// The Bucket. Auto added by the presign_post + Bucket, + + /// Custom field. Any other string not enumerated above + Custom(Cow<'a, str>), +} + +impl<'a> Into> for PostPolicyField<'a> { + fn into(self) -> Cow<'a, str> { + match self { + PostPolicyField::Key => Cow::from("key"), + PostPolicyField::Acl => Cow::from("acl"), + PostPolicyField::Tagging => Cow::from("tagging"), + PostPolicyField::SuccessActionRedirect => Cow::from("success_action_redirect"), + PostPolicyField::SuccessActionStatus => Cow::from("success_action_status"), + PostPolicyField::CacheControl => Cow::from("Cache-Control"), + PostPolicyField::ContentLengthRange => Cow::from("content-length-range"), + PostPolicyField::ContentType => Cow::from("Content-Type"), + PostPolicyField::ContentDisposition => Cow::from("Content-Disposition"), + PostPolicyField::ContentEncoding => Cow::from("Content-Encoding"), + PostPolicyField::Expires => Cow::from("Expires"), + + PostPolicyField::AmzServerSideEncryption => Cow::from("x-amz-server-side-encryption"), + PostPolicyField::AmzServerSideEncryptionKeyId => { + Cow::from("x-amz-server-side-encryption-aws-kms-key-id") + } + PostPolicyField::AmzServerSideEncryptionContext => { + Cow::from("x-amz-server-side-encryption-context") + } + PostPolicyField::AmzStorageClass => Cow::from("x-amz-storage-class"), + PostPolicyField::AmzWebsiteRedirectLocation => { + Cow::from("x-amz-website-redirect-location") + } + PostPolicyField::AmzChecksumAlgorithm(e) => { + let e: Cow = e.into(); + Cow::from(format!("x-amz-checksum-{}", e)) + } + PostPolicyField::AmzMeta(e) => Cow::from(format!("x-amz-meta-{}", e)), + PostPolicyField::AmzCredential => Cow::from("x-amz-credential"), + PostPolicyField::AmzAlgorithm => Cow::from("x-amz-algorithm"), + PostPolicyField::AmzDate => Cow::from("x-amz-date"), + PostPolicyField::AmzSecurityToken => Cow::from("x-amz-security-token"), + PostPolicyField::Bucket => Cow::from("bucket"), + PostPolicyField::Custom(e) => e, + } + } +} + +#[derive(Clone, Copy, Debug)] +pub enum PostPolicyChecksum { + CRC32, + CRC32c, + SHA1, + SHA256, +} + +impl<'a> Into> for PostPolicyChecksum { + fn into(self) -> Cow<'a, str> { + match self { + PostPolicyChecksum::CRC32 => Cow::from("crc32"), + PostPolicyChecksum::CRC32c => Cow::from("crc32c"), + PostPolicyChecksum::SHA1 => Cow::from("sha1"), + PostPolicyChecksum::SHA256 => Cow::from("sha256"), + } + } +} + +#[derive(Clone, Debug)] +pub enum PostPolicyValue<'a> { + /// Shortcut for StartsWith("".to_string()) + Anything, + /// A string starting with a value + StartsWith(Cow<'a, str>), + /// A range of integer values. Only valid for some fields + Range(u32, u32), + /// An exact string value + Exact(Cow<'a, str>), +} + +#[derive(Clone, Debug)] +pub enum PostPolicyExpiration { + /// Expires in X seconds from "now" + ExpiresIn(u32), + /// Expires at exactly this time + ExpiresAt(Rfc3339OffsetDateTime), +} + +impl From for PostPolicyExpiration { + fn from(value: u32) -> Self { + Self::ExpiresIn(value) + } +} + +impl From for PostPolicyExpiration { + fn from(value: Rfc3339OffsetDateTime) -> Self { + Self::ExpiresAt(value) + } +} + +impl From for Rfc3339OffsetDateTime { + fn from(value: PostPolicyExpiration) -> Self { + match value { + PostPolicyExpiration::ExpiresIn(d) => { + Rfc3339OffsetDateTime(now_utc().saturating_add(Duration::seconds(d as i64))) + } + PostPolicyExpiration::ExpiresAt(t) => t, + } + } +} + +impl Serialize for PostPolicyExpiration { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + Rfc3339OffsetDateTime::from(self.clone()).serialize(serializer) + } +} + +#[derive(Debug)] +pub struct PresignedPost { + pub url: String, + pub fields: HashMap, + pub dynamic_fields: HashMap, + pub expiration: Rfc3339OffsetDateTime, +} + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum PostPolicyError { + #[error("This value is not supported for this field")] + MismatchedCondition, +} + +#[cfg(test)] +mod test { + use super::*; + + use crate::creds::Credentials; + use crate::region::Region; + use crate::utils::with_timestamp; + + use serde_json::json; + + fn test_bucket() -> Bucket { + Bucket::new( + "rust-s3", + Region::UsEast1, + Credentials::new( + Some("AKIAIOSFODNN7EXAMPLE"), + Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), + None, + None, + None, + ) + .unwrap(), + ) + .unwrap() + } + + fn test_bucket_with_security_token() -> Bucket { + Bucket::new( + "rust-s3", + Region::UsEast1, + Credentials::new( + Some("AKIAIOSFODNN7EXAMPLE"), + Some("wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"), + Some("SomeSecurityToken"), + None, + None, + ) + .unwrap(), + ) + .unwrap() + } + + mod conditions { + use super::*; + + #[test] + fn starts_with_condition() { + let policy = PostPolicy::new(300) + .condition( + PostPolicyField::Key, + PostPolicyValue::StartsWith(Cow::from("users/user1/")), + ) + .unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert!(data["expiration"].is_string()); + assert_eq!( + data["conditions"], + json!([["starts-with", "$key", "users/user1/"]]) + ); + } + + #[test] + fn exact_condition() { + let policy = PostPolicy::new(300) + .condition( + PostPolicyField::Acl, + PostPolicyValue::Exact(Cow::from("public-read")), + ) + .unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert!(data["expiration"].is_string()); + assert_eq!(data["conditions"], json!([{"acl":"public-read"}])); + } + + #[test] + fn anything_condition() { + let policy = PostPolicy::new(300) + .condition(PostPolicyField::Key, PostPolicyValue::Anything) + .unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert!(data["expiration"].is_string()); + assert_eq!(data["conditions"], json!([["starts-with", "$key", ""]])); + } + + #[test] + fn range_condition() { + let policy = PostPolicy::new(300) + .condition( + PostPolicyField::ContentLengthRange, + PostPolicyValue::Range(0, 3_000_000), + ) + .unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert!(data["expiration"].is_string()); + assert_eq!( + data["conditions"], + json!([["content-length-range", 0, 3_000_000]]) + ); + } + + #[test] + fn range_condition_for_non_content_length_range() -> Result<(), S3Error> { + let result = PostPolicy::new(86400) + .condition(PostPolicyField::ContentType, PostPolicyValue::Range(0, 100)); + + assert!(matches!( + result, + Err(S3Error::PostPolicyError( + PostPolicyError::MismatchedCondition + )) + )); + + Ok(()) + } + + #[test] + fn starts_with_condition_for_content_length_range() -> Result<(), S3Error> { + let result = PostPolicy::new(86400).condition( + PostPolicyField::ContentLengthRange, + PostPolicyValue::StartsWith(Cow::from("")), + ); + + assert!(matches!( + result, + Err(S3Error::PostPolicyError( + PostPolicyError::MismatchedCondition + )) + )); + + Ok(()) + } + + #[test] + fn exact_condition_for_content_length_range() -> Result<(), S3Error> { + let result = PostPolicy::new(86400).condition( + PostPolicyField::ContentLengthRange, + PostPolicyValue::Exact(Cow::from("test")), + ); + + assert!(matches!( + result, + Err(S3Error::PostPolicyError( + PostPolicyError::MismatchedCondition + )) + )); + + Ok(()) + } + + #[test] + fn anything_condition_for_content_length_range() -> Result<(), S3Error> { + let result = PostPolicy::new(86400).condition( + PostPolicyField::ContentLengthRange, + PostPolicyValue::Anything, + ); + + assert!(matches!( + result, + Err(S3Error::PostPolicyError( + PostPolicyError::MismatchedCondition + )) + )); + + Ok(()) + } + + #[test] + fn checksum_policy() { + let policy = PostPolicy::new(300) + .condition( + PostPolicyField::AmzChecksumAlgorithm(PostPolicyChecksum::SHA256), + PostPolicyValue::Exact(Cow::from("abcdef1234567890")), + ) + .unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert!(data["expiration"].is_string()); + assert_eq!( + data["conditions"], + json!([ + {"x-amz-checksum-algorithm": "SHA256"}, + {"x-amz-checksum-sha256": "abcdef1234567890"} + ]) + ); + } + } + + mod build { + use super::*; + + #[test] + fn adds_credentials() { + let policy = PostPolicy::new(86400) + .condition( + PostPolicyField::Key, + PostPolicyValue::StartsWith(Cow::from("user/user1/")), + ) + .unwrap(); + + let bucket = test_bucket(); + + let _ts = with_timestamp(1_451_347_200); + let policy = policy.build(&now_utc(), &bucket).unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert_eq!( + data["conditions"], + json!([ + ["starts-with", "$key", "user/user1/"], + {"bucket": "rust-s3"}, + {"x-amz-algorithm": "AWS4-HMAC-SHA256"}, + {"x-amz-credential": "AKIAIOSFODNN7EXAMPLE/20151229/us-east-1/s3/aws4_request"}, + {"x-amz-date": "20151229T000000Z"}, + ]) + ); + } + + #[test] + fn with_security_token() { + let policy = PostPolicy::new(86400) + .condition( + PostPolicyField::Key, + PostPolicyValue::StartsWith(Cow::from("user/user1/")), + ) + .unwrap(); + + let bucket = test_bucket_with_security_token(); + + let _ts = with_timestamp(1_451_347_200); + let policy = policy.build(&now_utc(), &bucket).unwrap(); + + let data = serde_json::to_value(&policy).unwrap(); + + assert_eq!( + data["conditions"], + json!([ + ["starts-with", "$key", "user/user1/"], + {"bucket": "rust-s3"}, + {"x-amz-algorithm": "AWS4-HMAC-SHA256"}, + {"x-amz-credential": "AKIAIOSFODNN7EXAMPLE/20151229/us-east-1/s3/aws4_request"}, + {"x-amz-date": "20151229T000000Z"}, + {"x-amz-security-token": "SomeSecurityToken"}, + ]) + ); + } + } + + mod policy_string { + use super::*; + + #[test] + fn returns_base64_encoded() { + let policy = PostPolicy::new(129600) + .condition( + PostPolicyField::Key, + PostPolicyValue::StartsWith(Cow::from("user/user1/")), + ) + .unwrap(); + + let _ts = with_timestamp(1_451_347_200); + + let expected = "eyJleHBpcmF0aW9uIjoiMjAxNS0xMi0zMFQxMjowMDowMFoiLCJjb25kaXRpb25zIjpbWyJzdGFydHMtd2l0aCIsIiRrZXkiLCJ1c2VyL3VzZXIxLyJdXX0="; + + assert_eq!(policy.policy_string().unwrap(), expected); + } + } + + mod sign { + use super::*; + + #[test] + fn returns_full_details() { + let policy = PostPolicy::new(86400) + .condition( + PostPolicyField::Key, + PostPolicyValue::StartsWith(Cow::from("user/user1/")), + ) + .unwrap() + .condition( + PostPolicyField::ContentLengthRange, + PostPolicyValue::Range(0, 3_000_000), + ) + .unwrap(); + + let bucket = test_bucket(); + + let _ts = with_timestamp(1_451_347_200); + let post = policy.sign(bucket).unwrap(); + + assert_eq!(post.url, "https://rust-s3.s3.amazonaws.com"); + assert_eq!( + serde_json::to_value(&post.fields).unwrap(), + json!({ + "x-amz-credential": "AKIAIOSFODNN7EXAMPLE/20151229/us-east-1/s3/aws4_request", + "bucket": "rust-s3", + "Policy": "eyJleHBpcmF0aW9uIjoiMjAxNS0xMi0zMFQwMDowMDowMFoiLCJjb25kaXRpb25zIjpbWyJzdGFydHMtd2l0aCIsIiRrZXkiLCJ1c2VyL3VzZXIxLyJdLFsiY29udGVudC1sZW5ndGgtcmFuZ2UiLDAsMzAwMDAwMF0seyJidWNrZXQiOiJydXN0LXMzIn0seyJ4LWFtei1hbGdvcml0aG0iOiJBV1M0LUhNQUMtU0hBMjU2In0seyJ4LWFtei1jcmVkZW50aWFsIjoiQUtJQUlPU0ZPRE5ON0VYQU1QTEUvMjAxNTEyMjkvdXMtZWFzdC0xL3MzL2F3czRfcmVxdWVzdCJ9LHsieC1hbXotZGF0ZSI6IjIwMTUxMjI5VDAwMDAwMFoifV19", + "x-amz-date": "20151229T000000Z", + "x-amz-signature": "0ff9c50ab7e543a841e91e5c663fd32117c5243e56e7a69db88f94ee95c4706f", + "x-amz-algorithm": "AWS4-HMAC-SHA256" + }) + ); + assert_eq!( + serde_json::to_value(&post.dynamic_fields).unwrap(), + json!({ + "key": "user/user1/", + "content-length-range": "0,3000000", + }) + ); + } + } +} diff --git a/s3/src/request/async_std_backend.rs b/s3/src/request/async_std_backend.rs index 7ceb97725e..7bdf670d58 100644 --- a/s3/src/request/async_std_backend.rs +++ b/s3/src/request/async_std_backend.rs @@ -8,6 +8,7 @@ use std::collections::HashMap; use crate::bucket::Bucket; use crate::command::Command; use crate::error::S3Error; +use crate::utils::now_utc; use time::OffsetDateTime; use crate::command::HttpMethod; @@ -181,7 +182,7 @@ impl<'a> SurfRequest<'a> { bucket, path, command, - datetime: OffsetDateTime::now_utc(), + datetime: now_utc(), sync: false, }) } diff --git a/s3/src/request/blocking.rs b/s3/src/request/blocking.rs index d9c44d24e4..035f54c74f 100644 --- a/s3/src/request/blocking.rs +++ b/s3/src/request/blocking.rs @@ -9,6 +9,7 @@ use attohttpc::header::HeaderName; use crate::bucket::Bucket; use crate::command::Command; use crate::error::S3Error; +use crate::utils::now_utc; use bytes::Bytes; use std::collections::HashMap; use time::OffsetDateTime; @@ -138,7 +139,7 @@ impl<'a> AttoRequest<'a> { bucket, path, command, - datetime: OffsetDateTime::now_utc(), + datetime: now_utc(), sync: false, }) } diff --git a/s3/src/request/request_trait.rs b/s3/src/request/request_trait.rs index 64ae6215ad..334243bf08 100644 --- a/s3/src/request/request_trait.rs +++ b/s3/src/request/request_trait.rs @@ -178,14 +178,7 @@ pub trait Request { } fn string_to_sign(&self, request: &str) -> Result { - match self.command() { - Command::PresignPost { post_policy, .. } => Ok(post_policy), - _ => Ok(signing::string_to_sign( - &self.datetime(), - &self.bucket().region(), - request, - )?), - } + signing::string_to_sign(&self.datetime(), &self.bucket().region(), request) } fn host_header(&self) -> String { diff --git a/s3/src/request/tokio_backend.rs b/s3/src/request/tokio_backend.rs index d13ddd9d20..90c9b0c72e 100644 --- a/s3/src/request/tokio_backend.rs +++ b/s3/src/request/tokio_backend.rs @@ -15,6 +15,7 @@ use crate::bucket::Bucket; use crate::command::Command; use crate::command::HttpMethod; use crate::error::S3Error; +use crate::utils::now_utc; use tokio_stream::StreamExt; @@ -66,9 +67,7 @@ impl<'a> Request for HyperRequest<'a> { let mut http_connector = HttpConnector::new(); http_connector.set_connect_timeout(self.bucket.request_timeout); - // let https_connector = HttpsConnector::from((http_connector, tls_connector)); - - let https_connector = HttpsConnector::new(); + let https_connector = HttpsConnector::from((http_connector, tls_connector)); let client = Client::builder().build::<_, hyper::Body>(https_connector); @@ -194,7 +193,7 @@ impl<'a> HyperRequest<'a> { bucket, path, command, - datetime: OffsetDateTime::now_utc(), + datetime: now_utc(), sync: false, }) } diff --git a/s3/src/utils.rs b/s3/src/utils/mod.rs similarity index 99% rename from s3/src/utils.rs rename to s3/src/utils/mod.rs index 6ad5d580f6..a28f1a9b7b 100644 --- a/s3/src/utils.rs +++ b/s3/src/utils/mod.rs @@ -1,3 +1,7 @@ +mod time_utils; + +pub use time_utils::*; + use std::str::FromStr; use crate::error::S3Error; diff --git a/s3/src/utils/time_utils.rs b/s3/src/utils/time_utils.rs new file mode 100644 index 0000000000..d677eaccdd --- /dev/null +++ b/s3/src/utils/time_utils.rs @@ -0,0 +1,161 @@ +use std::time::{SystemTime, SystemTimeError, UNIX_EPOCH}; +use time::OffsetDateTime; + +fn real_time() -> Result { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|t| t.as_secs()) +} + +#[cfg(not(test))] +pub fn current_time() -> Result { + real_time() +} + +pub fn now_utc() -> OffsetDateTime { + OffsetDateTime::from_unix_timestamp(current_time().unwrap() as i64).unwrap() +} + +#[cfg(test)] +mod mocked_time { + use super::*; + + use std::cell::Cell; + + thread_local! { + static TIMESTAMP: Cell = Cell::new(0); + } + + pub fn current_time() -> Result { + TIMESTAMP.with(|ts| { + let time = ts.get(); + if time == 0 { + real_time() + } else { + Ok(time) + } + }) + } + + fn set_timestamp(timestamp: u64) -> u64 { + TIMESTAMP.with(|ts| { + let old = ts.get(); + ts.set(timestamp); + old + }) + } + + pub struct MockTimestamp { + old: u64, + } + + impl MockTimestamp { + // Get the real clock time + pub fn real_time(&self) -> u64 { + real_time().unwrap() + } + + // Get the old time before this call to with_timestamp + pub fn old_time(&self) -> u64 { + self.old + } + + // Sets the time to the exact unix timestamp + // 0 means use real_time + // Returns old time + pub fn set_time(&self, timestamp: u64) -> u64 { + set_timestamp(timestamp) + } + + // Add this many seconds to the current time + // Can be negative + // Returns new now, old time + pub fn add_time(&self, time_delta: i64) -> (u64, u64) { + let now = ((current_time().unwrap() as i64) + time_delta) as u64; + (now, set_timestamp(now)) + } + } + + impl Drop for MockTimestamp { + fn drop(&mut self) { + set_timestamp(self.old); + } + } + + pub fn with_timestamp(timestamp: u64) -> MockTimestamp { + MockTimestamp { + old: set_timestamp(timestamp), + } + } + + #[cfg(test)] + mod tests { + use super::*; + + const MOCKED_TIMESTAMP: u64 = 5_000_000; + + mod current_time { + use super::*; + + #[test] + fn when_set_timestamp_not_called() { + let now = real_time().unwrap(); + + assert!(current_time().unwrap() >= now); + } + + #[test] + fn when_set_timestamp_was_called() { + set_timestamp(MOCKED_TIMESTAMP); + + assert_eq!(current_time().unwrap(), MOCKED_TIMESTAMP); + + set_timestamp(0); + } + } + + mod with_timestamp { + use super::*; + + #[test] + fn when_resets_when_result_dropped() { + let now = real_time().unwrap(); + let ts = with_timestamp(MOCKED_TIMESTAMP); + + assert_eq!(current_time().unwrap(), MOCKED_TIMESTAMP); + + drop(ts); + + assert!(current_time().unwrap() >= now); + } + + #[test] + fn when_nested() { + let now = real_time().unwrap(); + { + let _ts = with_timestamp(MOCKED_TIMESTAMP); + + assert_eq!(current_time().unwrap(), MOCKED_TIMESTAMP); + + { + let _ts = with_timestamp(MOCKED_TIMESTAMP + 1_000); + + assert_eq!(current_time().unwrap(), MOCKED_TIMESTAMP + 1_000); + } + + { + let _ts = with_timestamp(0); + + assert!(current_time().unwrap() >= now); + } + + assert_eq!(current_time().unwrap(), MOCKED_TIMESTAMP); + } + + assert!(current_time().unwrap() >= now); + } + } + } +} +#[cfg(test)] +pub use mocked_time::*;