Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced error handling #35

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ serde_derive = "1.0"
serde_json = "1.0"
structopt = "0.3"
tokio = "0.2.20"
lazy_static = "1.4.0"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend using once_cell crate, because it is going to be a standard way to handle lazy initialization of static variables. By "standard" I mean literally standard, this crate is already exposed as part of std, though unstable: rust-lang/rust#74465


[dev-dependencies]
pretty_assertions = "0.6.1"
Expand Down
210 changes: 118 additions & 92 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::hashed_regex::HashedRegex;
use crate::{error_handling::ErrorHandling, hashed_regex::HashedRegex};
use anyhow::Error;
use http::header::{HeaderName, HeaderValue};
use log::Level;
use regex::{Captures, Regex, Replacer};
use reqwest::Client;
use serde_derive::{Deserialize, Serialize};
use std::{
Expand Down Expand Up @@ -31,13 +32,13 @@ pub struct Config {
/// The number of seconds a cached result is valid for.
#[serde(default = "default_cache_timeout")]
pub cache_timeout: u64,
/// The policy to use when warnings are encountered.
#[serde(default)]
pub warning_policy: WarningPolicy,
/// The map of regexes representing sets of web sites and
/// the list of HTTP headers that must be sent to matching sites.
#[serde(default)]
pub http_headers: HashMap<HashedRegex, Vec<HttpHeader>>,
/// How should non-valid links be handled?
#[serde(default)]
pub error_handling: ErrorHandling,
}

#[derive(Serialize, Deserialize, Debug, PartialEq, Clone)]
Expand All @@ -49,7 +50,7 @@ pub struct HttpHeader {

impl HttpHeader {
pub(crate) fn interpolate(&self) -> Result<HeaderValue, Error> {
interpolate_env(&self.value)
interpolate_env(&self.value, |var| std::env::var(var).ok())
}
}

Expand Down Expand Up @@ -82,10 +83,9 @@ impl Config {

pub(crate) fn interpolate_headers(
&self,
warning_policy: WarningPolicy,
error_handling: &ErrorHandling,
) -> Vec<(HashedRegex, Vec<(HeaderName, HeaderValue)>)> {
let mut all_headers = Vec::new();
let log_level = warning_policy.to_log_level();

for (pattern, headers) in &self.http_headers {
let mut interpolated = Vec::new();
Expand All @@ -102,12 +102,8 @@ impl Config {
//
// If it was important, the user would notice a "broken"
// link and read back through the logs.
log::log!(
log_level,
"Unable to interpolate \"{}\" because {}",
header,
e
);
error_handling
.on_header_interpolation_error(header, &e);
},
}
}
Expand All @@ -127,8 +123,8 @@ impl Default for Config {
exclude: Vec::new(),
user_agent: default_user_agent(),
http_headers: HashMap::new(),
warning_policy: WarningPolicy::Warn,
cache_timeout: Config::DEFAULT_CACHE_TIMEOUT.as_secs(),
error_handling: ErrorHandling::default(),
}
}
}
Expand Down Expand Up @@ -179,98 +175,93 @@ impl Into<String> for HttpHeader {
fn default_cache_timeout() -> u64 { Config::DEFAULT_CACHE_TIMEOUT.as_secs() }
fn default_user_agent() -> String { Config::DEFAULT_USER_AGENT.to_string() }

fn interpolate_env(value: &str) -> Result<HeaderValue, Error> {
use std::{iter::Peekable, str::CharIndices};

fn is_ident(ch: char) -> bool { ch.is_ascii_alphanumeric() || ch == '_' }

fn ident_end(start: usize, iter: &mut Peekable<CharIndices>) -> usize {
let mut end = start;
while let Some(&(i, ch)) = iter.peek() {
if !is_ident(ch) {
return i;
}
end = i + ch.len_utf8();
iter.next();
}
lazy_static::lazy_static! {
static ref INTERPOLATED_VARIABLE: Regex = Regex::new(r"(?x)
(?P<escape>\\)?
\$
(?P<variable>[\w_][\w_\d]*)
").unwrap();
}

end
fn interpolate_env<F>(value: &str, get_var: F) -> Result<HeaderValue, Error>
where
F: FnMut(&str) -> Option<String>,
{
let mut failed_replacements: Vec<String> = Vec::new();

let interpolated = INTERPOLATED_VARIABLE
.replace_all(value, replacer(&mut failed_replacements, get_var));

if failed_replacements.is_empty() {
interpolated.parse().map_err(Error::from)
} else {
Err(Error::from(InterpolationError {
variable_names: failed_replacements,
original_string: value.to_string(),
}))
}
}

let mut res = String::with_capacity(value.len());
let mut backslash = false;
let mut iter = value.char_indices().peekable();

while let Some((i, ch)) = iter.next() {
if backslash {
match ch {
'$' | '\\' => res.push(ch),
_ => {
res.push('\\');
res.push(ch);
},
}

backslash = false;
} else {
match ch {
'\\' => backslash = true,
'$' => {
iter.next();
let start = i + 1;
let end = ident_end(start, &mut iter);
let name = &value[start..end];

match std::env::var(name) {
Ok(env) => res.push_str(&env),
Err(e) => {
return Err(Error::msg(format!(
"Failed to retrieve `{}` env var: {}",
name, e
)))
},
}
/// Gets a `Replacer` which will try to replace a variable with the result
/// from the `get_var()` function, recording any errors that happen.
fn replacer<'a, V>(
failed_replacements: &'a mut Vec<String>,
mut get_var: V,
) -> impl Replacer + 'a
where
V: FnMut(&str) -> Option<String> + 'a,
{
move |caps: &Captures<'_>| {
if caps.name("escape").is_none() {
let variable = &caps["variable"];

match get_var(variable) {
Some(value) => return value,
None => {
failed_replacements.push(variable.to_string());
},

_ => res.push(ch),
}
}
}

// trailing backslash
if backslash {
res.push('\\');
// the dollar sign was escaped (e.g. "\$foo") or we couldn't get
// the environment variable
caps[0].to_string()
}

Ok(res.parse()?)
}

/// How should warnings be treated?
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum WarningPolicy {
/// Silently ignore them.
Ignore,
/// Warn the user, but don't fail the linkcheck.
Warn,
/// Treat warnings as errors.
Error,
#[derive(Debug, Clone)]
struct InterpolationError {
pub variable_names: Vec<String>,
pub original_string: String,
}

impl WarningPolicy {
pub(crate) fn to_log_level(self) -> Level {
match self {
WarningPolicy::Error => Level::Error,
WarningPolicy::Warn => Level::Warn,
WarningPolicy::Ignore => Level::Debug,
impl std::error::Error for InterpolationError {}

impl Display for InterpolationError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if self.variable_names.len() == 1 {
write!(
f,
"Unable to interpolate `${}` into \"{}\"",
&self.variable_names[0], self.original_string
)
} else {
let formatted_names: Vec<_> = self
.variable_names
.iter()
.map(|v| format!("`${}`", v))
.collect();

write!(
f,
"Unable to interpolate `${}` into \"{}\"",
formatted_names.join(", "),
self.original_string
)
}
}
}

impl Default for WarningPolicy {
fn default() -> WarningPolicy { WarningPolicy::Warn }
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -282,7 +273,6 @@ traverse-parent-directories = true
exclude = ["google\\.com"]
user-agent = "Internet Explorer"
cache-timeout = 3600
warning-policy = "error"

[http-headers]
https = ["accept: html/text", "authorization: Basic $TOKEN"]
Expand All @@ -294,7 +284,6 @@ https = ["accept: html/text", "authorization: Basic $TOKEN"]

let should_be = Config {
follow_web_links: true,
warning_policy: WarningPolicy::Error,
traverse_parent_directories: true,
exclude: vec![HashedRegex::new(r"google\.com").unwrap()],
user_agent: String::from("Internet Explorer"),
Expand All @@ -306,6 +295,7 @@ https = ["accept: html/text", "authorization: Basic $TOKEN"]
],
)]),
cache_timeout: 3600,
error_handling: ErrorHandling::default(),
};

let got: Config = toml::from_str(CONFIG).unwrap();
Expand Down Expand Up @@ -338,4 +328,40 @@ https = ["accept: html/text", "authorization: Basic $TOKEN"]

assert_eq!(got, should_be);
}

#[test]
fn interplate_a_single_variable() {
let text = "Hello, $name";

let got = interpolate_env(text, |name| {
if name == "name" {
Some(String::from("World!"))
} else {
None
}
})
.unwrap();

assert_eq!(got, "Hello, World!");
}

#[test]
fn you_can_skip_interpolation_by_escaping_the_dollar_sign() {
let text = r"Hello, \$name";

let got = interpolate_env(text, |_| unreachable!()).unwrap();

assert_eq!(got, text);
}

#[test]
fn not_having_the_requested_variable_is_an_error() {
let text = r"Hello, $name";
let never_works = |_name: &str| None;

let got = interpolate_env(text, never_works).unwrap_err();

let inner = got.downcast::<InterpolationError>().unwrap();
assert_eq!(inner.variable_names, vec![String::from("name")]);
}
}
Loading