diff --git a/justfile b/justfile index 24c6cd395..48776bce8 100644 --- a/justfile +++ b/justfile @@ -197,6 +197,10 @@ lint: fmt clippy fmt: cargo fmt --all -- --check +# Run Nightly cargo fmt, ordering imports +fmt2: + cargo +nightly fmt -- --config imports_granularity=Module,group_imports=StdExternalCrate + # Run cargo clippy clippy: cargo clippy --workspace --all-targets --all-features --bins --tests --lib --benches -- -D warnings diff --git a/martin-mbtiles/src/mbtiles_queries.rs b/martin-mbtiles/src/mbtiles_queries.rs index d76b56343..ea662dcde 100644 --- a/martin-mbtiles/src/mbtiles_queries.rs +++ b/martin-mbtiles/src/mbtiles_queries.rs @@ -1,6 +1,7 @@ -use crate::errors::MbtResult; use sqlx::{query, SqliteExecutor}; +use crate::errors::MbtResult; + pub async fn is_deduplicated_type(conn: &mut T) -> MbtResult where for<'e> &'e mut T: SqliteExecutor<'e>, diff --git a/martin/src/config.rs b/martin/src/config.rs index 314814156..33168a33f 100644 --- a/martin/src/config.rs +++ b/martin/src/config.rs @@ -14,11 +14,15 @@ use crate::file_config::{resolve_files, FileConfigEnum}; use crate::mbtiles::MbtSource; use crate::pg::PgConfig; use crate::pmtiles::PmtSource; -use crate::source::{IdResolver, Sources}; +use crate::source::Sources; use crate::srv::SrvConfig; -use crate::utils::{OneOrMany, Result}; +use crate::utils::{IdResolver, OneOrMany, Result}; use crate::Error::{ConfigLoadError, ConfigParseError, NoSources}; +pub struct AllSources { + pub sources: Sources, +} + #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct Config { #[serde(flatten)] @@ -73,7 +77,7 @@ impl Config { } } - pub async fn resolve(&mut self, idr: IdResolver) -> Result { + pub async fn resolve(&mut self, idr: IdResolver) -> Result { let create_pmt_src = &mut PmtSource::new_box; let create_mbt_src = &mut MbtSource::new_box; @@ -93,13 +97,15 @@ impl Config { sources.push(Box::pin(val)); } - Ok(try_join_all(sources) - .await? - .into_iter() - .fold(HashMap::new(), |mut acc, hashmap| { - acc.extend(hashmap); - acc - })) + Ok(AllSources { + sources: try_join_all(sources).await?.into_iter().fold( + Sources::default(), + |mut acc, hashmap| { + acc.extend(hashmap); + acc + }, + ), + }) } } diff --git a/martin/src/file_config.rs b/martin/src/file_config.rs index 4d546d1d2..5411eee38 100644 --- a/martin/src/file_config.rs +++ b/martin/src/file_config.rs @@ -10,9 +10,9 @@ use serde_yaml::Value; use crate::config::{copy_unrecognized_config, Unrecognized}; use crate::file_config::FileError::{InvalidFilePath, InvalidSourceFilePath, IoError}; -use crate::utils::sorted_opt_map; +use crate::source::{Source, Sources, Xyz}; +use crate::utils::{sorted_opt_map, Error, IdResolver, OneOrMany}; use crate::OneOrMany::{Many, One}; -use crate::{Error, IdResolver, OneOrMany, Source, Sources, Xyz}; #[derive(thiserror::Error, Debug)] pub enum FileError { @@ -40,6 +40,22 @@ pub enum FileConfigEnum { Config(FileConfig), } +impl FileConfigEnum { + pub fn extract_file_config(&mut self) -> FileConfig { + match self { + FileConfigEnum::Path(path) => FileConfig { + paths: Some(One(mem::take(path))), + ..FileConfig::default() + }, + FileConfigEnum::Paths(paths) => FileConfig { + paths: Some(Many(mem::take(paths))), + ..Default::default() + }, + FileConfigEnum::Config(cfg) => mem::take(cfg), + } + } +} + #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] pub struct FileConfig { /// A list of file paths @@ -69,12 +85,12 @@ pub enum FileConfigSrc { } impl FileConfigSrc { - #[must_use] - pub fn path(&self) -> &PathBuf { - match self { + pub fn abs_path(&self) -> Result { + let path = match self { Self::Path(p) => p, Self::Obj(o) => &o.path, - } + }; + path.canonicalize().map_err(|e| IoError(e, path.clone())) } } @@ -125,27 +141,16 @@ async fn resolve_int( where Fut: Future, FileError>>, { - let cfg = match config { - FileConfigEnum::Path(path) => FileConfig { - paths: Some(One(mem::take(path))), - ..FileConfig::default() - }, - FileConfigEnum::Paths(paths) => FileConfig { - paths: Some(Many(mem::take(paths))), - ..Default::default() - }, - FileConfigEnum::Config(cfg) => mem::take(cfg), - }; + let cfg = config.extract_file_config(); - let mut results = Sources::new(); + let mut results = Sources::default(); let mut configs = HashMap::new(); let mut files = HashSet::new(); let mut directories = Vec::new(); if let Some(sources) = cfg.sources { for (id, source) in sources { - let path = source.path(); - let can = path.canonicalize().map_err(|e| IoError(e, path.clone()))?; + let can = source.abs_path()?; if !can.is_file() { // todo: maybe warn instead? return Err(InvalidSourceFilePath(id.to_string(), can)); @@ -173,7 +178,7 @@ where directories.push(path.clone()); path.read_dir() .map_err(|e| IoError(e, path.clone()))? - .filter_map(std::result::Result::ok) + .filter_map(Result::ok) .filter(|f| { f.path().extension().filter(|e| *e == extension).is_some() && f.path().is_file() diff --git a/martin/src/lib.rs b/martin/src/lib.rs index 6284e7160..a5b1cf648 100644 --- a/martin/src/lib.rs +++ b/martin/src/lib.rs @@ -26,8 +26,10 @@ mod test_utils; #[cfg(test)] pub use crate::args::Env; pub use crate::config::{read_config, Config}; -pub use crate::source::{IdResolver, Source, Sources, Xyz}; -pub use crate::utils::{decode_brotli, decode_gzip, BoolOrObject, Error, OneOrMany, Result}; +pub use crate::source::{Source, Sources, Xyz}; +pub use crate::utils::{ + decode_brotli, decode_gzip, BoolOrObject, Error, IdResolver, OneOrMany, Result, +}; // Ensure README.md contains valid code #[cfg(doctest)] diff --git a/martin/src/pg/config.rs b/martin/src/pg/config.rs index dc4472c2c..a59e5a8eb 100644 --- a/martin/src/pg/config.rs +++ b/martin/src/pg/config.rs @@ -7,8 +7,8 @@ use crate::pg::config_function::FuncInfoSources; use crate::pg::config_table::TableInfoSources; use crate::pg::configurator::PgBuilder; use crate::pg::utils::Result; -use crate::source::{IdResolver, Sources}; -use crate::utils::{sorted_opt_map, BoolOrObject, OneOrMany}; +use crate::source::Sources; +use crate::utils::{sorted_opt_map, BoolOrObject, IdResolver, OneOrMany}; pub trait PgInfo { fn format_id(&self) -> String; diff --git a/martin/src/pg/configurator.rs b/martin/src/pg/configurator.rs index 3f16ee9d5..4b27388a0 100755 --- a/martin/src/pg/configurator.rs +++ b/martin/src/pg/configurator.rs @@ -1,5 +1,5 @@ use std::cmp::Ordering; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use futures::future::join_all; use itertools::Itertools; @@ -14,8 +14,8 @@ use crate::pg::pool::PgPool; use crate::pg::table_source::{calc_srid, get_table_sources, merge_table_info, table_to_query}; use crate::pg::utils::PgError::InvalidTableExtent; use crate::pg::utils::Result; -use crate::source::{IdResolver, Sources}; -use crate::utils::{find_info, normalize_key, BoolOrObject, InfoMap, OneOrMany}; +use crate::source::Sources; +use crate::utils::{find_info, normalize_key, BoolOrObject, IdResolver, InfoMap, OneOrMany}; pub type SqlFuncInfoMapMap = InfoMap>; pub type SqlTableInfoMapMapMap = InfoMap>>; @@ -139,7 +139,7 @@ impl PgBuilder { } } - let mut res: Sources = HashMap::new(); + let mut res = Sources::default(); let mut info_map = TableInfoSources::new(); let pending = join_all(pending).await; for src in pending { @@ -161,7 +161,7 @@ impl PgBuilder { pub async fn instantiate_functions(&self) -> Result<(Sources, FuncInfoSources)> { let mut all_funcs = get_function_sources(&self.pool).await?; - let mut res: Sources = HashMap::new(); + let mut res = Sources::default(); let mut info_map = FuncInfoSources::new(); let mut used = HashSet::<(&str, &str)>::new(); diff --git a/martin/src/source.rs b/martin/src/source.rs index e813f2c6b..3e1154303 100644 --- a/martin/src/source.rs +++ b/martin/src/source.rs @@ -1,10 +1,13 @@ -use std::collections::hash_map::Entry; -use std::collections::{HashMap, HashSet}; -use std::fmt::{Debug, Display, Formatter, Write}; -use std::sync::{Arc, Mutex}; +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt::{Debug, Display, Formatter}; +use actix_web::error::ErrorNotFound; use async_trait::async_trait; +use itertools::Itertools; +use log::debug; use martin_tile_utils::TileInfo; +use serde::{Deserialize, Serialize}; use tilejson::TileJSON; use crate::utils::Result; @@ -28,7 +31,91 @@ impl Display for Xyz { pub type Tile = Vec; pub type UrlQuery = HashMap; -pub type Sources = HashMap>; + +#[derive(Default, Clone)] +pub struct Sources(HashMap>); + +impl Sources { + pub fn insert(&mut self, id: String, source: Box) { + self.0.insert(id, source); + } + + pub fn extend(&mut self, other: Sources) { + self.0.extend(other.0); + } + + #[must_use] + pub fn get_catalog(&self) -> Vec { + self.0 + .iter() + .map(|(id, src)| { + let tilejson = src.get_tilejson(); + let info = src.get_tile_info(); + IndexEntry { + id: id.clone(), + content_type: info.format.content_type().to_string(), + content_encoding: info.encoding.content_encoding().map(ToString::to_string), + name: tilejson.name.filter(|v| v != id), + description: tilejson.description, + attribution: tilejson.attribution, + } + }) + .sorted() + .collect() + } + + pub fn get_source(&self, id: &str) -> actix_web::Result<&dyn Source> { + Ok(self + .0 + .get(id) + .ok_or_else(|| ErrorNotFound(format!("Source {id} does not exist")))? + .as_ref()) + } + + pub fn get_sources( + &self, + source_ids: &str, + zoom: Option, + ) -> actix_web::Result<(Vec<&dyn Source>, bool, TileInfo)> { + let mut sources = Vec::new(); + let mut info: Option = None; + let mut use_url_query = false; + for id in source_ids.split(',') { + let src = self.get_source(id)?; + let src_inf = src.get_tile_info(); + use_url_query |= src.support_url_query(); + + // make sure all sources have the same format + match info { + Some(inf) if inf == src_inf => {} + Some(inf) => Err(ErrorNotFound(format!( + "Cannot merge sources with {inf} with {src_inf}" + )))?, + None => info = Some(src_inf), + } + + // TODO: Use chained-if-let once available + if match zoom { + Some(zoom) if Self::check_zoom(src, id, zoom) => true, + None => true, + _ => false, + } { + sources.push(src); + } + } + + // format is guaranteed to be Some() here + Ok((sources, use_url_query, info.unwrap())) + } + + pub fn check_zoom(src: &dyn Source, id: &str, zoom: u8) -> bool { + let is_valid = src.is_valid_zoom(zoom); + if !is_valid { + debug!("Zoom {zoom} is not valid for source {id}"); + } + is_valid + } +} #[async_trait] pub trait Source: Send + Debug { @@ -51,71 +138,29 @@ impl Clone for Box { } } -#[derive(Debug, Default, Clone)] -pub struct IdResolver { - /// name -> unique name - names: Arc>>, - /// reserved names - reserved: HashSet<&'static str>, +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct IndexEntry { + pub id: String, + pub content_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content_encoding: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub attribution: Option, } -impl IdResolver { - #[must_use] - pub fn new(reserved_keywords: &[&'static str]) -> Self { - Self { - names: Arc::new(Mutex::new(HashMap::new())), - reserved: reserved_keywords.iter().copied().collect(), - } +impl PartialOrd for IndexEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) } +} - /// If source name already exists in the self.names structure, - /// try appending it with ".1", ".2", etc. until the name is unique. - /// Only alphanumeric characters plus dashes/dots/underscores are allowed. - #[must_use] - pub fn resolve(&self, name: &str, unique_name: String) -> String { - // Ensure name has no prohibited characters like spaces, commas, slashes, or non-unicode etc. - // Underscores, dashes, and dots are OK. All other characters will be replaced with dashes. - let mut name = name.replace( - |c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.' && c != '-', - "-", - ); - - let mut names = self.names.lock().expect("IdResolver panicked"); - if !self.reserved.contains(name.as_str()) { - match names.entry(name) { - Entry::Vacant(e) => { - let id = e.key().clone(); - e.insert(unique_name); - return id; - } - Entry::Occupied(e) => { - name = e.key().clone(); - if e.get() == &unique_name { - return name; - } - } - } - } - // name already exists, try it with ".1", ".2", etc. until the value matches - // assume that reserved keywords never end in a "dot number", so don't check - let mut index: i32 = 1; - let mut new_name = String::new(); - loop { - new_name.clear(); - write!(&mut new_name, "{name}.{index}").unwrap(); - index = index.checked_add(1).unwrap(); - match names.entry(new_name.clone()) { - Entry::Vacant(e) => { - e.insert(unique_name); - return new_name; - } - Entry::Occupied(e) => { - if e.get() == &unique_name { - return new_name; - } - } - } - } +impl Ord for IndexEntry { + fn cmp(&self, other: &Self) -> Ordering { + (&self.id, &self.name).cmp(&(&other.id, &other.name)) } } @@ -123,22 +168,6 @@ impl IdResolver { mod tests { use super::*; - #[test] - fn id_resolve() { - let r = IdResolver::default(); - assert_eq!(r.resolve("a", "a".to_string()), "a"); - assert_eq!(r.resolve("a", "a".to_string()), "a"); - assert_eq!(r.resolve("a", "b".to_string()), "a.1"); - assert_eq!(r.resolve("a", "b".to_string()), "a.1"); - assert_eq!(r.resolve("b", "a".to_string()), "b"); - assert_eq!(r.resolve("b", "a".to_string()), "b"); - assert_eq!(r.resolve("a.1", "a".to_string()), "a.1.1"); - assert_eq!(r.resolve("a.1", "b".to_string()), "a.1"); - - assert_eq!(r.resolve("a b", "a b".to_string()), "a-b"); - assert_eq!(r.resolve("a b", "ab2".to_string()), "a-b.1"); - } - #[test] fn xyz_format() { let xyz = Xyz { z: 1, x: 2, y: 3 }; diff --git a/martin/src/srv/mod.rs b/martin/src/srv/mod.rs index 087019a89..6496a9230 100644 --- a/martin/src/srv/mod.rs +++ b/martin/src/srv/mod.rs @@ -2,4 +2,4 @@ mod config; mod server; pub use config::{SrvConfig, KEEP_ALIVE_DEFAULT, LISTEN_ADDRESSES_DEFAULT}; -pub use server::{new_server, router, AppState, IndexEntry, RESERVED_KEYWORDS}; +pub use server::{new_server, router, RESERVED_KEYWORDS}; diff --git a/martin/src/srv/server.rs b/martin/src/srv/server.rs index e992e7ccc..0219ea3ea 100755 --- a/martin/src/srv/server.rs +++ b/martin/src/srv/server.rs @@ -1,4 +1,3 @@ -use std::cmp::Ordering; use std::string::ToString; use std::time::Duration; @@ -17,12 +16,12 @@ use actix_web::{ Responder, Result, }; use futures::future::try_join_all; -use itertools::Itertools; -use log::{debug, error}; +use log::error; use martin_tile_utils::{Encoding, Format, TileInfo}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tilejson::{tilejson, TileJSON}; +use crate::config::AllSources; use crate::source::{Source, Sources, UrlQuery, Xyz}; use crate::srv::config::{SrvConfig, KEEP_ALIVE_DEFAULT, LISTEN_ADDRESSES_DEFAULT}; use crate::utils::{decode_brotli, decode_gzip, encode_brotli, encode_gzip}; @@ -41,56 +40,6 @@ static SUPPORTED_ENCODINGS: &[HeaderEnc] = &[ HeaderEnc::identity(), ]; -pub struct AppState { - pub sources: Sources, -} - -impl AppState { - fn get_source(&self, id: &str) -> Result<&dyn Source> { - Ok(self - .sources - .get(id) - .ok_or_else(|| error::ErrorNotFound(format!("Source {id} does not exist")))? - .as_ref()) - } - - fn get_sources( - &self, - source_ids: &str, - zoom: Option, - ) -> Result<(Vec<&dyn Source>, bool, TileInfo)> { - let mut sources = Vec::new(); - let mut info: Option = None; - let mut use_url_query = false; - for id in source_ids.split(',') { - let src = self.get_source(id)?; - let src_inf = src.get_tile_info(); - use_url_query |= src.support_url_query(); - - // make sure all sources have the same format - match info { - Some(inf) if inf == src_inf => {} - Some(inf) => Err(error::ErrorNotFound(format!( - "Cannot merge sources with {inf} with {src_inf}" - )))?, - None => info = Some(src_inf), - } - - // TODO: Use chained-if-let once available - if match zoom { - Some(zoom) if check_zoom(src, id, zoom) => true, - None => true, - _ => false, - } { - sources.push(src); - } - } - - // format is guaranteed to be Some() here - Ok((sources, use_url_query, info.unwrap())) - } -} - #[derive(Deserialize)] struct TileJsonRequest { source_ids: String, @@ -104,32 +53,6 @@ struct TileRequest { y: u32, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct IndexEntry { - pub id: String, - pub content_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub content_encoding: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub attribution: Option, -} - -impl PartialOrd for IndexEntry { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for IndexEntry { - fn cmp(&self, other: &Self) -> Ordering { - (&self.id, &self.name).cmp(&(&other.id, &other.name)) - } -} - fn map_internal_error(e: T) -> Error { error!("{e}"); error::ErrorInternalServerError(e.to_string()) @@ -160,25 +83,8 @@ async fn get_health() -> impl Responder { wrap = "middleware::Compress::default()" )] #[allow(clippy::unused_async)] -async fn get_catalog(state: Data) -> impl Responder { - let info: Vec<_> = state - .sources - .iter() - .map(|(id, src)| { - let tilejson = src.get_tilejson(); - let info = src.get_tile_info(); - IndexEntry { - id: id.clone(), - content_type: info.format.content_type().to_string(), - content_encoding: info.encoding.content_encoding().map(ToString::to_string), - name: tilejson.name.filter(|v| v != id), - description: tilejson.description, - attribution: tilejson.attribution, - } - }) - .sorted() - .collect(); - HttpResponse::Ok().json(info) +async fn get_catalog(sources: Data) -> impl Responder { + HttpResponse::Ok().json(sources.get_catalog()) } #[route( @@ -191,9 +97,9 @@ async fn get_catalog(state: Data) -> impl Responder { async fn git_source_info( req: HttpRequest, path: Path, - state: Data, + sources: Data, ) -> Result { - let sources = state.get_sources(&path.source_ids, None)?.0; + let sources = sources.get_sources(&path.source_ids, None)?.0; let tiles_path = req .headers() @@ -319,7 +225,7 @@ fn merge_tilejson(sources: Vec<&dyn Source>, tiles_url: String) -> TileJSON { async fn get_tile( req: HttpRequest, path: Path, - state: Data, + sources: Data, ) -> Result { let xyz = Xyz { z: path.z, @@ -329,7 +235,7 @@ async fn get_tile( // Optimization for a single-source request. let (tile, info) = if path.source_ids.contains(',') { - let (sources, use_url_query, info) = state.get_sources(&path.source_ids, Some(path.z))?; + let (sources, use_url_query, info) = sources.get_sources(&path.source_ids, Some(path.z))?; if sources.is_empty() { return Err(error::ErrorNotFound("No valid sources found")); } @@ -356,8 +262,8 @@ async fn get_tile( } else { let id = &path.source_ids; let zoom = xyz.z; - let src = state.get_source(id)?; - if !check_zoom(src, id, zoom) { + let src = sources.get_source(id)?; + if !Sources::check_zoom(src, id, zoom) { return Err(error::ErrorNotFound(format!( "Zoom {zoom} is not valid for source {id}", ))); @@ -462,7 +368,7 @@ pub fn router(cfg: &mut web::ServiceConfig) { } /// Create a new initialized Actix `App` instance together with the listening address. -pub fn new_server(config: SrvConfig, sources: Sources) -> crate::Result<(Server, String)> { +pub fn new_server(config: SrvConfig, all_sources: AllSources) -> crate::Result<(Server, String)> { let keep_alive = Duration::from_secs(config.keep_alive.unwrap_or(KEEP_ALIVE_DEFAULT)); let worker_processes = config.worker_processes.unwrap_or_else(num_cpus::get); let listen_addresses = config @@ -470,16 +376,12 @@ pub fn new_server(config: SrvConfig, sources: Sources) -> crate::Result<(Server, .unwrap_or_else(|| LISTEN_ADDRESSES_DEFAULT.to_owned()); let server = HttpServer::new(move || { - let state = AppState { - sources: sources.clone(), - }; - let cors_middleware = Cors::default() .allow_any_origin() .allowed_methods(vec!["GET"]); App::new() - .app_data(Data::new(state)) + .app_data(Data::new(all_sources.sources.clone())) .wrap(cors_middleware) .wrap(middleware::NormalizePath::new(TrailingSlash::MergeOnly)) .wrap(middleware::Logger::default()) @@ -495,14 +397,6 @@ pub fn new_server(config: SrvConfig, sources: Sources) -> crate::Result<(Server, Ok((server, listen_addresses)) } -fn check_zoom(src: &dyn Source, id: &str, zoom: u8) -> bool { - let is_valid = src.is_valid_zoom(zoom); - if !is_valid { - debug!("Zoom {zoom} is not valid for source {id}"); - } - is_valid -} - fn parse_x_rewrite_url(header: &HeaderValue) -> Option { header .to_str() diff --git a/martin/src/utils/id_resolver.rs b/martin/src/utils/id_resolver.rs new file mode 100644 index 000000000..ff1fdd2a4 --- /dev/null +++ b/martin/src/utils/id_resolver.rs @@ -0,0 +1,93 @@ +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use std::fmt::Write; +use std::sync::{Arc, Mutex}; + +#[derive(Debug, Default, Clone)] +pub struct IdResolver { + /// name -> unique name + names: Arc>>, + /// reserved names + reserved: HashSet<&'static str>, +} + +impl IdResolver { + #[must_use] + pub fn new(reserved_keywords: &[&'static str]) -> Self { + Self { + names: Arc::new(Mutex::new(HashMap::new())), + reserved: reserved_keywords.iter().copied().collect(), + } + } + + /// If source name already exists in the self.names structure, + /// try appending it with ".1", ".2", etc. until the name is unique. + /// Only alphanumeric characters plus dashes/dots/underscores are allowed. + #[must_use] + pub fn resolve(&self, name: &str, unique_name: String) -> String { + // Ensure name has no prohibited characters like spaces, commas, slashes, or non-unicode etc. + // Underscores, dashes, and dots are OK. All other characters will be replaced with dashes. + let mut name = name.replace( + |c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '.' && c != '-', + "-", + ); + + let mut names = self.names.lock().expect("IdResolver panicked"); + if !self.reserved.contains(name.as_str()) { + match names.entry(name) { + Entry::Vacant(e) => { + let id = e.key().clone(); + e.insert(unique_name); + return id; + } + Entry::Occupied(e) => { + name = e.key().clone(); + if e.get() == &unique_name { + return name; + } + } + } + } + // name already exists, try it with ".1", ".2", etc. until the value matches + // assume that reserved keywords never end in a "dot number", so don't check + let mut index: i32 = 1; + let mut new_name = String::new(); + loop { + new_name.clear(); + write!(&mut new_name, "{name}.{index}").unwrap(); + index = index.checked_add(1).unwrap(); + match names.entry(new_name.clone()) { + Entry::Vacant(e) => { + e.insert(unique_name); + return new_name; + } + Entry::Occupied(e) => { + if e.get() == &unique_name { + return new_name; + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn id_resolve() { + let r = IdResolver::default(); + assert_eq!(r.resolve("a", "a".to_string()), "a"); + assert_eq!(r.resolve("a", "a".to_string()), "a"); + assert_eq!(r.resolve("a", "b".to_string()), "a.1"); + assert_eq!(r.resolve("a", "b".to_string()), "a.1"); + assert_eq!(r.resolve("b", "a".to_string()), "b"); + assert_eq!(r.resolve("b", "a".to_string()), "b"); + assert_eq!(r.resolve("a.1", "a".to_string()), "a.1.1"); + assert_eq!(r.resolve("a.1", "b".to_string()), "a.1"); + + assert_eq!(r.resolve("a b", "a b".to_string()), "a-b"); + assert_eq!(r.resolve("a b", "ab2".to_string()), "a-b.1"); + } +} diff --git a/martin/src/utils/mod.rs b/martin/src/utils/mod.rs index 15ebc4811..43aff165d 100644 --- a/martin/src/utils/mod.rs +++ b/martin/src/utils/mod.rs @@ -1,7 +1,9 @@ mod error; +mod id_resolver; mod one_or_many; mod utilities; pub use error::*; +pub use id_resolver::IdResolver; pub use one_or_many::OneOrMany; pub use utilities::*; diff --git a/rustfmt.toml b/rustfmt.toml index 186c1c91d..a0a6602a5 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,4 +1,4 @@ ## These should be enabled in the future, but for now its a manual step to simplify usage. -## Use cargo nightly for these: cargo +nightly fmt +## Use Justfile fmt2 target instead: just fmt2 #imports_granularity = "Module" #group_imports = "StdExternalCrate" diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index e5474f8c7..d15588ea7 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -15,8 +15,8 @@ mod test_utils; #[allow(clippy::wildcard_imports)] pub use test_utils::*; -pub async fn mock_app_data(sources: Sources) -> Data { - Data::new(AppState { sources }) +pub async fn mock_app_data(sources: Sources) -> Data { + Data::new(sources) } #[must_use]