From a68fe1601f047c7f0918b910417bb391cc736e53 Mon Sep 17 00:00:00 2001 From: houseme Date: Mon, 27 Apr 2026 20:39:02 +0800 Subject: [PATCH] fix(ilm): harden signer failures and guard remote tier delete storms (#2706) --- Cargo.lock | 1 + .../src/bucket/lifecycle/tier_sweeper.rs | 194 +++++- crates/ecstore/src/client/bucket_cache.rs | 18 +- crates/ecstore/src/client/mod.rs | 1 + crates/ecstore/src/client/signer_error.rs | 74 +++ crates/ecstore/src/client/transition_api.rs | 74 ++- crates/signer/Cargo.toml | 1 + crates/signer/src/lib.rs | 4 + crates/signer/src/request_signature_v4.rs | 580 ++++++++++++++---- crates/signer/src/utils.rs | 41 +- 10 files changed, 851 insertions(+), 137 deletions(-) create mode 100644 crates/ecstore/src/client/signer_error.rs diff --git a/Cargo.lock b/Cargo.lock index 3d716c03c..c2c97ea75 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8703,6 +8703,7 @@ dependencies = [ "rustfs-utils", "s3s", "serde_urlencoded", + "thiserror 2.0.18", "time", "tracing", ] diff --git a/crates/ecstore/src/bucket/lifecycle/tier_sweeper.rs b/crates/ecstore/src/bucket/lifecycle/tier_sweeper.rs index a5f037ba4..26f07c9ab 100644 --- a/crates/ecstore/src/bucket/lifecycle/tier_sweeper.rs +++ b/crates/ecstore/src/bucket/lifecycle/tier_sweeper.rs @@ -20,15 +20,144 @@ use crate::bucket::lifecycle::bucket_lifecycle_ops::{ExpiryOp, GLOBAL_ExpiryState, TransitionedObject}; use crate::bucket::lifecycle::lifecycle::{self, ObjectOpts}; +use crate::client::signer_error::error_chain_contains_signer_header_marker; use crate::global::GLOBAL_TierConfigMgr; +use rustfs_utils::get_env_usize; use sha2::{Digest, Sha256}; use std::any::Any; +use std::collections::VecDeque; use std::io::Write; +use std::sync::LazyLock; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::{Duration, Instant}; +use tokio::sync::{Mutex, Semaphore}; +use tracing::warn; use uuid::Uuid; use xxhash_rust::xxh64; static XXHASH_SEED: u64 = 0; +const ENV_REMOTE_DELETE_MAX_CONCURRENCY: &str = "RUSTFS_REMOTE_DELETE_MAX_CONCURRENCY"; +const ENV_REMOTE_DELETE_BREAKER_THRESHOLD: &str = "RUSTFS_REMOTE_DELETE_BREAKER_THRESHOLD"; +const ENV_REMOTE_DELETE_BREAKER_WINDOW_SECS: &str = "RUSTFS_REMOTE_DELETE_BREAKER_WINDOW_SECS"; +const DEFAULT_REMOTE_DELETE_BREAKER_THRESHOLD: usize = 50; +const DEFAULT_REMOTE_DELETE_BREAKER_WINDOW_SECS: usize = 30; +const METRIC_DELETE_REMOTE_FAILED_TOTAL: &str = "rustfs_delete_remote_failed_total"; +const METRIC_DELETE_REMOTE_BREAKER_TOTAL: &str = "rustfs_delete_remote_breaker_total"; +const METRIC_DELETE_REMOTE_INFLIGHT: &str = "rustfs_delete_remote_inflight"; + +static REMOTE_DELETE_INFLIGHT: AtomicUsize = AtomicUsize::new(0); + +static REMOTE_DELETE_LIMITER: LazyLock = LazyLock::new(|| { + let default_limit = std::cmp::min(num_cpus::get(), 16).max(1); + let concurrency = get_env_usize(ENV_REMOTE_DELETE_MAX_CONCURRENCY, default_limit).max(1); + Semaphore::new(concurrency) +}); + +static REMOTE_DELETE_BREAKER: LazyLock> = LazyLock::new(|| { + Mutex::new(RemoteDeleteBreaker::new( + get_env_usize(ENV_REMOTE_DELETE_BREAKER_THRESHOLD, DEFAULT_REMOTE_DELETE_BREAKER_THRESHOLD).max(1), + Duration::from_secs( + get_env_usize(ENV_REMOTE_DELETE_BREAKER_WINDOW_SECS, DEFAULT_REMOTE_DELETE_BREAKER_WINDOW_SECS) as u64, + ), + )) +}); + +#[derive(Debug)] +struct RemoteDeleteBreaker { + threshold: usize, + window: Duration, + failures: VecDeque, +} + +impl RemoteDeleteBreaker { + fn new(threshold: usize, window: Duration) -> Self { + Self { + threshold: threshold.max(1), + window: window.max(Duration::from_secs(1)), + failures: VecDeque::new(), + } + } + + fn should_short_circuit(&mut self, now: Instant) -> bool { + self.prune(now); + self.failures.len() >= self.threshold + } + + fn record_signer_failure(&mut self, now: Instant) -> bool { + self.prune(now); + let was_open = self.failures.len() >= self.threshold; + self.failures.push_back(now); + !was_open && self.failures.len() >= self.threshold + } + + fn prune(&mut self, now: Instant) { + while let Some(ts) = self.failures.front().copied() { + if now.duration_since(ts) > self.window { + self.failures.pop_front(); + } else { + break; + } + } + } +} + +struct RemoteDeleteInflightGuard; + +impl RemoteDeleteInflightGuard { + fn new() -> Self { + let inflight = REMOTE_DELETE_INFLIGHT.fetch_add(1, Ordering::Relaxed) + 1; + metrics::gauge!(METRIC_DELETE_REMOTE_INFLIGHT).set(inflight as f64); + Self + } +} + +impl Drop for RemoteDeleteInflightGuard { + fn drop(&mut self) { + let inflight = REMOTE_DELETE_INFLIGHT.fetch_sub(1, Ordering::Relaxed) - 1; + metrics::gauge!(METRIC_DELETE_REMOTE_INFLIGHT).set(inflight as f64); + } +} + +fn is_signer_header_error(err: &std::io::Error) -> bool { + if err.kind() != std::io::ErrorKind::InvalidInput { + return false; + } + + if let Some(source) = err.get_ref() { + if error_chain_contains_signer_header_marker(source) { + return true; + } + } + + let message = err.to_string().to_ascii_lowercase(); + message.contains("invalid utf-8 header value") + || message.contains("invalidheadervalue") + || (message.contains("sign v4") && message.contains("header value")) +} + +async fn remote_delete_breaker_is_open(now: Instant) -> bool { + let mut breaker = REMOTE_DELETE_BREAKER.lock().await; + breaker.should_short_circuit(now) +} + +async fn record_remote_delete_failure(err: &std::io::Error, now: Instant) { + metrics::counter!(METRIC_DELETE_REMOTE_FAILED_TOTAL).increment(1); + + if !is_signer_header_error(err) { + return; + } + + let mut breaker = REMOTE_DELETE_BREAKER.lock().await; + if breaker.record_signer_failure(now) { + warn!( + threshold = breaker.threshold, + window_secs = breaker.window.as_secs(), + "remote tier delete breaker opened by signer/header failures" + ); + } +} + #[derive(Default)] #[allow(dead_code)] struct ObjSweeper { @@ -148,12 +277,31 @@ impl ExpiryOp for Jentry { } pub async fn delete_object_from_remote_tier(obj_name: &str, rv_id: &str, tier_name: &str) -> Result<(), std::io::Error> { + if remote_delete_breaker_is_open(Instant::now()).await { + metrics::counter!(METRIC_DELETE_REMOTE_BREAKER_TOTAL).increment(1); + return Err(std::io::Error::other("remote tier delete breaker is open due to signer/header failures")); + } + + let _permit = REMOTE_DELETE_LIMITER + .acquire() + .await + .map_err(|_| std::io::Error::other("remote tier delete limiter is closed"))?; + let _inflight = RemoteDeleteInflightGuard::new(); + let mut config_mgr = GLOBAL_TierConfigMgr.write().await; let w = match config_mgr.get_driver(tier_name).await { Ok(w) => w, - Err(e) => return Err(std::io::Error::other(e)), + Err(e) => { + let err = std::io::Error::other(e); + record_remote_delete_failure(&err, Instant::now()).await; + return Err(err); + } }; - w.remove(obj_name, rv_id).await + let result = w.remove(obj_name, rv_id).await; + if let Err(err) = &result { + record_remote_delete_failure(err, Instant::now()).await; + } + result } pub fn transitioned_delete_journal_entry( @@ -189,4 +337,44 @@ pub fn transitioned_force_delete_journal_entry(transitioned: &TransitionedObject } #[cfg(test)] -mod test {} +mod test { + use crate::client::signer_error::invalid_utf8_header_error; + + use super::{RemoteDeleteBreaker, is_signer_header_error}; + use std::io::{Error, ErrorKind}; + use std::time::{Duration, Instant}; + + #[test] + fn signer_header_error_detection_matches_utf8_failures() { + let err = Error::new( + ErrorKind::InvalidInput, + "failed to sign v4 request: invalid UTF-8 header value for `x-amz-meta-invalid`", + ); + assert!(is_signer_header_error(&err)); + } + + #[test] + fn signer_header_error_detection_rejects_unrelated_errors() { + let err = Error::other("dial tcp: i/o timeout"); + assert!(!is_signer_header_error(&err)); + } + + #[test] + fn signer_header_error_detection_matches_structured_marker() { + let err = invalid_utf8_header_error("failed to sign v4 request", "x-amz-meta-invalid"); + assert!(is_signer_header_error(&err)); + } + + #[test] + fn breaker_opens_at_threshold_and_recovers_after_window() { + let mut breaker = RemoteDeleteBreaker::new(3, Duration::from_secs(30)); + let start = Instant::now(); + + assert!(!breaker.should_short_circuit(start)); + assert!(!breaker.record_signer_failure(start)); + assert!(!breaker.record_signer_failure(start + Duration::from_secs(1))); + assert!(breaker.record_signer_failure(start + Duration::from_secs(2))); + assert!(breaker.should_short_circuit(start + Duration::from_secs(3))); + assert!(!breaker.should_short_circuit(start + Duration::from_secs(40))); + } +} diff --git a/crates/ecstore/src/client/bucket_cache.rs b/crates/ecstore/src/client/bucket_cache.rs index bb207095f..739c0f72d 100644 --- a/crates/ecstore/src/client/bucket_cache.rs +++ b/crates/ecstore/src/client/bucket_cache.rs @@ -22,6 +22,7 @@ use super::constants::UNSIGNED_PAYLOAD; use super::credentials::SignatureType; use crate::client::{ api_error_response::http_resp_to_error_response, + signer_error, transition_api::{CreateBucketConfiguration, LocationConstraint, TransitionClient}, }; use http::Request; @@ -35,6 +36,10 @@ use rustfs_utils::hash::EMPTY_STRING_SHA256_HASH; use s3s::S3ErrorCode; use std::collections::HashMap; +fn signer_error_to_io_error(scope: &str, error: rustfs_signer::SignV4Error) -> std::io::Error { + signer_error::signer_error_to_io_error(scope, error) +} + #[derive(Debug, Clone)] pub struct BucketLocationCache { items: HashMap, @@ -179,10 +184,15 @@ impl TransitionClient { content_sha256 = UNSIGNED_PAYLOAD.to_string(); } - if let Ok(content_sha256_value) = content_sha256.parse() { - req.headers_mut().insert("X-Amz-Content-Sha256", content_sha256_value); - } - let req = rustfs_signer::sign_v4(req, 0, &access_key_id, &secret_access_key, &session_token, "us-east-1"); + let content_sha256_value = content_sha256.parse().map_err(|err| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("invalid X-Amz-Content-Sha256 header value: {err}"), + ) + })?; + req.headers_mut().insert("X-Amz-Content-Sha256", content_sha256_value); + let req = rustfs_signer::try_sign_v4(req, 0, &access_key_id, &secret_access_key, &session_token, "us-east-1") + .map_err(|err| signer_error_to_io_error("failed to sign bucket location request", err))?; Ok(req) } } diff --git a/crates/ecstore/src/client/mod.rs b/crates/ecstore/src/client/mod.rs index c3c9e2374..9fb9ed1e1 100644 --- a/crates/ecstore/src/client/mod.rs +++ b/crates/ecstore/src/client/mod.rs @@ -35,5 +35,6 @@ pub mod constants; pub mod credentials; pub mod object_api_utils; pub mod object_handlers_common; +pub mod signer_error; pub mod transition_api; pub mod utils; diff --git a/crates/ecstore/src/client/signer_error.rs b/crates/ecstore/src/client/signer_error.rs new file mode 100644 index 000000000..566b49803 --- /dev/null +++ b/crates/ecstore/src/client/signer_error.rs @@ -0,0 +1,74 @@ +// Copyright 2024 RustFS Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::error::Error as StdError; +use std::fmt::{Display, Formatter}; +use std::io::{Error, ErrorKind}; + +pub(crate) const SIGNER_HEADER_ERROR_MARKER: &str = "rustfs_signer_header_error"; + +#[derive(Debug)] +struct SignerHeaderError { + scope: String, + header_name: String, +} + +impl SignerHeaderError { + fn new(scope: &str, header_name: &str) -> Self { + Self { + scope: scope.to_string(), + header_name: header_name.to_string(), + } + } +} + +impl Display for SignerHeaderError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: invalid UTF-8 header value for `{}` [{}]", + self.scope, self.header_name, SIGNER_HEADER_ERROR_MARKER + ) + } +} + +impl StdError for SignerHeaderError {} + +pub(crate) fn invalid_utf8_header_error(scope: &str, header_name: &str) -> Error { + Error::new(ErrorKind::InvalidInput, SignerHeaderError::new(scope, header_name)) +} + +pub(crate) fn signer_error_to_io_error(scope: &str, error: rustfs_signer::SignV4Error) -> Error { + match error { + rustfs_signer::SignV4Error::InvalidHeaderValue { name } => invalid_utf8_header_error(scope, &name), + other => Error::other(format!("{scope}: {other}")), + } +} + +pub(crate) fn error_chain_contains_signer_header_marker(err: &(dyn StdError + 'static)) -> bool { + let mut current = Some(err); + while let Some(source) = current { + if source.downcast_ref::().is_some() { + return true; + } + + if source.to_string().contains(SIGNER_HEADER_ERROR_MARKER) { + return true; + } + + current = source.source(); + } + + false +} diff --git a/crates/ecstore/src/client/transition_api.rs b/crates/ecstore/src/client/transition_api.rs index 0854fca05..e2a2ac386 100644 --- a/crates/ecstore/src/client/transition_api.rs +++ b/crates/ecstore/src/client/transition_api.rs @@ -31,6 +31,7 @@ use crate::client::{ }, constants::{UNSIGNED_PAYLOAD, UNSIGNED_PAYLOAD_TRAILER}, credentials::{CredContext, Credentials, SignatureType, Static}, + signer_error, }; use crate::{client::checksum::ChecksumMode, store_api::GetObjectReader}; use futures::{Future, StreamExt}; @@ -85,6 +86,21 @@ const C_UNKNOWN: i32 = -1; const C_OFFLINE: i32 = 0; const C_ONLINE: i32 = 1; +fn invalid_utf8_header_error(scope: &str, header_name: &str) -> std::io::Error { + signer_error::invalid_utf8_header_error(scope, header_name) +} + +fn validate_header_values(headers: &HeaderMap, scope: &str) -> Result<(), std::io::Error> { + for (name, value) in headers { + value.to_str().map_err(|_| invalid_utf8_header_error(scope, name.as_str()))?; + } + Ok(()) +} + +fn signer_error_to_io_error(scope: &str, error: rustfs_signer::SignV4Error) -> std::io::Error { + signer_error::signer_error_to_io_error(scope, error) +} + //pub type ReaderImpl = Box; pub enum ReaderImpl { Body(Bytes), @@ -560,8 +576,9 @@ impl TransitionClient { "extra signed headers for presign with signature v2 is not supported.", ))); } - let headers = req.headers_mut(); if let Some(extra_headers) = metadata.extra_pre_sign_header.as_ref() { + validate_header_values(extra_headers, "presign extra header")?; + let headers = req.headers_mut(); for (k, v) in extra_headers { headers.insert(k, v.clone()); } @@ -570,7 +587,7 @@ impl TransitionClient { if signer_type == SignatureType::SignatureV2 { req = rustfs_signer::pre_sign_v2(req, &access_key_id, &secret_access_key, metadata.expires, is_virtual_host); } else if signer_type == SignatureType::SignatureV4 { - req = rustfs_signer::pre_sign_v4( + req = rustfs_signer::try_pre_sign_v4( req, &access_key_id, &secret_access_key, @@ -578,12 +595,14 @@ impl TransitionClient { &location, metadata.expires, OffsetDateTime::now_utc(), - ); + ) + .map_err(|err| signer_error_to_io_error("failed to presign v4 request", err))?; } return Ok(req); } self.set_user_agent(&mut req); + validate_header_values(&metadata.custom_header, "request custom header")?; for (k, v) in metadata.custom_header.clone() { if let Some(key) = k { @@ -593,15 +612,15 @@ impl TransitionClient { //req.content_length = metadata.content_length; if metadata.content_length <= -1 { - if let Ok(chunked_value) = HeaderValue::from_str(&vec!["chunked"].join(",")) { - req.headers_mut().insert(http::header::TRANSFER_ENCODING, chunked_value); - } + req.headers_mut() + .insert(http::header::TRANSFER_ENCODING, HeaderValue::from_static("chunked")); } - if metadata.content_md5_base64.len() > 0 { - if let Ok(md5_value) = HeaderValue::from_str(&metadata.content_md5_base64) { - req.headers_mut().insert("Content-Md5", md5_value); - } + if !metadata.content_md5_base64.is_empty() { + let md5_value = HeaderValue::from_str(&metadata.content_md5_base64).map_err(|err| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("invalid Content-Md5 header value: {err}")) + })?; + req.headers_mut().insert("Content-Md5", md5_value); } if signer_type == SignatureType::SignatureAnonymous { @@ -634,14 +653,15 @@ impl TransitionClient { .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?; req.headers_mut().insert(header_name, header_value); - req = rustfs_signer::sign_v4_trailer( + req = rustfs_signer::try_sign_v4_trailer( req, &access_key_id, &secret_access_key, &session_token, &location, metadata.trailer.clone(), - ); + ) + .map_err(|err| signer_error_to_io_error("failed to sign v4 request", err))?; } if metadata.content_length > 0 { @@ -1354,7 +1374,10 @@ pub struct CreateBucketConfiguration { #[cfg(test)] mod tests { - use super::{build_tls_config, load_root_store_from_tls_path, with_rustls_init_guard}; + use super::{ + build_tls_config, load_root_store_from_tls_path, signer_error_to_io_error, validate_header_values, with_rustls_init_guard, + }; + use http::{HeaderMap, HeaderValue}; #[test] fn rustls_guard_converts_panics_to_io_errors() { @@ -1404,4 +1427,29 @@ mod tests { }); assert!(outcome.is_ok(), "provider install guard must not panic when a provider is already set"); } + + #[test] + fn validate_header_values_returns_header_name_for_non_utf8_values() { + let mut headers = HeaderMap::new(); + headers.insert( + "x-amz-meta-invalid", + HeaderValue::from_bytes(&[0xFF]).expect("invalid utf8 bytes should be accepted by HeaderValue"), + ); + + let err = + validate_header_values(&headers, "request custom header").expect_err("invalid header value should fail validation"); + assert!(err.to_string().contains("x-amz-meta-invalid")); + } + + #[test] + fn signer_error_mapping_preserves_header_name() { + let err = signer_error_to_io_error( + "failed to sign v4 request", + rustfs_signer::SignV4Error::InvalidHeaderValue { + name: "x-amz-meta-invalid".to_string(), + }, + ); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidInput); + assert!(err.to_string().contains("x-amz-meta-invalid")); + } } diff --git a/crates/signer/Cargo.toml b/crates/signer/Cargo.toml index 49291ebe3..8cc8249b0 100644 --- a/crates/signer/Cargo.toml +++ b/crates/signer/Cargo.toml @@ -35,6 +35,7 @@ serde_urlencoded.workspace = true rustfs-utils = { workspace = true, features = ["full"] } s3s.workspace = true base64-simd.workspace = true +thiserror.workspace = true [lints] workspace = true diff --git a/crates/signer/src/lib.rs b/crates/signer/src/lib.rs index c13f33db9..582294d07 100644 --- a/crates/signer/src/lib.rs +++ b/crates/signer/src/lib.rs @@ -22,6 +22,10 @@ pub mod utils; pub use request_signature_streaming::streaming_sign_v4; pub use request_signature_v2::pre_sign_v2; pub use request_signature_v2::sign_v2; +pub use request_signature_v4::SignV4Error; pub use request_signature_v4::pre_sign_v4; pub use request_signature_v4::sign_v4; pub use request_signature_v4::sign_v4_trailer; +pub use request_signature_v4::try_pre_sign_v4; +pub use request_signature_v4::try_sign_v4; +pub use request_signature_v4::try_sign_v4_trailer; diff --git a/crates/signer/src/request_signature_v4.rs b/crates/signer/src/request_signature_v4.rs index 0536706c7..a4b4741de 100644 --- a/crates/signer/src/request_signature_v4.rs +++ b/crates/signer/src/request_signature_v4.rs @@ -20,11 +20,11 @@ use std::collections::HashMap; use std::fmt::Write; use std::sync::LazyLock; use time::{OffsetDateTime, macros::format_description}; -use tracing::debug; +use tracing::{debug, warn}; use super::constants::UNSIGNED_PAYLOAD; use super::request_signature_streaming_unsigned_trailer::streaming_unsigned_v4; -use super::utils::{get_host_addr, sign_v4_trim_all}; +use super::utils::{HostAddrError, sign_v4_trim_all, try_get_host_addr}; use rustfs_utils::crypto::{hex, hex_sha256, hmac_sha256}; use s3s::Body; @@ -32,6 +32,36 @@ pub const SIGN_V4_ALGORITHM: &str = "AWS4-HMAC-SHA256"; pub const SERVICE_TYPE_S3: &str = "s3"; pub const SERVICE_TYPE_STS: &str = "sts"; +#[derive(Debug, thiserror::Error)] +pub enum SignV4Error { + #[error("invalid UTF-8 header value for `{name}`")] + InvalidHeaderValue { name: String }, + #[error("failed to format signing timestamp: {reason}")] + TimeFormat { reason: String }, + #[error("failed to build signing timestamp: {reason}")] + TimeComponent { reason: String }, + #[error("failed to encode query parameters: {reason}")] + QueryEncode { reason: String }, + #[error("failed to parse uri: {reason}")] + InvalidUri { reason: String }, + #[error("failed to build uri from parts: {reason}")] + InvalidUriParts { reason: String }, + #[error("failed to convert canonical headers to UTF-8: {reason}")] + CanonicalUtf8 { reason: String }, + #[error("failed to parse header value for `{name}`: {reason}")] + HeaderValueParse { name: String, reason: String }, +} + +pub type SignResult = std::result::Result; + +#[derive(Debug)] +struct SignFailure { + request: request::Request, + error: SignV4Error, +} + +type SignOutcome = std::result::Result, Box>; + #[allow(non_upper_case_globals)] // FIXME static v4_ignored_headers: LazyLock> = LazyLock::new(|| { let mut m = >::new(); @@ -41,11 +71,28 @@ static v4_ignored_headers: LazyLock> = LazyLock::new(|| { m }); +fn fail(request: request::Request, error: SignV4Error) -> SignOutcome { + Err(Box::new(SignFailure { request, error })) +} + +fn format_yyyymmdd(t: OffsetDateTime) -> String { + let mut value = String::with_capacity(8); + // Build YYYYMMDD directly from date components to avoid formatter fallbacks. + let _ = write!(value, "{:04}{:02}{:02}", t.year(), u8::from(t.month()), t.day()); + value +} + +fn format_amz_datetime(t: OffsetDateTime) -> SignResult { + let format = format_description!("[year][month][day]T[hour][minute][second]Z"); + t.format(&format) + .map_err(|err| SignV4Error::TimeFormat { reason: err.to_string() }) +} + pub fn get_signing_key(secret: &str, loc: &str, t: OffsetDateTime, service_type: &str) -> [u8; 32] { let mut s = "AWS4".to_string(); s.push_str(secret); - let format = format_description!("[year][month][day]"); - let date = hmac_sha256(s.into_bytes(), t.format(&format).unwrap().into_bytes()); + let date_value = format_yyyymmdd(t); + let date = hmac_sha256(s.into_bytes(), date_value.into_bytes()); let location = hmac_sha256(date, loc); let service = hmac_sha256(location, service_type); @@ -57,9 +104,8 @@ pub fn get_signature(signing_key: [u8; 32], string_to_sign: &str) -> String { } pub fn get_scope(location: &str, t: OffsetDateTime, service_type: &str) -> String { - let format = format_description!("[year][month][day]"); let mut ans = String::from(""); - ans.push_str(&t.format(&format).unwrap()); + ans.push_str(format_yyyymmdd(t).as_str()); ans.push('/'); ans.push_str(location); ans.push('/'); @@ -76,19 +122,21 @@ fn get_credential(access_key_id: &str, location: &str, t: OffsetDateTime, servic s } -fn get_hashed_payload(req: &request::Request) -> String { +fn try_get_hashed_payload(req: &request::Request) -> SignResult { let headers = req.headers(); let mut hashed_payload = ""; if let Some(payload) = headers.get("X-Amz-Content-Sha256") { - hashed_payload = payload.to_str().unwrap(); + hashed_payload = payload.to_str().map_err(|_| SignV4Error::InvalidHeaderValue { + name: "x-amz-content-sha256".to_string(), + })?; } if hashed_payload.is_empty() { hashed_payload = UNSIGNED_PAYLOAD; } - hashed_payload.to_string() + Ok(hashed_payload.to_string()) } -fn get_canonical_headers(req: &request::Request, ignored_headers: &HashMap) -> String { +fn try_get_canonical_headers(req: &request::Request, ignored_headers: &HashMap) -> SignResult { let mut headers = Vec::::new(); let mut vals = HashMap::>::new(); for k in req.headers().keys() { @@ -100,8 +148,14 @@ fn get_canonical_headers(req: &request::Request, ignored_headers: &HashMap .headers() .get_all(k) .iter() - .map(|e| e.to_str().unwrap().to_string()) - .collect(); + .map(|e| { + e.to_str() + .map(|v| v.to_string()) + .map_err(|_| SignV4Error::InvalidHeaderValue { + name: k.as_str().to_lowercase(), + }) + }) + .collect::>>()?; vals.insert(k.as_str().to_lowercase(), vv); } if !header_exists("host", &headers) { @@ -119,11 +173,22 @@ fn get_canonical_headers(req: &request::Request, ignored_headers: &HashMap let k: &str = &k; match k { "host" => { - let _ = buf.write_str(&get_host_addr(req)); + let host_addr = try_get_host_addr(req).map_err(|err| match err { + HostAddrError::InvalidHostHeader => SignV4Error::InvalidHeaderValue { + name: "host".to_string(), + }, + HostAddrError::MissingUriHost => SignV4Error::InvalidUri { + reason: "request uri has no host".to_string(), + }, + })?; + let _ = buf.write_str(&host_addr); let _ = buf.write_char('\n'); } _ => { - for (idx, v) in vals[k].iter().enumerate() { + let Some(values) = vals.get(k) else { + continue; + }; + for (idx, v) in values.iter().enumerate() { if idx > 0 { let _ = buf.write_char(','); } @@ -133,7 +198,7 @@ fn get_canonical_headers(req: &request::Request, ignored_headers: &HashMap } } } - String::from_utf8(buf.to_vec()).unwrap() + String::from_utf8(buf.to_vec()).map_err(|err| SignV4Error::CanonicalUtf8 { reason: err.to_string() }) } fn header_exists(key: &str, headers: &[String]) -> bool { @@ -162,7 +227,11 @@ fn get_signed_headers(req: &request::Request, ignored_headers: &HashMap, ignored_headers: &HashMap, hashed_payload: &str) -> String { +fn try_get_canonical_request( + req: &request::Request, + ignored_headers: &HashMap, + hashed_payload: &str, +) -> SignResult { let mut canonical_query_string = "".to_string(); if let Some(q) = req.uri().query() { // Parse query string into key-value pairs @@ -192,26 +261,30 @@ fn get_canonical_request(req: &request::Request, ignored_headers: &HashMap req.method().to_string(), req.uri().path().to_string(), canonical_query_string, - get_canonical_headers(req, ignored_headers), + try_get_canonical_headers(req, ignored_headers)?, get_signed_headers(req, ignored_headers), hashed_payload.to_string(), ]; - canonical_request.join("\n") + Ok(canonical_request.join("\n")) } -fn get_string_to_sign_v4(t: OffsetDateTime, location: &str, canonical_request: &str, service_type: &str) -> String { +fn try_get_string_to_sign_v4( + t: OffsetDateTime, + location: &str, + canonical_request: &str, + service_type: &str, +) -> SignResult { let mut string_to_sign = SIGN_V4_ALGORITHM.to_string(); string_to_sign.push('\n'); - let format = format_description!("[year][month][day]T[hour][minute][second]Z"); - string_to_sign.push_str(&t.format(&format).unwrap()); + string_to_sign.push_str(format_amz_datetime(t)?.as_str()); string_to_sign.push('\n'); string_to_sign.push_str(&get_scope(location, t, service_type)); string_to_sign.push('\n'); string_to_sign.push_str(&hex_sha256(canonical_request.as_bytes(), |s| s.to_string())); - string_to_sign + Ok(string_to_sign) } -pub fn pre_sign_v4( +fn pre_sign_v4_inner( req: request::Request, access_key_id: &str, secret_access_key: &str, @@ -219,9 +292,9 @@ pub fn pre_sign_v4( location: &str, expires: i64, t: OffsetDateTime, -) -> request::Request { +) -> SignOutcome { if access_key_id.is_empty() || secret_access_key.is_empty() { - return req; + return Ok(req); } let credential = get_credential(access_key_id, location, t, SERVICE_TYPE_S3); @@ -233,8 +306,11 @@ pub fn pre_sign_v4( query = result.unwrap_or_default(); } query.push(("X-Amz-Algorithm".to_string(), SIGN_V4_ALGORITHM.to_string())); - let format = format_description!("[year][month][day]T[hour][minute][second]Z"); - query.push(("X-Amz-Date".to_string(), t.format(&format).unwrap())); + let amz_date = match format_amz_datetime(t) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; + query.push(("X-Amz-Date".to_string(), amz_date)); query.push(("X-Amz-Expires".to_string(), format!("{expires:010}"))); query.push(("X-Amz-SignedHeaders".to_string(), signed_headers)); query.push(("X-Amz-Credential".to_string(), credential)); @@ -244,16 +320,38 @@ pub fn pre_sign_v4( let uri = req.uri().clone(); let mut parts = req.uri().clone().into_parts(); - parts.path_and_query = Some( - format!("{}?{}", uri.path(), serde_urlencoded::to_string(&query).unwrap()) - .parse() - .unwrap(), - ); + let query_str = match serde_urlencoded::to_string(&query) { + Ok(value) => value, + Err(err) => { + return fail(req, SignV4Error::QueryEncode { reason: err.to_string() }); + } + }; + parts.path_and_query = Some(match format!("{}?{}", uri.path(), query_str).parse() { + Ok(value) => value, + Err(err) => { + return fail(req, SignV4Error::InvalidUri { reason: err.to_string() }); + } + }); let mut req = req; - *req.uri_mut() = Uri::from_parts(parts).unwrap(); + *req.uri_mut() = match Uri::from_parts(parts) { + Ok(value) => value, + Err(err) => { + return fail(req, SignV4Error::InvalidUriParts { reason: err.to_string() }); + } + }; - let canonical_request = get_canonical_request(&req, &v4_ignored_headers, &get_hashed_payload(&req)); - let string_to_sign = get_string_to_sign_v4(t, location, &canonical_request, SERVICE_TYPE_S3); + let hashed_payload = match try_get_hashed_payload(&req) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; + let canonical_request = match try_get_canonical_request(&req, &v4_ignored_headers, &hashed_payload) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; + let string_to_sign = match try_get_string_to_sign_v4(t, location, &canonical_request, SERVICE_TYPE_S3) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; //println!("canonical_request: \n{}\n", canonical_request); //println!("string_to_sign: \n{}\n", string_to_sign); let signing_key = get_signing_key(secret_access_key, location, t, SERVICE_TYPE_S3); @@ -261,20 +359,57 @@ pub fn pre_sign_v4( let uri = req.uri().clone(); let mut parts = req.uri().clone().into_parts(); - parts.path_and_query = Some( - format!( - "{}?{}&X-Amz-Signature={}", - uri.path(), - serde_urlencoded::to_string(&query).unwrap(), - signature - ) - .parse() - .unwrap(), - ); + let query_str = match serde_urlencoded::to_string(&query) { + Ok(value) => value, + Err(err) => { + return fail(req, SignV4Error::QueryEncode { reason: err.to_string() }); + } + }; + parts.path_and_query = Some(match format!("{}?{}&X-Amz-Signature={}", uri.path(), query_str, signature).parse() { + Ok(value) => value, + Err(err) => { + return fail(req, SignV4Error::InvalidUri { reason: err.to_string() }); + } + }); - *req.uri_mut() = Uri::from_parts(parts).unwrap(); + *req.uri_mut() = match Uri::from_parts(parts) { + Ok(value) => value, + Err(err) => { + return fail(req, SignV4Error::InvalidUriParts { reason: err.to_string() }); + } + }; - req + Ok(req) +} + +pub fn try_pre_sign_v4( + req: request::Request, + access_key_id: &str, + secret_access_key: &str, + session_token: &str, + location: &str, + expires: i64, + t: OffsetDateTime, +) -> SignResult> { + pre_sign_v4_inner(req, access_key_id, secret_access_key, session_token, location, expires, t).map_err(|f| f.error) +} + +pub fn pre_sign_v4( + req: request::Request, + access_key_id: &str, + secret_access_key: &str, + session_token: &str, + location: &str, + expires: i64, + t: OffsetDateTime, +) -> request::Request { + match pre_sign_v4_inner(req, access_key_id, secret_access_key, session_token, location, expires, t) { + Ok(request) => request, + Err(failure) => { + warn!(error = %failure.error, "failed to presign v4 request"); + failure.request + } + } } fn _post_pre_sign_signature_v4(policy_base64: &str, t: OffsetDateTime, secret_access_key: &str, location: &str) -> String { @@ -289,7 +424,13 @@ fn _sign_v4_sts( secret_access_key: &str, location: &str, ) -> request::Request { - sign_v4_inner(req, 0, access_key_id, secret_access_key, "", location, SERVICE_TYPE_STS, HeaderMap::new()) + match sign_v4_inner(req, 0, access_key_id, secret_access_key, "", location, SERVICE_TYPE_STS, HeaderMap::new()) { + Ok(request) => request, + Err(failure) => { + warn!(error = %failure.error, "failed to sign v4 sts request"); + failure.request + } + } } #[allow(clippy::too_many_arguments)] @@ -302,38 +443,119 @@ fn sign_v4_inner( location: &str, service_type: &str, trailer: HeaderMap, -) -> request::Request { +) -> SignOutcome { if access_key_id.is_empty() || secret_access_key.is_empty() { - return req; + return Ok(req); } let t = OffsetDateTime::now_utc(); - let t2 = t.replace_time(time::Time::from_hms(0, 0, 0).unwrap()); + let t2 = match time::Time::from_hms(0, 0, 0) { + Ok(midnight) => t.replace_time(midnight), + Err(err) => { + return fail(req, SignV4Error::TimeComponent { reason: err.to_string() }); + } + }; - let headers = req.headers_mut(); - let format = format_description!("[year][month][day]T[hour][minute][second]Z"); - headers.insert("X-Amz-Date", t.format(&format).unwrap().parse().unwrap()); + let amz_date = match format_amz_datetime(t) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; + let amz_date_value = match amz_date.parse::() { + Ok(value) => value, + Err(err) => { + return fail( + req, + SignV4Error::HeaderValueParse { + name: "X-Amz-Date".to_string(), + reason: err.to_string(), + }, + ); + } + }; + req.headers_mut().insert("X-Amz-Date", amz_date_value); if !session_token.is_empty() { - headers.insert("X-Amz-Security-Token", session_token.parse().unwrap()); + let token_value = match session_token.parse::() { + Ok(value) => value, + Err(err) => { + return fail( + req, + SignV4Error::HeaderValueParse { + name: "X-Amz-Security-Token".to_string(), + reason: err.to_string(), + }, + ); + } + }; + req.headers_mut().insert("X-Amz-Security-Token", token_value); } if !trailer.is_empty() { + let mut trailer_values = Vec::new(); for (k, _) in &trailer { - headers.append("X-Amz-Trailer", k.as_str().to_lowercase().parse().unwrap()); + let parsed = match k.as_str().to_lowercase().parse::() { + Ok(value) => value, + Err(err) => { + return fail( + req, + SignV4Error::HeaderValueParse { + name: "X-Amz-Trailer".to_string(), + reason: err.to_string(), + }, + ); + } + }; + trailer_values.push(parsed); + } + let content_encoding = match "aws-chunked".parse::() { + Ok(value) => value, + Err(err) => { + return fail( + req, + SignV4Error::HeaderValueParse { + name: "Content-Encoding".to_string(), + reason: err.to_string(), + }, + ); + } + }; + let decoded_len = match format!("{content_len:010}").parse::() { + Ok(value) => value, + Err(err) => { + return fail( + req, + SignV4Error::HeaderValueParse { + name: "x-amz-decoded-content-length".to_string(), + reason: err.to_string(), + }, + ); + } + }; + let headers = req.headers_mut(); + for value in trailer_values { + headers.append("X-Amz-Trailer", value); } - headers.insert("Content-Encoding", "aws-chunked".parse().unwrap()); - headers.insert("x-amz-decoded-content-length", format!("{content_len:010}").parse().unwrap()); + headers.insert("Content-Encoding", content_encoding); + headers.insert("x-amz-decoded-content-length", decoded_len); } if service_type == SERVICE_TYPE_STS { - headers.remove("X-Amz-Content-Sha256"); + req.headers_mut().remove("X-Amz-Content-Sha256"); } - let hashed_payload = get_hashed_payload(&req); - let canonical_request = get_canonical_request(&req, &v4_ignored_headers, &hashed_payload); - let string_to_sign = get_string_to_sign_v4(t, location, &canonical_request, service_type); + let hashed_payload = match try_get_hashed_payload(&req) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; + let canonical_request = match try_get_canonical_request(&req, &v4_ignored_headers, &hashed_payload) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; + let string_to_sign = match try_get_string_to_sign_v4(t, location, &canonical_request, service_type) { + Ok(value) => value, + Err(err) => return fail(req, err), + }; let signing_key = get_signing_key(secret_access_key, location, t, service_type); let credential = get_credential(access_key_id, location, t2, service_type); let signed_headers = get_signed_headers(&req, &v4_ignored_headers); @@ -343,42 +565,28 @@ fn sign_v4_inner( let headers = req.headers_mut(); let auth = format!("{SIGN_V4_ALGORITHM} Credential={credential}, SignedHeaders={signed_headers}, Signature={signature}"); - headers.insert("Authorization", auth.parse().unwrap()); + let auth_value = match auth.parse::() { + Ok(value) => value, + Err(err) => { + return fail( + req, + SignV4Error::HeaderValueParse { + name: "Authorization".to_string(), + reason: err.to_string(), + }, + ); + } + }; + headers.insert("Authorization", auth_value); if !trailer.is_empty() { //req.Trailer = trailer; for (_, v) in &trailer { headers.append(http::header::TRAILER, v.clone()); } - return streaming_unsigned_v4(req, session_token, content_len, t); + return Ok(streaming_unsigned_v4(req, session_token, content_len, t)); } - req -} - -fn _unsigned_trailer(mut req: request::Request, content_len: i64, trailer: HeaderMap) { - if !trailer.is_empty() { - return; - } - let t = OffsetDateTime::now_utc(); - let t = t.replace_time(time::Time::from_hms(0, 0, 0).unwrap()); - - let headers = req.headers_mut(); - let format = format_description!("[year][month][day]T[hour][minute][second]Z"); - headers.insert("X-Amz-Date", t.format(&format).unwrap().parse().unwrap()); - - for (k, _) in &trailer { - headers.append("X-Amz-Trailer", k.as_str().to_lowercase().parse().unwrap()); - } - - headers.insert("Content-Encoding", "aws-chunked".parse().unwrap()); - headers.insert("x-amz-decoded-content-length", format!("{content_len:010}").parse().unwrap()); - - if !trailer.is_empty() { - for (_, v) in &trailer { - headers.append(http::header::TRAILER, v.clone()); - } - } - streaming_unsigned_v4(req, "", content_len, t); + Ok(req) } pub fn sign_v4( @@ -389,6 +597,32 @@ pub fn sign_v4( session_token: &str, location: &str, ) -> request::Request { + match sign_v4_inner( + req, + content_len, + access_key_id, + secret_access_key, + session_token, + location, + SERVICE_TYPE_S3, + HeaderMap::new(), + ) { + Ok(request) => request, + Err(failure) => { + warn!(error = %failure.error, "failed to sign v4 request"); + failure.request + } + } +} + +pub fn try_sign_v4( + req: request::Request, + content_len: i64, + access_key_id: &str, + secret_access_key: &str, + session_token: &str, + location: &str, +) -> SignResult> { sign_v4_inner( req, content_len, @@ -399,6 +633,7 @@ pub fn sign_v4( SERVICE_TYPE_S3, HeaderMap::new(), ) + .map_err(|failure| failure.error) } pub fn sign_v4_trailer( @@ -409,6 +644,32 @@ pub fn sign_v4_trailer( location: &str, trailer: HeaderMap, ) -> request::Request { + match sign_v4_inner( + req, + 0, + access_key_id, + secret_access_key, + session_token, + location, + SERVICE_TYPE_S3, + trailer, + ) { + Ok(request) => request, + Err(failure) => { + warn!(error = %failure.error, "failed to sign v4 trailer request"); + failure.request + } + } +} + +pub fn try_sign_v4_trailer( + req: request::Request, + access_key_id: &str, + secret_access_key: &str, + session_token: &str, + location: &str, + trailer: HeaderMap, +) -> SignResult> { sign_v4_inner( req, 0, @@ -419,11 +680,13 @@ pub fn sign_v4_trailer( SERVICE_TYPE_S3, trailer, ) + .map_err(|failure| failure.error) } #[cfg(test)] #[allow(unused_variables, unused_mut)] mod tests { + use http::HeaderValue; use http::request; use time::macros::datetime; @@ -468,7 +731,9 @@ mod tests { ); *req.uri_mut() = Uri::from_parts(parts).unwrap(); - let canonical_request = get_canonical_request(&req, &v4_ignored_headers, &get_hashed_payload(&req)); + let hashed_payload = try_get_hashed_payload(&req).expect("example request should have valid payload header"); + let canonical_request = + try_get_canonical_request(&req, &v4_ignored_headers, &hashed_payload).expect("example request should canonicalize"); assert_eq!( canonical_request, concat!( @@ -486,7 +751,8 @@ mod tests { ) ); - let string_to_sign = get_string_to_sign_v4(t, region, &canonical_request, service); + let string_to_sign = try_get_string_to_sign_v4(t, region, &canonical_request, service) + .expect("example request should build string-to-sign"); assert_eq!( string_to_sign, concat!( @@ -542,7 +808,9 @@ mod tests { //println!("parts.path_and_query: {:?}", parts.path_and_query); *req.uri_mut() = Uri::from_parts(parts).unwrap(); - let canonical_request = get_canonical_request(&req, &v4_ignored_headers, &get_hashed_payload(&req)); + let hashed_payload = try_get_hashed_payload(&req).expect("example request should have valid payload header"); + let canonical_request = + try_get_canonical_request(&req, &v4_ignored_headers, &hashed_payload).expect("example request should canonicalize"); println!("canonical_request: \n{canonical_request}\n"); assert_eq!( canonical_request, @@ -561,7 +829,8 @@ mod tests { ) ); - let string_to_sign = get_string_to_sign_v4(t, region, &canonical_request, service); + let string_to_sign = try_get_string_to_sign_v4(t, region, &canonical_request, service) + .expect("example request should build string-to-sign"); println!("string_to_sign: \n{string_to_sign}\n"); assert_eq!( string_to_sign, @@ -607,7 +876,9 @@ mod tests { headers.insert("x-amz-date", timestamp.parse().unwrap()); println!("{:?}", req.uri().query()); - let canonical_request = get_canonical_request(&req, &v4_ignored_headers, &get_hashed_payload(&req)); + let hashed_payload = try_get_hashed_payload(&req).expect("example request should have valid payload header"); + let canonical_request = + try_get_canonical_request(&req, &v4_ignored_headers, &hashed_payload).expect("example request should canonicalize"); println!("canonical_request: \n{canonical_request}\n"); assert_eq!( canonical_request, @@ -626,7 +897,8 @@ mod tests { ) ); - let string_to_sign = get_string_to_sign_v4(t, region, &canonical_request, service); + let string_to_sign = try_get_string_to_sign_v4(t, region, &canonical_request, service) + .expect("example request should build string-to-sign"); println!("string_to_sign: \n{string_to_sign}\n"); assert_eq!( string_to_sign, @@ -672,7 +944,9 @@ mod tests { headers.insert("x-amz-date", timestamp.parse().unwrap()); println!("{:?}", req.uri().query()); - let canonical_request = get_canonical_request(&req, &v4_ignored_headers, &get_hashed_payload(&req)); + let hashed_payload = try_get_hashed_payload(&req).expect("example request should have valid payload header"); + let canonical_request = + try_get_canonical_request(&req, &v4_ignored_headers, &hashed_payload).expect("example request should canonicalize"); println!("canonical_request: \n{canonical_request}\n"); assert_eq!( canonical_request, @@ -691,7 +965,8 @@ mod tests { ) ); - let string_to_sign = get_string_to_sign_v4(t, region, &canonical_request, service); + let string_to_sign = try_get_string_to_sign_v4(t, region, &canonical_request, service) + .expect("example request should build string-to-sign"); println!("string_to_sign: \n{string_to_sign}\n"); assert_eq!( string_to_sign, @@ -739,11 +1014,19 @@ mod tests { canonical_request.push('\n'); canonical_request.push_str(req.uri().query().unwrap()); canonical_request.push('\n'); - canonical_request.push_str(&get_canonical_headers(&req, &v4_ignored_headers)); + canonical_request.push_str( + try_get_canonical_headers(&req, &v4_ignored_headers) + .expect("presigned request should canonicalize headers") + .as_str(), + ); canonical_request.push('\n'); canonical_request.push_str(&get_signed_headers(&req, &v4_ignored_headers)); canonical_request.push('\n'); - canonical_request.push_str(&get_hashed_payload(&req)); + canonical_request.push_str( + try_get_hashed_payload(&req) + .expect("presigned request should include payload hash") + .as_str(), + ); //println!("canonical_request: \n{}\n", canonical_request); assert_eq!( canonical_request, @@ -787,11 +1070,19 @@ mod tests { canonical_request.push('\n'); canonical_request.push_str(req.uri().query().unwrap()); canonical_request.push('\n'); - canonical_request.push_str(&get_canonical_headers(&req, &v4_ignored_headers)); + canonical_request.push_str( + try_get_canonical_headers(&req, &v4_ignored_headers) + .expect("presigned request should canonicalize headers") + .as_str(), + ); canonical_request.push('\n'); canonical_request.push_str(&get_signed_headers(&req, &v4_ignored_headers)); canonical_request.push('\n'); - canonical_request.push_str(&get_hashed_payload(&req)); + canonical_request.push_str( + try_get_hashed_payload(&req) + .expect("presigned request should include payload hash") + .as_str(), + ); //println!("canonical_request: \n{}\n", canonical_request); assert_eq!( canonical_request, @@ -806,4 +1097,87 @@ mod tests { ) ); } + + fn build_request_with_invalid_header_value(uri: &str) -> request::Request { + let mut req = request::Request::builder() + .method(http::Method::GET) + .uri(uri) + .body(Body::empty()) + .unwrap(); + let headers = req.headers_mut(); + headers.insert("host", HeaderValue::from_static("examplebucket.s3.amazonaws.com")); + headers.insert("x-amz-content-sha256", HeaderValue::from_static(UNSIGNED_PAYLOAD)); + headers.insert("x-amz-meta-invalid", HeaderValue::from_bytes(&[0xFF]).unwrap()); + req + } + + #[test] + fn try_sign_v4_returns_error_for_non_utf8_header_value() { + let req = build_request_with_invalid_header_value("http://examplebucket.s3.amazonaws.com/object"); + let err = try_sign_v4(req, 0, "rustfsadmin", "rustfsadmin", "", "us-east-1").unwrap_err(); + assert!(matches!( + err, + SignV4Error::InvalidHeaderValue { name } if name == "x-amz-meta-invalid" + )); + } + + #[test] + fn try_sign_v4_returns_invalid_uri_error_when_uri_has_no_host() { + let mut req = request::Request::builder() + .method(http::Method::GET) + .uri("/object") + .body(Body::empty()) + .unwrap(); + let headers = req.headers_mut(); + headers.insert("host", HeaderValue::from_static("examplebucket.s3.amazonaws.com")); + headers.insert("x-amz-content-sha256", HeaderValue::from_static(UNSIGNED_PAYLOAD)); + + let err = try_sign_v4(req, 0, "rustfsadmin", "rustfsadmin", "", "us-east-1").unwrap_err(); + assert!(matches!( + err, + SignV4Error::InvalidUri { reason } if reason.contains("no host") + )); + } + + #[test] + fn legacy_sign_apis_do_not_panic_on_non_utf8_header_value() { + let signed = sign_v4( + build_request_with_invalid_header_value("http://examplebucket.s3.amazonaws.com/object"), + 0, + "rustfsadmin", + "rustfsadmin", + "", + "us-east-1", + ); + assert!(signed.headers().get(http::header::AUTHORIZATION).is_none()); + + let presigned = pre_sign_v4( + build_request_with_invalid_header_value("http://examplebucket.s3.amazonaws.com/object"), + "rustfsadmin", + "rustfsadmin", + "", + "us-east-1", + 60, + datetime!(2026-04-27 00:00:00 UTC), + ); + let query = presigned.uri().query().unwrap_or_default(); + assert!(!query.contains("X-Amz-Signature=")); + } + + #[test] + fn sign_v4_sts_returns_original_request_on_non_utf8_header_value() { + let signed = _sign_v4_sts( + build_request_with_invalid_header_value("http://examplebucket.s3.amazonaws.com/object"), + "rustfsadmin", + "rustfsadmin", + "us-east-1", + ); + assert!(signed.headers().get(http::header::AUTHORIZATION).is_none()); + } + + #[test] + fn format_yyyymmdd_is_zero_padded() { + let t = datetime!(0001-01-02 03:04:05 UTC); + assert_eq!(format_yyyymmdd(t), "00010102"); + } } diff --git a/crates/signer/src/utils.rs b/crates/signer/src/utils.rs index 8f31f793e..900fd003e 100644 --- a/crates/signer/src/utils.rs +++ b/crates/signer/src/utils.rs @@ -16,24 +16,37 @@ use http::request; use s3s::Body; -pub fn get_host_addr(req: &request::Request) -> String { +#[derive(Debug, thiserror::Error)] +pub enum HostAddrError { + #[error("invalid UTF-8 header value for `host`")] + InvalidHostHeader, + #[error("request uri has no host")] + MissingUriHost, +} + +pub fn try_get_host_addr(req: &request::Request) -> Result { let host = req.headers().get("host"); let uri = req.uri(); - let req_host; - if let Some(port) = uri.port() { - req_host = format!("{}:{}", uri.host().unwrap(), port); + let uri_host = uri.host().ok_or(HostAddrError::MissingUriHost)?; + + let req_host = if let Some(port) = uri.port() { + format!("{uri_host}:{port}") } else { - req_host = uri.host().unwrap().to_string(); + uri_host.to_string() + }; + + if let Some(host) = host { + let host = host.to_str().map_err(|_| HostAddrError::InvalidHostHeader)?; + if req_host != host { + return Ok(host.to_string()); + } } - if let Some(host) = host - && req_host != *host.to_str().unwrap() - { - return (*host.to_str().unwrap()).to_string(); - } - /*if req.uri_ref().unwrap().host().is_some() { - return req.uri_ref().unwrap().host().unwrap(); - }*/ - req_host + + Ok(req_host) +} + +pub fn get_host_addr(req: &request::Request) -> String { + try_get_host_addr(req).unwrap() } pub fn sign_v4_trim_all(input: &str) -> String {