From 301b913982f40eabb1c0c0c5adbdf6d2fd9ff89d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9my=20HERGAULT?= Date: Thu, 19 Dec 2024 09:18:07 +0100 Subject: [PATCH] feat: Add listener for server sockets (#15) * feat: Add listener for server sockets * fix: Correct doc tests and ALPN select --------- Signed-off-by: Jeremy HERGAULT --- prosa/Cargo.toml | 4 + prosa/src/inj/proc.rs | 15 +- prosa/src/io.rs | 1013 ++++++++++++++------------------- prosa/src/io/listener.rs | 503 ++++++++++++++++ prosa/src/io/stream.rs | 742 ++++++++++++++++++++++++ prosa_utils/src/config/ssl.rs | 210 +++++-- 6 files changed, 1860 insertions(+), 627 deletions(-) create mode 100644 prosa/src/io/listener.rs create mode 100644 prosa/src/io/stream.rs diff --git a/prosa/Cargo.toml b/prosa/Cargo.toml index 5f0bd92..f64065f 100644 --- a/prosa/Cargo.toml +++ b/prosa/Cargo.toml @@ -35,6 +35,7 @@ tracing = "0.1" tracing-subscriber = {version = "0.3", features = ["std", "env-filter"]} thiserror.workspace = true url = { version = "2", features = ["serde"] } +rlimit = "0.10" aquamarine.workspace = true @@ -58,3 +59,6 @@ opentelemetry-stdout.workspace = true opentelemetry-otlp.workspace = true opentelemetry-appender-log.workspace = true memory-stats = "1" + +[dev-dependencies] +futures-util = { version = "0.3", default-features = false } diff --git a/prosa/src/inj/proc.rs b/prosa/src/inj/proc.rs index 068915c..760c328 100644 --- a/prosa/src/inj/proc.rs +++ b/prosa/src/inj/proc.rs @@ -20,7 +20,7 @@ extern crate self as prosa; /// Inj settings for service and speed parameters #[proc_settings] -#[derive(Default, Debug, Deserialize, Serialize, Clone)] +#[derive(Debug, Deserialize, Serialize, Clone)] pub struct InjSettings { /// Service to inject to service_name: String, @@ -83,6 +83,19 @@ impl InjSettings { } } +#[proc_settings] +impl Default for InjSettings { + fn default() -> InjSettings { + InjSettings { + service_name: Default::default(), + max_speed: InjSettings::default_max_speed(), + timeout_threshold: InjSettings::default_timeout_threshold(), + max_concurrents_send: InjSettings::default_max_concurrents_send(), + speed_interval: InjSettings::default_speed_interval(), + } + } +} + /// Inj processor to inject transactions /// /// ``` diff --git a/prosa/src/io.rs b/prosa/src/io.rs index e9681aa..c981595 100644 --- a/prosa/src/io.rs +++ b/prosa/src/io.rs @@ -1,646 +1,505 @@ //! Module that define IO that could be use by a ProSA processor use std::{ - fmt, io, - os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd}, - pin::Pin, - task::{Context, Poll}, + fmt, + net::{SocketAddrV4, SocketAddrV6}, + path::Path, }; -use openssl::ssl::{self, SslContext}; pub use prosa_macros::io; -use prosa_utils::config::ssl::SslConfig; -use serde::{Deserialize, Serialize}; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - net::{TcpStream, ToSocketAddrs}, -}; -use tokio_openssl::SslStream; use url::Url; -/// ProSA socket object to handle TCP/SSL socket with or without proxy -#[derive(Debug)] -pub enum Stream { - /// TCP socket - Tcp(TcpStream), - /// SSL socket - Ssl(SslStream), - /// TCP socket using Http proxy - TcpHttpProxy(TcpStream), - /// SSL socket using Http proxy - SslHttpProxy(SslStream), -} - -impl Stream { - #[cfg_attr(doc, aquamarine::aquamarine)] - /// Connect a TCP socket to a distant - /// - /// ```mermaid - /// graph LR - /// client[Client] - /// server[Server] - /// - /// client -- TCP --> server - /// ``` - /// - /// ``` - /// use tokio::io; - /// use url::Url; - /// use prosa::io::Stream; - /// - /// async fn connecting() -> Result<(), io::Error> { - /// let stream: Stream = Stream::connect_tcp("worldline.com:80").await?; - /// - /// // Handle the stream like any tokio stream - /// - /// Ok(()) - /// } - /// ``` - pub async fn connect_tcp(addr: A) -> Result - where - A: ToSocketAddrs, - { - Ok(Stream::Tcp(TcpStream::connect(addr).await?)) - } - - /// Method to create an SSL stream from a TCP stream - async fn create_ssl( - tcp_stream: TcpStream, - ssl_context: &ssl::SslContext, - ) -> Result, io::Error> { - let ssl = ssl::Ssl::new(ssl_context).unwrap(); - let mut stream = SslStream::new(ssl, tcp_stream).unwrap(); - if let Err(e) = Pin::new(&mut stream).connect().await { - if e.code() != ssl::ErrorCode::ZERO_RETURN { - return Err(io::Error::new( - io::ErrorKind::Interrupted, - format!("Can't connect the SSL socket `{}`", e), - )); - } - } - - Ok(stream) - } - - #[cfg_attr(doc, aquamarine::aquamarine)] - /// Connect an SSL socket to a distant - /// - /// ```mermaid - /// graph LR - /// client[Client] - /// server[Server] - /// - /// client -- TCP+TLS --> server - /// ``` - /// - /// ``` - /// use tokio::io; - /// use url::Url; - /// use prosa_utils::config::ssl::SslConfig; - /// use prosa::io::Stream; - /// - /// async fn connecting() -> Result<(), io::Error> { - /// let ssl_config = SslConfig::default(); - /// if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { - /// let ssl_context = ssl_context_builder.build(); - /// let stream: Stream = Stream::connect_ssl("worldline.com:443", &ssl_context).await?; - /// - /// // Handle the stream like any tokio stream - /// } - /// - /// Ok(()) - /// } - /// ``` - pub async fn connect_ssl(addr: A, ssl_context: &ssl::SslContext) -> Result - where - A: ToSocketAddrs, - { - Ok(Stream::Ssl( - Self::create_ssl(TcpStream::connect(addr).await?, ssl_context).await?, - )) - } +pub mod listener; +pub mod stream; - /// Method to connect a TCP stream through an HTTP proxy - async fn connect_http_proxy( - host: &str, - port: u16, - proxy: &Url, - ) -> Result { - let proxy_addrs = proxy.socket_addrs(|| proxy.port_or_known_default())?; - let mut tcp_stream = TcpStream::connect(&*proxy_addrs).await?; - if let (username, Some(password)) = (proxy.username(), proxy.password()) { - if let Err(e) = async_http_proxy::http_connect_tokio_with_basic_auth( - &mut tcp_stream, - host, - port, - username, - password, - ) - .await - { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - format!("Can't connect to the http proxy with basic_auth `{}`", e), - )); - } - } else if let Err(e) = - async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port).await - { - return Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - format!("Can't connect to the http proxy `{}`", e), - )); - } +/// Trait to define ProSA IO. +/// Implement with the procedural macro io +pub trait IO { + /// Frame error trigger when the frame operation can't be executed + type Error; - Ok(tcp_stream) - } + /// Method call to parse a frame + fn parse_frame(&mut self) -> std::result::Result, Self::Error>; - #[cfg_attr(doc, aquamarine::aquamarine)] - /// Connect a TCP socket to a distant through an HTTP proxy - /// - /// ```mermaid - /// graph LR - /// client[Client] - /// server[Server] - /// proxy[Proxy] - /// - /// client -- TCP --> proxy - /// proxy --> server - /// ``` - /// - /// ``` - /// use tokio::io; - /// use url::Url; - /// use prosa::io::Stream; - /// - /// async fn connecting() -> Result<(), io::Error> { - /// let proxy_url = Url::parse("http://user:pwd@proxy:3128").unwrap(); - /// let stream: Stream = Stream::connect_tcp_with_http_proxy("worldline.com", 443, &proxy_url).await?; - /// - /// // Handle the stream like any tokio stream - /// - /// Ok(()) - /// } - /// ``` - pub async fn connect_tcp_with_http_proxy( - host: &str, - port: u16, - proxy: &Url, - ) -> Result { - Ok(Stream::TcpHttpProxy( - Self::connect_http_proxy(host, port, proxy).await?, - )) - } + /// Method to wait a complete frame + fn read_frame( + &mut self, + ) -> impl std::future::Future, Self::Error>> + Send; + /// Method to write a frame and wait for completion + fn write_frame( + &mut self, + frame: F, + ) -> impl std::future::Future> + Send; +} - #[cfg_attr(doc, aquamarine::aquamarine)] - /// Connect an SSL socket to a distant through an HTTP proxy - /// - /// ```mermaid - /// graph LR - /// client[Client] - /// server[Server] - /// proxy[Proxy] - /// - /// client -- TCP+TLS --> proxy - /// proxy --> server - /// ``` - /// - /// ``` - /// use tokio::io; - /// use url::Url; - /// use prosa_utils::config::ssl::SslConfig; - /// use prosa::io::Stream; - /// - /// async fn connecting() -> Result<(), io::Error> { - /// let proxy_url = Url::parse("http://user:pwd@proxy:3128").unwrap(); - /// let ssl_config = SslConfig::default(); - /// if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { - /// let ssl_context = ssl_context_builder.build(); - /// let stream: Stream = Stream::connect_ssl_with_http_proxy("worldline.com", 443, &ssl_context, &proxy_url).await?; - /// - /// // Handle the stream like any tokio stream - /// } - /// - /// Ok(()) - /// } - /// ``` - pub async fn connect_ssl_with_http_proxy( - host: &str, - port: u16, - ssl_context: &ssl::SslContext, - proxy: &Url, - ) -> Result { - Ok(Stream::SslHttpProxy( - Self::create_ssl( - Self::connect_http_proxy(host, port, proxy).await?, - ssl_context, - ) - .await?, - )) +/// Method to known if the url indicate an SSL protocol +/// +/// ``` +/// use url::Url; +/// use prosa::io::url_is_ssl; +/// +/// assert!(!url_is_ssl(&Url::parse("http://localhost").unwrap())); +/// assert!(url_is_ssl(&Url::parse("https://localhost").unwrap())); +/// ``` +pub fn url_is_ssl(url: &Url) -> bool { + let scheme = url.scheme(); + if scheme.ends_with("+ssl") || scheme.ends_with("+tls") { + true + } else { + matches!(url.scheme(), "ssl" | "tls" | "https" | "wss") } +} - #[cfg_attr(doc, aquamarine::aquamarine)] - /// Accept an SSL socket from a TcpListener - /// - /// ```mermaid - /// graph RL - /// clients[Clients] - /// server[Server] - /// - /// clients --> server - /// ``` - /// - /// ``` - /// use tokio::io; - /// use tokio::net::TcpListener; - /// use prosa_utils::config::ssl::SslConfig; - /// use prosa::io::Stream; - /// - /// async fn listenning() -> Result<(), io::Error> { - /// let ssl_context = SslConfig::default().init_tls_server_context().unwrap().build(); - /// let listener = TcpListener::bind("0.0.0.0:4443").await?; - /// - /// loop { - /// let (stream, cli_addr) = listener.accept().await?; - /// let stream = Stream::accept_ssl(stream, &ssl_context).await?; - /// - /// // Use stream ... - /// } - /// - /// Ok(()) - /// } - /// ``` - pub async fn accept_ssl(stream: TcpStream, context: &SslContext) -> Result { - let ssl = ssl::Ssl::new(context)?; - let mut ssl_stream = SslStream::new(ssl, stream).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("Can't create SslStream: {}", e), - ) - })?; - if let Err(e) = Pin::new(&mut ssl_stream).accept().await { - if e.code() != ssl::ErrorCode::ZERO_RETURN { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("Can't accept the client: {}", e), - )); - } - } - - Ok(Stream::Ssl(ssl_stream)) - } +/// Internal Socket adress enum to define IPv4, IPv6 and unix socket. +#[derive(Debug)] +pub enum SocketAddr { + #[cfg(target_family = "unix")] + /// UNIX socket address + Unix(tokio::net::unix::SocketAddr), + /// IPv4 address + V4(SocketAddrV4), + /// IPv6 address + V6(SocketAddrV6), +} - /// Sets the value of the TCP_NODELAY option on the ProSA socket - pub fn set_nodelay(&self, nodelay: bool) -> Result<(), io::Error> { +impl SocketAddr { + /// Returns true if this is a loopback address (IPv4: 127.0.0.0/8, IPv6: ::1). + /// These properties are defined by [IETF RFC 1122](https://tools.ietf.org/html/rfc1122), and [IETF RFC 4291 section 2.5.3](https://tools.ietf.org/html/rfc4291#section-2.5.3). + pub fn is_loopback(&self) -> bool { match self { - Stream::Tcp(s) => s.set_nodelay(nodelay), - Stream::Ssl(s) => s.get_ref().set_nodelay(nodelay), - Stream::TcpHttpProxy(s) => s.set_nodelay(nodelay), - Stream::SslHttpProxy(s) => s.get_ref().set_nodelay(nodelay), + #[cfg(target_family = "unix")] + SocketAddr::Unix(_) => true, + SocketAddr::V4(ipv4) => ipv4.ip().is_loopback(), + SocketAddr::V6(ipv6) => ipv6.ip().is_loopback(), } } - /// Gets the value of the TCP_NODELAY option for the ProSA socket - pub fn nodelay(&self) -> Result { + /// Returns the port number associated with this socket address. + pub const fn port(&self) -> u16 { match self { - Stream::Tcp(s) => s.nodelay(), - Stream::Ssl(s) => s.get_ref().nodelay(), - Stream::TcpHttpProxy(s) => s.nodelay(), - Stream::SslHttpProxy(s) => s.get_ref().nodelay(), + #[cfg(target_family = "unix")] + SocketAddr::Unix(_) => 0u16, + SocketAddr::V4(ipv4) => ipv4.port(), + SocketAddr::V6(ipv6) => ipv6.port(), } } - /// Sets the value for the IP_TTL option on the ProSA socket - pub fn set_ttl(&self, ttl: u32) -> Result<(), io::Error> { + /// Changes the port number associated with this socket address. + pub fn set_port(&mut self, port: u16) { match self { - Stream::Tcp(s) => s.set_ttl(ttl), - Stream::Ssl(s) => s.get_ref().set_ttl(ttl), - Stream::TcpHttpProxy(s) => s.set_ttl(ttl), - Stream::SslHttpProxy(s) => s.get_ref().set_ttl(ttl), + #[cfg(target_family = "unix")] + SocketAddr::Unix(_) => {} + SocketAddr::V4(ipv4) => ipv4.set_port(port), + SocketAddr::V6(ipv6) => ipv6.set_port(port), } } +} - /// Gets the value of the IP_TTL option for the ProSA socket - pub fn ttl(&self) -> Result { - match self { - Stream::Tcp(s) => s.ttl(), - Stream::Ssl(s) => s.get_ref().ttl(), - Stream::TcpHttpProxy(s) => s.ttl(), - Stream::SslHttpProxy(s) => s.get_ref().ttl(), +impl PartialEq for SocketAddr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + #[cfg(target_family = "unix")] + (SocketAddr::Unix(s), SocketAddr::Unix(o)) => s.as_pathname() == o.as_pathname(), + (SocketAddr::V4(s), SocketAddr::V4(o)) => s == o, + (SocketAddr::V6(s), SocketAddr::V6(o)) => s == o, + _ => false, } } } -impl AsFd for Stream { - fn as_fd(&self) -> BorrowedFd<'_> { +impl fmt::Display for SocketAddr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Stream::Tcp(s) => s.as_fd(), - Stream::Ssl(s) => s.get_ref().as_fd(), - Stream::TcpHttpProxy(s) => s.as_fd(), - Stream::SslHttpProxy(s) => s.get_ref().as_fd(), + #[cfg(target_family = "unix")] + SocketAddr::Unix(path) => write!( + f, + "{}", + path.as_pathname() + .unwrap_or(Path::new("undefined")) + .display() + ), + SocketAddr::V4(ipv4) => write!(f, "{}", ipv4), + SocketAddr::V6(ipv6) => write!(f, "{}", ipv6), } } } -impl AsRawFd for Stream { - fn as_raw_fd(&self) -> RawFd { - match self { - Stream::Tcp(s) => s.as_raw_fd(), - Stream::Ssl(s) => s.get_ref().as_raw_fd(), - Stream::TcpHttpProxy(s) => s.as_raw_fd(), - Stream::SslHttpProxy(s) => s.get_ref().as_raw_fd(), +impl From for SocketAddr { + fn from(addr: std::net::SocketAddr) -> Self { + match addr { + std::net::SocketAddr::V4(ipv4) => SocketAddr::V4(ipv4), + std::net::SocketAddr::V6(ipv6) => SocketAddr::V6(ipv6), } } } -impl AsyncRead for Stream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - match self.get_mut() { - Stream::Tcp(s) => { - let stream = Pin::new(s); - stream.poll_read(cx, buf) - } - Stream::Ssl(s) => { - let stream = Pin::new(s); - stream.poll_read(cx, buf) - } - Stream::TcpHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_read(cx, buf) - } - Stream::SslHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_read(cx, buf) - } - } +#[cfg(target_family = "unix")] +impl From for SocketAddr { + fn from(addr: tokio::net::unix::SocketAddr) -> Self { + SocketAddr::Unix(addr) } } -impl AsyncWrite for Stream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - match self.get_mut() { - Stream::Tcp(s) => { - let stream = Pin::new(s); - stream.poll_write(cx, buf) - } - Stream::Ssl(s) => { - let stream = Pin::new(s); - stream.poll_write(cx, buf) - } - Stream::TcpHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_write(cx, buf) - } - Stream::SslHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_write(cx, buf) - } - } - } +#[cfg(test)] +mod tests { + use futures_util::future; + use listener::{ListenerSetting, StreamListener}; + use openssl::ssl::SslVerifyMode; + use prosa_utils::config::ssl::{SslConfig, Store}; + use std::{env, os::fd::AsRawFd as _}; + use stream::{Stream, TargetSetting}; + use tokio::{ + fs::File, + io::{AsyncReadExt as _, AsyncWriteExt}, + }; + + use super::*; + + #[cfg(target_family = "unix")] + #[tokio::test] + async fn unix_client_server() { + let addr = "/tmp/prosa_unix_client_server_test.sock"; + let listener = StreamListener::Unix(tokio::net::UnixListener::bind(addr).unwrap()); + assert!(listener.as_raw_fd() > 0); + assert!( + format!("{:?}", listener).contains("UnixListener"), + "listener `{:?}` don't contain UnixListener", + listener + ); + assert!( + format!("{:?}", listener).contains(addr), + "listener `{:?}` don't contain {}", + listener, + addr + ); + assert_eq!( + "unix:///tmp/prosa_unix_client_server_test.sock", + &listener.to_string() + ); + + let server = async move { + let (mut client_stream, client_addr) = listener.accept().await.unwrap(); + assert!(client_addr.is_loopback()); + + let mut buf = [0; 5]; + client_stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ProSA"); + + client_stream.write_all(b"Worldline").await.unwrap(); + }; - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - match self.get_mut() { - Stream::Tcp(s) => { - let stream = Pin::new(s); - stream.poll_write_vectored(cx, bufs) - } - Stream::Ssl(s) => { - let stream = Pin::new(s); - stream.poll_write_vectored(cx, bufs) - } - Stream::TcpHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_write_vectored(cx, bufs) - } - Stream::SslHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_write_vectored(cx, bufs) - } - } - } + let client = async { + let mut stream = Stream::connect_unix(addr).await.unwrap(); + assert!(stream.as_raw_fd() > 0); + assert!( + format!("{:?}", stream).contains("UnixStream"), + "stream `{:?}` don't contain UnixStream", + stream + ); + assert!( + format!("{:?}", stream).contains(addr), + "stream `{:?}` don't contain {}", + stream, + addr + ); + + stream.write_all(b"ProSA").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"Worldline"); + + let _ = stream.shutdown().await; + }; - fn is_write_vectored(&self) -> bool { - match self { - Stream::Tcp(s) => s.is_write_vectored(), - Stream::Ssl(s) => s.is_write_vectored(), - Stream::TcpHttpProxy(s) => s.is_write_vectored(), - Stream::SslHttpProxy(s) => s.is_write_vectored(), - } + future::join(server, client).await; + std::fs::remove_file(addr).unwrap(); } - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Stream::Tcp(s) => { - let stream = Pin::new(s); - stream.poll_flush(cx) - } - Stream::Ssl(s) => { - let stream = Pin::new(s); - stream.poll_flush(cx) - } - Stream::TcpHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_flush(cx) - } - Stream::SslHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_flush(cx) - } - } - } + #[tokio::test] + async fn tcp_client_server() { + let addr = "localhost:41800"; + let listener = StreamListener::bind(addr).await.unwrap(); + assert!(listener.as_raw_fd() > 0); + assert!( + format!("{:?}", listener).contains("Tcp"), + "listener `{:?}` don't contain Tcp", + listener + ); + assert!( + format!("{:?}", listener).contains("TcpListener"), + "listener `{:?}` don't contain TcpListener", + listener + ); + assert!(listener.to_string().starts_with("tcp://")); + + let server = async move { + let (mut client_stream, client_addr) = listener.accept().await.unwrap(); + assert!(client_addr.is_loopback()); + + let mut buf = [0; 5]; + client_stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ProSA"); + + // Should do nothing + client_stream = listener.handshake(client_stream).await.unwrap(); + + client_stream.write_all(b"Worldline").await.unwrap(); + }; - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - Stream::Tcp(s) => { - let stream = Pin::new(s); - stream.poll_shutdown(cx) - } - Stream::Ssl(s) => { - let stream = Pin::new(s); - stream.poll_shutdown(cx) - } - Stream::TcpHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_shutdown(cx) - } - Stream::SslHttpProxy(s) => { - let stream = Pin::new(s); - stream.poll_shutdown(cx) - } - } - } -} + let client = async { + let mut stream = Stream::connect_tcp(addr).await.unwrap(); + assert!(stream.as_raw_fd() > 0); + assert!( + format!("{:?}", stream).contains("Tcp"), + "stream `{:?}` don't contain Tcp", + stream + ); + assert!( + format!("{:?}", stream).contains("TcpStream"), + "stream `{:?}` don't contain TcpStream", + stream + ); + assert!(stream.to_string().starts_with("tcp://")); + + stream.write_all(b"ProSA").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"Worldline"); + + let _ = stream.shutdown().await; + }; -impl From for Stream { - fn from(stream: TcpStream) -> Self { - Stream::Tcp(stream) + future::join(server, client).await; } -} -/// Configuration struct of an network target -/// -/// ``` -/// use tokio::io; -/// use url::Url; -/// use prosa::io::{TargetSetting, Stream}; -/// -/// async fn connecting() -> Result<(), io::Error> { -/// let wl_target = TargetSetting::new(Url::parse("https://worldline.com").unwrap(), None, None); -/// let stream: Stream = wl_target.connect().await?; -/// -/// // Handle the stream like any tokio stream -/// -/// Ok(()) -/// } -/// ``` -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct TargetSetting { - /// Url of the target destination - pub url: Url, - /// SSL configuration for target destination - pub ssl: Option, - /// Optional proxy use to reach the target - pub proxy: Option, - #[serde(skip)] - /// SSL configuration for target destination - pub ssl_context: Option, -} + #[tokio::test] + async fn ssl_client_server() { + let addr = "localhost:41443"; + let addr_url = Url::parse(format!("tls://{}", addr).as_str()).unwrap(); + + let ssl_config = SslConfig::default(); + let ssl_acceptor = ssl_config + .init_tls_server_context(addr_url.domain()) + .unwrap() + .build(); + let listener = StreamListener::bind(addr) + .await + .unwrap() + .ssl_acceptor(ssl_acceptor, Some(ssl_config.get_ssl_timeout())); + assert!(listener.as_raw_fd() > 0); + assert!( + format!("{:?}", listener).contains("Ssl"), + "listener `{:?}` don't contain Ssl", + listener + ); + assert!( + format!("{:?}", listener).contains("TcpListener"), + "listener `{:?}` don't contain TcpListener", + listener + ); + assert!(listener.to_string().starts_with("ssl://")); + + let server = async move { + let (mut client_stream, client_addr) = listener.accept().await.unwrap(); + assert!(client_addr.is_loopback()); + + let mut buf = [0; 5]; + client_stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ProSA"); + + client_stream.write_all(b"Worldline").await.unwrap(); + }; -impl TargetSetting { - /// Method to create manually a target - pub fn new(url: Url, ssl: Option, proxy: Option) -> TargetSetting { - let mut target = TargetSetting { - url, - ssl, - proxy, - ssl_context: None, + let client = async { + let mut ssl_client_context = ssl_config.init_tls_client_context().unwrap(); + ssl_client_context.set_verify(SslVerifyMode::NONE); + + let mut stream = Stream::connect_ssl(&addr_url, &ssl_client_context.build()) + .await + .unwrap(); + assert!(stream.as_raw_fd() > 0); + assert!( + format!("{:?}", stream).contains("Ssl"), + "stream `{:?}` don't contain Ssl", + stream + ); + assert!(stream.to_string().starts_with("ssl://")); + + stream.write_all(b"ProSA").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"Worldline"); + + let _ = stream.shutdown().await; }; - target.init_ssl_context(); - target + future::join(server, client).await; } - /// Method to known if the url indicate an SSL protocol - pub fn url_is_ssl(url: &Url) -> bool { - let scheme = url.scheme(); - if scheme.ends_with("+ssl") || scheme.ends_with("+tls") { - true - } else { - matches!(url.scheme(), "ssl" | "tls" | "https") - } - } + #[tokio::test] + async fn ssl_client_server_raw() { + let addr = "localhost:41453"; + let addr_url = Url::parse(format!("tls://{}", addr).as_str()).unwrap(); + + let ssl_config = SslConfig::default(); + let ssl_acceptor = ssl_config + .init_tls_server_context(addr_url.domain()) + .unwrap() + .build(); + let listener = StreamListener::bind(addr) + .await + .unwrap() + .ssl_acceptor(ssl_acceptor, Some(ssl_config.get_ssl_timeout())); + assert!(listener.as_raw_fd() > 0); + assert!( + format!("{:?}", listener).contains("Ssl"), + "listener `{:?}` don't contain Ssl", + listener + ); + assert!( + format!("{:?}", listener).contains("TcpListener"), + "listener `{:?}` don't contain TcpListener", + listener + ); + assert!(listener.to_string().starts_with("ssl://")); + + let server = async move { + let (mut client_stream, client_addr) = listener.accept_raw().await.unwrap(); + assert!(client_addr.is_loopback()); + client_stream = listener.handshake(client_stream).await.unwrap(); + + let mut buf = [0; 5]; + client_stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ProSA"); + + client_stream.write_all(b"Worldline").await.unwrap(); + }; - /// Method to init the ssl context out of the ssl target configuration. - /// Must be call when the configuration is retrieved - pub fn init_ssl_context(&mut self) { - if let Some(ssl_config) = &self.ssl { - if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { - self.ssl_context = Some(ssl_context_builder.build()); - } - } - } + let client = async { + let mut ssl_client_context = ssl_config.init_tls_client_context().unwrap(); + ssl_client_context.set_verify(SslVerifyMode::NONE); - /// Method to connect a ProSA stream to the remote target using the configuration - pub async fn connect(&self) -> Result { - let ssl_context = if self.ssl_context.is_some() { - self.ssl_context.clone() - } else if let Some(ssl_config) = &self.ssl { - if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { - Some(ssl_context_builder.build()) - } else { - None - } - } else if Self::url_is_ssl(&self.url) { - let ssl_config = SslConfig::default(); - if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { - Some(ssl_context_builder.build()) - } else { - None - } - } else { - None + let mut stream = Stream::connect_ssl(&addr_url, &ssl_client_context.build()) + .await + .unwrap(); + assert!(stream.as_raw_fd() > 0); + assert!( + format!("{:?}", stream).contains("Ssl"), + "stream `{:?}` don't contain Ssl", + stream + ); + assert!(stream.to_string().starts_with("ssl://")); + + stream.write_all(b"ProSA").await.unwrap(); + + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"Worldline"); + + let _ = stream.shutdown().await; }; - if let Some(proxy_url) = &self.proxy { - if let Some(ssl_cx) = ssl_context { - Stream::connect_ssl_with_http_proxy( - self.url.host_str().unwrap_or_default(), - self.url.port_or_known_default().unwrap_or_default(), - &ssl_cx, - proxy_url, - ) + future::join(server, client).await; + } + + #[tokio::test] + async fn ssl_client_server_with_config() { + let temp_cert_dir = env::temp_dir(); + let addr_str = "tls://localhost:41463"; + let addr = Url::parse(addr_str).unwrap(); + + let mut server_ssl_config = SslConfig::default(); + server_ssl_config.set_alpn(vec!["prosa/1".into(), "h2".into()]); + + let listener_settings = ListenerSetting::new(addr.clone(), Some(server_ssl_config)); + assert!( + format!("{:?}", listener_settings).contains("tls") + && format!("{:?}", listener_settings).contains("localhost") + && format!("{:?}", listener_settings).contains("41463"), + "`{:?}` Not contain the address {}", + listener_settings, + addr_str + ); + assert!( + listener_settings.to_string().starts_with(addr_str), + "`{}` Not start with the address {}", + listener_settings, + addr_str + ); + assert!(listener_settings.to_string().starts_with(addr_str)); + + let listener = listener_settings.bind().await.unwrap(); + if let StreamListener::Ssl(_, acceptor, _) = &listener { + let server_cert = acceptor.context().certificate().unwrap(); + let mut server_cert_file = File::create(temp_cert_dir.join("prosa_test_server.pem")) .await - } else { - Stream::connect_tcp_with_http_proxy( - self.url.host_str().unwrap_or_default(), - self.url.port_or_known_default().unwrap_or_default(), - proxy_url, - ) + .unwrap(); + server_cert_file + .write_all(&server_cert.to_pem().unwrap()) .await - } - } else { - let addrs = self.url.socket_addrs(|| self.url.port_or_known_default())?; - if let Some(ssl_cx) = ssl_context { - Stream::connect_ssl(&*addrs, &ssl_cx).await - } else { - Stream::connect_tcp(&*addrs).await - } + .unwrap(); } - } -} + assert!(listener.as_raw_fd() > 0); + assert!( + format!("{:?}", listener).contains("Ssl"), + "listener `{:?}` don't contain Ssl", + listener + ); + assert!( + format!("{:?}", listener).contains("TcpListener"), + "listener `{:?}` don't contain TcpListener", + listener + ); + + let server = async move { + let (mut client_stream, client_addr) = listener.accept().await.unwrap(); + assert!(client_addr.is_loopback()); + + let mut buf = [0; 5]; + client_stream.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"ProSA"); + + // Should do nothing + client_stream = listener.handshake(client_stream).await.unwrap(); + + client_stream.write_all(b"Worldline").await.unwrap(); + }; -impl fmt::Display for TargetSetting { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let mut url = self.url.clone(); - if self.ssl.is_some() { - let url_scheme = url.scheme(); - if url_scheme.is_empty() { - let _ = url.set_scheme("ssl"); - } else if !url_scheme.ends_with("ssl") - && !url_scheme.ends_with("tls") - && !url_scheme.ends_with("https") - && !url_scheme.ends_with("wss") - { - let _ = url.set_scheme(format!("{}+ssl", url_scheme).as_str()); + let mut client_ssl_config = SslConfig::default(); + client_ssl_config.set_alpn(vec!["http/1.1".into(), "prosa/1".into()]); + let ssl_store = Store::new(temp_cert_dir.to_str().unwrap().to_string() + "/"); + client_ssl_config.set_store(ssl_store); + let target_settings = TargetSetting::new(addr, Some(client_ssl_config), None); + assert_eq!(addr_str, target_settings.to_string()); + + let client = async { + let mut stream = target_settings.connect().await.unwrap(); + assert!(stream.as_raw_fd() > 0); + assert!( + format!("{:?}", stream).contains("Ssl"), + "stream `{:?}` don't contain Ssl", + stream + ); + if let Stream::Ssl(s) = &stream { + assert_eq!( + Some(b"prosa/1".as_slice()), + s.ssl().selected_alpn_protocol() + ); + } else { + panic!("Should be an SSL stream for client"); } - } - if let Some(proxy_url) = &self.proxy { - writeln!(f, "{} -proxy {}", url, proxy_url) - } else { - writeln!(f, "{}", url) - } - } -} + stream.write_all(b"ProSA").await.unwrap(); -/// Trait to define ProSA IO. -/// Implement with the procedural macro io -pub trait IO { - /// Frame error trigger when the frame operation can't be executed - type Error; + let mut buf = vec![]; + stream.read_to_end(&mut buf).await.unwrap(); + assert_eq!(buf, b"Worldline"); - /// Method call to parse a frame - fn parse_frame(&mut self) -> std::result::Result, Self::Error>; + let _ = stream.shutdown().await; + }; - /// Method to wait a complete frame - fn read_frame( - &mut self, - ) -> impl std::future::Future, Self::Error>> + Send; - /// Method to write a frame and wait for completion - fn write_frame( - &mut self, - frame: F, - ) -> impl std::future::Future> + Send; + future::join(server, client).await; + } } diff --git a/prosa/src/io/listener.rs b/prosa/src/io/listener.rs new file mode 100644 index 0000000..d26282d --- /dev/null +++ b/prosa/src/io/listener.rs @@ -0,0 +1,503 @@ +//! Module that define listener IO that could be use by a ProSA processor +use std::{ + fmt, io, + net::{Ipv4Addr, SocketAddrV4}, + os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd}, + pin::Pin, + time::Duration, +}; + +use openssl::ssl::SslAcceptor; +use prosa_utils::config::ssl::SslConfig; +use serde::{Deserialize, Serialize}; + +pub use prosa_macros::io; +use tokio::{ + net::{TcpListener, ToSocketAddrs, UnixListener}, + time::timeout, +}; +use url::Url; + +use super::{stream::Stream, url_is_ssl, SocketAddr}; + +/// ProSA socket object to handle TCP/SSL server socket +pub enum StreamListener { + #[cfg(target_family = "unix")] + /// Unix server socket (only on unix systems) + Unix(tokio::net::UnixListener), + /// TCP server socket + Tcp(TcpListener), + /// SSL server socket + Ssl(TcpListener, SslAcceptor, Duration), +} + +impl fmt::Debug for StreamListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(l) => f.debug_struct("Unix").field("listener", &l).finish(), + StreamListener::Tcp(l) => f.debug_struct("Tcp").field("listener", &l).finish(), + StreamListener::Ssl(l, a, t) => f + .debug_struct("Ssl") + .field("listener", &l) + .field("ssl_timeout", &t) + .field( + "certificate", + &a.context().certificate().map(|c| c.to_text()), + ) + .finish(), + } + } +} + +impl StreamListener { + /// Default SSL handshake timeout + pub const DEFAULT_SSL_TIMEOUT: Duration = Duration::new(3, 0); + + /// Returns the local address that this listener is bound to. + /// + /// This can be useful, for example, when binding to port 0 to figure out + /// which port was actually bound. + /// + /// ``` + /// use tokio::io; + /// use prosa::io::listener::StreamListener; + /// use prosa::io::SocketAddr; + /// use std::net::{Ipv4Addr, SocketAddrV4}; + /// + /// async fn accepting() -> Result<(), io::Error> { + /// let stream_listener: StreamListener = StreamListener::bind("0.0.0.0:10000").await?; + /// + /// assert_eq!(stream_listener.local_addr()?, + /// SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 10000))); + /// + /// Ok(()) + /// } + /// ``` + pub fn local_addr(&self) -> Result { + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(listener) => listener.local_addr().map(|addr| addr.into()), + StreamListener::Tcp(listener) => listener.local_addr().map(|addr| addr.into()), + StreamListener::Ssl(listener, _, _) => listener.local_addr().map(|addr| addr.into()), + } + } + + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Accept TCP connections from clients + /// + /// ```mermaid + /// graph LR + /// clients[Clients] + /// server[Server] + /// + /// clients -- TCP --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use prosa::io::listener::StreamListener; + /// + /// async fn accepting() -> Result<(), io::Error> { + /// let stream_listener: StreamListener = StreamListener::bind("0.0.0.0:10000").await?; + /// + /// loop { + /// let (stream, addr) = stream_listener.accept().await?; + /// + /// // Handle the stream like any tokio stream + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn bind(addr: A) -> Result { + Ok(StreamListener::Tcp(TcpListener::bind(addr).await?)) + } + + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Set an OpenSSL acceptor to accept SSL connections from clients + /// By default, the SSL connect timeout is 3 seconds + /// + /// ```mermaid + /// graph LR + /// clients[Clients] + /// server[Server] + /// + /// clients -- TLS --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use prosa_utils::config::ssl::SslConfig; + /// use prosa::io::listener::StreamListener; + /// + /// async fn accepting() -> Result<(), io::Error> { + /// let ssl_acceptor = SslConfig::default().init_tls_server_context(None).unwrap().build(); + /// let stream_listener: StreamListener = StreamListener::bind("0.0.0.0:10000").await?.ssl_acceptor(ssl_acceptor, None); + /// + /// loop { + /// // The client SSL handshake will happen here + /// let (stream, addr) = stream_listener.accept().await?; + /// + /// // Handle the stream like any tokio stream + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn ssl_acceptor( + self, + ssl_acceptor: SslAcceptor, + ssl_timeout: Option, + ) -> StreamListener { + match self { + StreamListener::Tcp(listener) => StreamListener::Ssl( + listener, + ssl_acceptor, + ssl_timeout.unwrap_or(Self::DEFAULT_SSL_TIMEOUT), + ), + StreamListener::Ssl(listener, _, _) => StreamListener::Ssl( + listener, + ssl_acceptor, + ssl_timeout.unwrap_or(Self::DEFAULT_SSL_TIMEOUT), + ), + _ => self, + } + } + + /// Method to accept a client after a bind + /// + /// ``` + /// use tokio::io; + /// use prosa_utils::config::ssl::SslConfig; + /// use prosa::io::listener::StreamListener; + /// + /// async fn accepting() -> Result<(), io::Error> { + /// let ssl_acceptor = SslConfig::default().init_tls_server_context(None).unwrap().build(); + /// let stream_listener: StreamListener = StreamListener::bind("0.0.0.0:10000").await?.ssl_acceptor(ssl_acceptor, None); + /// + /// loop { + /// // The client SSL handshake will happen here + /// let (stream, addr) = stream_listener.accept().await?; + /// + /// // Handle the stream like any tokio stream + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> { + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(l) => l.accept().await.map(|s| (Stream::Unix(s.0), s.1.into())), + StreamListener::Tcp(l) => l.accept().await.map(|s| (Stream::Tcp(s.0), s.1.into())), + StreamListener::Ssl(l, ssl_acceptor, ssl_timeout) => { + let ssl = openssl::ssl::Ssl::new(ssl_acceptor.context()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + let (stream, addr) = l.accept().await?; + let mut stream = tokio_openssl::SslStream::new(ssl, stream) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + if let Err(e) = timeout(*ssl_timeout, Pin::new(&mut stream).accept()) + .await + .map_err(|_| { + io::Error::new( + io::ErrorKind::TimedOut, + format!( + "SSL timeout[{} ms] for {:?}", + ssl_timeout.as_millis(), + stream + ), + ) + })? + { + if e.code() != openssl::ssl::ErrorCode::ZERO_RETURN { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Can't accept the client: {}", e), + )); + } + } + + Ok((Stream::Ssl(stream), addr.into())) + } + } + } + + /// Method to accept a client after a bind without SSL handshake (must be done with handshake after) + /// + /// ``` + /// use tokio::io; + /// use prosa_utils::config::ssl::SslConfig; + /// use prosa::io::listener::StreamListener; + /// + /// async fn accepting() -> Result<(), io::Error> { + /// let ssl_acceptor = SslConfig::default().init_tls_server_context(None).unwrap().build(); + /// let stream_listener: StreamListener = StreamListener::bind("0.0.0.0:10000").await?.ssl_acceptor(ssl_acceptor, None); + /// + /// loop { + /// let (stream, addr) = stream_listener.accept_raw().await?; + /// + /// // The client SSL handshake will happen here + /// let stream = stream_listener.handshake(stream).await?; + /// + /// // Handle the stream like any tokio stream + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn accept_raw(&self) -> Result<(Stream, SocketAddr), io::Error> { + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(l) => l.accept().await.map(|s| (Stream::Unix(s.0), s.1.into())), + StreamListener::Tcp(l) => l.accept().await.map(|s| (Stream::Tcp(s.0), s.1.into())), + StreamListener::Ssl(l, _ssl_acceptor, _ssl_timeout) => { + l.accept().await.map(|s| (Stream::Tcp(s.0), s.1.into())) + } + } + } + + /// Method to do an handshake with a client after an accept (Do nothing if the handshake is already done) + pub async fn handshake(&self, stream: Stream) -> Result { + match stream { + Stream::Tcp(tcp_stream) => { + if let StreamListener::Ssl(_l, ssl_acceptor, ssl_timeout) = self { + let ssl = openssl::ssl::Ssl::new(ssl_acceptor.context()) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + let mut stream = tokio_openssl::SslStream::new(ssl, tcp_stream) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + if let Err(e) = timeout(*ssl_timeout, Pin::new(&mut stream).accept()) + .await + .map_err(|_| { + io::Error::new( + io::ErrorKind::TimedOut, + format!( + "SSL timeout[{} ms] for {:?}", + ssl_timeout.as_millis(), + stream + ), + ) + })? + { + if e.code() != openssl::ssl::ErrorCode::ZERO_RETURN { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("Can't accept the client: {}", e), + )); + } + } + + Ok(Stream::Ssl(stream)) + } else { + Ok(Stream::Tcp(tcp_stream)) + } + } + s => Ok(s), + } + } +} + +impl AsFd for StreamListener { + fn as_fd(&self) -> BorrowedFd<'_> { + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(l) => l.as_fd(), + StreamListener::Tcp(l) => l.as_fd(), + StreamListener::Ssl(l, _, _) => l.as_fd(), + } + } +} + +impl AsRawFd for StreamListener { + fn as_raw_fd(&self) -> RawFd { + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(l) => l.as_raw_fd(), + StreamListener::Tcp(l) => l.as_raw_fd(), + StreamListener::Ssl(l, _, _) => l.as_raw_fd(), + } + } +} + +impl fmt::Display for StreamListener { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let addr = self + .local_addr() + .unwrap_or(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(0, 0, 0, 0), + 0, + ))); + match self { + #[cfg(target_family = "unix")] + StreamListener::Unix(_) => write!(f, "unix://{}", addr), + StreamListener::Tcp(_) => write!(f, "tcp://{}", addr), + StreamListener::Ssl(_, _, _) => write!(f, "ssl://{}", addr), + } + } +} + +#[cfg(target_family = "unix")] +impl From for StreamListener { + fn from(listener: tokio::net::UnixListener) -> Self { + StreamListener::Unix(listener) + } +} + +impl From for StreamListener { + fn from(listener: TcpListener) -> Self { + StreamListener::Tcp(listener) + } +} + +/// Configuration struct of an network listener +/// +/// ``` +/// use tokio::io; +/// use url::Url; +/// use prosa::io::stream::Stream; +/// use prosa::io::listener::{ListenerSetting, StreamListener}; +/// +/// async fn accepting() -> Result<(), io::Error> { +/// let wl_target = ListenerSetting::new(Url::parse("https://[::]").unwrap(), None); +/// let stream: StreamListener = wl_target.bind().await?; +/// +/// // Use the StreamListener object to accept clients +/// +/// Ok(()) +/// } +/// ``` +#[derive(Deserialize, Serialize, Clone)] +pub struct ListenerSetting { + /// Url of the listening + pub url: Url, + /// SSL configuration for target destination + pub ssl: Option, + #[serde(skip)] + /// OpenSSL configuration for target destination + ssl_context: Option, + #[serde(skip_serializing)] + #[serde(default = "ListenerSetting::default_max_socket")] + /// Maximum number of socket + pub max_socket: u64, +} + +impl ListenerSetting { + #[cfg(target_family = "unix")] + fn default_max_socket() -> u64 { + rlimit::Resource::NOFILE + .get_soft() + .unwrap_or(u32::MAX as u64) + - 1 + } + + #[cfg(target_family = "windows")] + fn default_max_socket() -> u64 { + (rlimit::getmaxstdio() as u64) - 1 + } + + #[cfg(all(not(target_family = "unix"), not(target_family = "windows")))] + fn default_max_socket() -> u64 { + (u32::MAX as u64) - 1 + } + + /// Method to create manually a target + pub fn new(url: Url, ssl: Option) -> ListenerSetting { + let mut target = ListenerSetting { + url: url.clone(), + ssl, + ssl_context: None, + max_socket: Self::default_max_socket(), + }; + + target.init_ssl_context(url.domain()); + target + } + + /// Method to init the ssl context out of the ssl target configuration. + /// Must be call when the configuration is retrieved + pub fn init_ssl_context(&mut self, domain: Option<&str>) { + if let Some(ssl_config) = &self.ssl { + if let Ok(ssl_context_builder) = ssl_config.init_tls_server_context(domain) { + self.ssl_context = Some(ssl_context_builder.build()); + } + } + } + + /// Method to connect a ProSA stream to the remote target using the configuration + pub async fn bind(&self) -> Result { + #[cfg(target_family = "unix")] + if self.url.scheme() == "unix" || self.url.scheme() == "file" { + return Ok(StreamListener::Unix(UnixListener::bind(self.url.path())?)); + } + + let addrs = self.url.socket_addrs(|| self.url.port_or_known_default())?; + let mut stream_listener = StreamListener::bind(&*addrs).await?; + + if let Some(ssl_acceptor) = &self.ssl_context { + stream_listener = stream_listener.ssl_acceptor( + ssl_acceptor.clone(), + self.ssl.as_ref().map(|c| c.get_ssl_timeout()), + ); + } else if let Some(ssl_config) = &self.ssl { + if let Ok(ssl_acceptor_builder) = ssl_config.init_tls_server_context(self.url.domain()) + { + stream_listener = stream_listener.ssl_acceptor( + ssl_acceptor_builder.build(), + Some(ssl_config.get_ssl_timeout()), + ); + } + } else if url_is_ssl(&self.url) { + let ssl_config = SslConfig::default(); + if let Ok(ssl_acceptor_builder) = ssl_config.init_tls_server_context(self.url.domain()) + { + stream_listener = stream_listener.ssl_acceptor( + ssl_acceptor_builder.build(), + Some(ssl_config.get_ssl_timeout()), + ); + } + } + + Ok(stream_listener) + } +} + +impl From for ListenerSetting { + fn from(url: Url) -> Self { + ListenerSetting { + url, + ssl: None, + ssl_context: None, + max_socket: Self::default_max_socket(), + } + } +} + +impl fmt::Debug for ListenerSetting { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ListenerSetting") + .field("url", &self.url) + .field("ssl", &self.ssl) + .field("max_socket", &self.max_socket) + .finish() + } +} + +impl fmt::Display for ListenerSetting { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut url = self.url.clone(); + if self.ssl.is_some() { + let url_scheme = url.scheme(); + if url_scheme.is_empty() { + let _ = url.set_scheme("ssl"); + } else if !url_scheme.ends_with("ssl") + && !url_scheme.ends_with("tls") + && !url_scheme.ends_with("https") + && !url_scheme.ends_with("wss") + { + let _ = url.set_scheme(format!("{}+ssl", url_scheme).as_str()); + } + } + + write!(f, "{} -max_socket {}", url, self.max_socket) + } +} diff --git a/prosa/src/io/stream.rs b/prosa/src/io/stream.rs new file mode 100644 index 0000000..9157f5e --- /dev/null +++ b/prosa/src/io/stream.rs @@ -0,0 +1,742 @@ +//! Module that define stream IO that could be use by a ProSA processor +use std::{ + fmt, io, + net::{Ipv4Addr, SocketAddrV4}, + os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd}, + path::Path, + pin::Pin, + task::{Context, Poll}, +}; + +use openssl::ssl::{self, SslConnector}; +use prosa_utils::config::ssl::SslConfig; +use serde::{Deserialize, Serialize}; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::{TcpStream, ToSocketAddrs}, +}; +use tokio_openssl::SslStream; +use url::Url; + +use super::{url_is_ssl, SocketAddr}; + +/// ProSA socket object to handle TCP/SSL socket with or without proxy +#[derive(Debug)] +pub enum Stream { + #[cfg(target_family = "unix")] + /// Unix socket (only on unix systems) + Unix(tokio::net::UnixStream), + /// TCP socket + Tcp(TcpStream), + /// SSL socket + Ssl(SslStream), + /// TCP socket using Http proxy + TcpHttpProxy(TcpStream), + /// SSL socket using Http proxy + SslHttpProxy(SslStream), +} + +impl Stream { + /// Returns the local address that this stream is bound to. + /// + /// ``` + /// use tokio::io; + /// use url::Url; + /// use prosa::io::stream::Stream; + /// use prosa::io::SocketAddr; + /// use std::net::{Ipv4Addr, SocketAddrV4}; + /// + /// async fn accepting() -> Result<(), io::Error> { + /// let stream: Stream = Stream::connect_tcp("127.0.0.1:80").await?; + /// + /// assert_eq!(stream.local_addr()?, + /// SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 80))); + /// + /// Ok(()) + /// } + /// ``` + pub fn local_addr(&self) -> Result { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(s) => s.local_addr().map(|addr| addr.into()), + Stream::Tcp(s) => s.local_addr().map(|addr| addr.into()), + Stream::Ssl(s) => s.get_ref().local_addr().map(|addr| addr.into()), + Stream::TcpHttpProxy(s) => s.local_addr().map(|addr| addr.into()), + Stream::SslHttpProxy(s) => s.get_ref().local_addr().map(|addr| addr.into()), + } + } + + #[cfg(target_family = "unix")] + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Connect a UNIX socket on a path + /// + /// ```mermaid + /// graph LR + /// client[Client] + /// server[Server] + /// + /// client -- UNIX --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use url::Url; + /// use prosa::io::stream::Stream; + /// + /// async fn connecting() -> Result<(), io::Error> { + /// let stream: Stream = Stream::connect_unix("/var/run/prosa.socket").await?; + /// + /// // Handle the stream like any tokio stream + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect_unix

(path: P) -> Result + where + P: AsRef, + { + Ok(Stream::Unix(tokio::net::UnixStream::connect(path).await?)) + } + + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Connect a TCP socket to a distant + /// + /// ```mermaid + /// graph LR + /// client[Client] + /// server[Server] + /// + /// client -- TCP --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use url::Url; + /// use prosa::io::stream::Stream; + /// + /// async fn connecting() -> Result<(), io::Error> { + /// let stream: Stream = Stream::connect_tcp("worldline.com:80").await?; + /// + /// // Handle the stream like any tokio stream + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect_tcp(addr: A) -> Result + where + A: ToSocketAddrs, + { + Ok(Stream::Tcp(TcpStream::connect(addr).await?)) + } + + /// Method to create an SSL stream from a TCP stream + async fn create_ssl( + tcp_stream: TcpStream, + ssl_connector: &ssl::SslConnector, + domain: &str, + ) -> Result, io::Error> { + let ssl = ssl_connector.configure()?.into_ssl(domain)?; + let mut stream = SslStream::new(ssl, tcp_stream).unwrap(); + if let Err(e) = Pin::new(&mut stream).connect().await { + if e.code() != ssl::ErrorCode::ZERO_RETURN { + return Err(io::Error::new( + io::ErrorKind::Interrupted, + format!("Can't connect the SSL socket `{}`", e), + )); + } + } + + Ok(stream) + } + + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Connect an SSL socket to a distant + /// + /// ```mermaid + /// graph LR + /// client[Client] + /// server[Server] + /// + /// client -- TCP+TLS --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use url::Url; + /// use prosa_utils::config::ssl::SslConfig; + /// use prosa::io::stream::Stream; + /// + /// async fn connecting() -> Result<(), io::Error> { + /// let ssl_config = SslConfig::default(); + /// if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { + /// let ssl_context = ssl_context_builder.build(); + /// let stream: Stream = Stream::connect_ssl(&Url::parse("worldline.com:443").unwrap(), &ssl_context).await?; + /// + /// // Handle the stream like any tokio stream + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect_ssl( + url: &Url, + ssl_context: &ssl::SslConnector, + ) -> Result { + let addrs = url.socket_addrs(|| url.port_or_known_default())?; + Ok(Stream::Ssl( + Self::create_ssl( + TcpStream::connect(&*addrs).await?, + ssl_context, + url.domain().ok_or(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Can't retrieve domain name from url `{}`", url), + ))?, + ) + .await?, + )) + } + + /// Method to connect a TCP stream through an HTTP proxy + async fn connect_http_proxy( + host: &str, + port: u16, + proxy: &Url, + ) -> Result { + let proxy_addrs = proxy.socket_addrs(|| proxy.port_or_known_default())?; + let mut tcp_stream = TcpStream::connect(&*proxy_addrs).await?; + if let (username, Some(password)) = (proxy.username(), proxy.password()) { + if let Err(e) = async_http_proxy::http_connect_tokio_with_basic_auth( + &mut tcp_stream, + host, + port, + username, + password, + ) + .await + { + return Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + format!("Can't connect to the http proxy with basic_auth `{}`", e), + )); + } + } else if let Err(e) = + async_http_proxy::http_connect_tokio(&mut tcp_stream, host, port).await + { + return Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + format!("Can't connect to the http proxy `{}`", e), + )); + } + + Ok(tcp_stream) + } + + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Connect a TCP socket to a distant through an HTTP proxy + /// + /// ```mermaid + /// graph LR + /// client[Client] + /// server[Server] + /// proxy[Proxy] + /// + /// client -- TCP --> proxy + /// proxy --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use url::Url; + /// use prosa::io::stream::Stream; + /// + /// async fn connecting() -> Result<(), io::Error> { + /// let proxy_url = Url::parse("http://user:pwd@proxy:3128").unwrap(); + /// let stream: Stream = Stream::connect_tcp_with_http_proxy("worldline.com", 443, &proxy_url).await?; + /// + /// // Handle the stream like any tokio stream + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect_tcp_with_http_proxy( + host: &str, + port: u16, + proxy: &Url, + ) -> Result { + Ok(Stream::TcpHttpProxy( + Self::connect_http_proxy(host, port, proxy).await?, + )) + } + + #[cfg_attr(doc, aquamarine::aquamarine)] + /// Connect an SSL socket to a distant through an HTTP proxy + /// + /// ```mermaid + /// graph LR + /// client[Client] + /// server[Server] + /// proxy[Proxy] + /// + /// client -- TCP+TLS --> proxy + /// proxy --> server + /// ``` + /// + /// ``` + /// use tokio::io; + /// use url::Url; + /// use prosa_utils::config::ssl::SslConfig; + /// use prosa::io::stream::Stream; + /// + /// async fn connecting() -> Result<(), io::Error> { + /// let proxy_url = Url::parse("http://user:pwd@proxy:3128").unwrap(); + /// let ssl_config = SslConfig::default(); + /// if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { + /// let ssl_context = ssl_context_builder.build(); + /// let stream: Stream = Stream::connect_ssl_with_http_proxy("worldline.com", 443, &ssl_context, &proxy_url).await?; + /// + /// // Handle the stream like any tokio stream + /// } + /// + /// Ok(()) + /// } + /// ``` + pub async fn connect_ssl_with_http_proxy( + host: &str, + port: u16, + ssl_connector: &ssl::SslConnector, + proxy: &Url, + ) -> Result { + Ok(Stream::SslHttpProxy( + Self::create_ssl( + Self::connect_http_proxy(host, port, proxy).await?, + ssl_connector, + host, + ) + .await?, + )) + } + + /// Sets the value of the TCP_NODELAY option on the ProSA socket + pub fn set_nodelay(&self, nodelay: bool) -> Result<(), io::Error> { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(_) => Ok(()), + Stream::Tcp(s) => s.set_nodelay(nodelay), + Stream::Ssl(s) => s.get_ref().set_nodelay(nodelay), + Stream::TcpHttpProxy(s) => s.set_nodelay(nodelay), + Stream::SslHttpProxy(s) => s.get_ref().set_nodelay(nodelay), + } + } + + /// Gets the value of the TCP_NODELAY option for the ProSA socket + pub fn nodelay(&self) -> Result { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(_) => Ok(true), + Stream::Tcp(s) => s.nodelay(), + Stream::Ssl(s) => s.get_ref().nodelay(), + Stream::TcpHttpProxy(s) => s.nodelay(), + Stream::SslHttpProxy(s) => s.get_ref().nodelay(), + } + } + + /// Sets the value for the IP_TTL option on the ProSA socket + pub fn set_ttl(&self, ttl: u32) -> Result<(), io::Error> { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(_) => Ok(()), + Stream::Tcp(s) => s.set_ttl(ttl), + Stream::Ssl(s) => s.get_ref().set_ttl(ttl), + Stream::TcpHttpProxy(s) => s.set_ttl(ttl), + Stream::SslHttpProxy(s) => s.get_ref().set_ttl(ttl), + } + } + + /// Gets the value of the IP_TTL option for the ProSA socket + pub fn ttl(&self) -> Result { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(_) => Ok(0), + Stream::Tcp(s) => s.ttl(), + Stream::Ssl(s) => s.get_ref().ttl(), + Stream::TcpHttpProxy(s) => s.ttl(), + Stream::SslHttpProxy(s) => s.get_ref().ttl(), + } + } +} + +impl AsFd for Stream { + fn as_fd(&self) -> BorrowedFd<'_> { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(s) => s.as_fd(), + Stream::Tcp(s) => s.as_fd(), + Stream::Ssl(s) => s.get_ref().as_fd(), + Stream::TcpHttpProxy(s) => s.as_fd(), + Stream::SslHttpProxy(s) => s.get_ref().as_fd(), + } + } +} + +impl AsRawFd for Stream { + fn as_raw_fd(&self) -> RawFd { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(s) => s.as_raw_fd(), + Stream::Tcp(s) => s.as_raw_fd(), + Stream::Ssl(s) => s.get_ref().as_raw_fd(), + Stream::TcpHttpProxy(s) => s.as_raw_fd(), + Stream::SslHttpProxy(s) => s.get_ref().as_raw_fd(), + } + } +} + +impl AsyncRead for Stream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + #[cfg(target_family = "unix")] + Stream::Unix(s) => { + let stream = Pin::new(s); + stream.poll_read(cx, buf) + } + Stream::Tcp(s) => { + let stream = Pin::new(s); + stream.poll_read(cx, buf) + } + Stream::Ssl(s) => { + let stream = Pin::new(s); + stream.poll_read(cx, buf) + } + Stream::TcpHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_read(cx, buf) + } + Stream::SslHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_read(cx, buf) + } + } + } +} + +impl AsyncWrite for Stream { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + #[cfg(target_family = "unix")] + Stream::Unix(s) => { + let stream = Pin::new(s); + stream.poll_write(cx, buf) + } + Stream::Tcp(s) => { + let stream = Pin::new(s); + stream.poll_write(cx, buf) + } + Stream::Ssl(s) => { + let stream = Pin::new(s); + stream.poll_write(cx, buf) + } + Stream::TcpHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_write(cx, buf) + } + Stream::SslHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_write(cx, buf) + } + } + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + #[cfg(target_family = "unix")] + Stream::Unix(s) => { + let stream = Pin::new(s); + stream.poll_write_vectored(cx, bufs) + } + Stream::Tcp(s) => { + let stream = Pin::new(s); + stream.poll_write_vectored(cx, bufs) + } + Stream::Ssl(s) => { + let stream = Pin::new(s); + stream.poll_write_vectored(cx, bufs) + } + Stream::TcpHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_write_vectored(cx, bufs) + } + Stream::SslHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_write_vectored(cx, bufs) + } + } + } + + fn is_write_vectored(&self) -> bool { + match self { + #[cfg(target_family = "unix")] + Stream::Unix(s) => s.is_write_vectored(), + Stream::Tcp(s) => s.is_write_vectored(), + Stream::Ssl(s) => s.is_write_vectored(), + Stream::TcpHttpProxy(s) => s.is_write_vectored(), + Stream::SslHttpProxy(s) => s.is_write_vectored(), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + #[cfg(target_family = "unix")] + Stream::Unix(s) => { + let stream = Pin::new(s); + stream.poll_flush(cx) + } + Stream::Tcp(s) => { + let stream = Pin::new(s); + stream.poll_flush(cx) + } + Stream::Ssl(s) => { + let stream = Pin::new(s); + stream.poll_flush(cx) + } + Stream::TcpHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_flush(cx) + } + Stream::SslHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_flush(cx) + } + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + #[cfg(target_family = "unix")] + Stream::Unix(s) => { + let stream = Pin::new(s); + stream.poll_shutdown(cx) + } + Stream::Tcp(s) => { + let stream = Pin::new(s); + stream.poll_shutdown(cx) + } + Stream::Ssl(s) => { + let stream = Pin::new(s); + stream.poll_shutdown(cx) + } + Stream::TcpHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_shutdown(cx) + } + Stream::SslHttpProxy(s) => { + let stream = Pin::new(s); + stream.poll_shutdown(cx) + } + } + } +} + +impl fmt::Display for Stream { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let addr = self + .local_addr() + .unwrap_or(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(0, 0, 0, 0), + 0, + ))); + match self { + #[cfg(target_family = "unix")] + Stream::Unix(_) => write!(f, "unix://{}", addr), + Stream::Tcp(_) => write!(f, "tcp://{}", addr), + Stream::Ssl(_) => write!(f, "ssl://{}", addr), + Stream::TcpHttpProxy(_) => write!(f, "tcp+http_proxy://{}", addr), + Stream::SslHttpProxy(_) => write!(f, "ssl+http_proxy://{}", addr), + } + } +} + +#[cfg(target_family = "unix")] +impl From for Stream { + fn from(stream: tokio::net::UnixStream) -> Self { + Stream::Unix(stream) + } +} + +impl From for Stream { + fn from(stream: TcpStream) -> Self { + Stream::Tcp(stream) + } +} + +/// Configuration struct of an network target +/// +/// ``` +/// use tokio::io; +/// use url::Url; +/// use prosa::io::stream::{TargetSetting, Stream}; +/// +/// async fn connecting() -> Result<(), io::Error> { +/// let wl_target = TargetSetting::new(Url::parse("https://worldline.com").unwrap(), None, None); +/// let stream: Stream = wl_target.connect().await?; +/// +/// // Handle the stream like any tokio stream +/// +/// Ok(()) +/// } +/// ``` +#[derive(Deserialize, Serialize, Clone)] +pub struct TargetSetting { + /// Url of the target destination + pub url: Url, + /// SSL configuration for target destination + pub ssl: Option, + /// Optional proxy use to reach the target + pub proxy: Option, + #[serde(skip)] + /// SSL configuration for target destination + ssl_context: Option, + #[serde(skip_serializing)] + #[serde(default = "TargetSetting::get_default_connect_timeout")] + /// Timeout for socket connection in milliseconds + pub connect_timeout: u32, +} + +impl TargetSetting { + fn get_default_connect_timeout() -> u32 { + 5000 + } + + /// Method to create manually a target + pub fn new(url: Url, ssl: Option, proxy: Option) -> TargetSetting { + let mut target = TargetSetting { + url, + ssl, + proxy, + ssl_context: None, + connect_timeout: Self::get_default_connect_timeout(), + }; + + target.init_ssl_context(); + target + } + + /// Method to init the ssl context out of the ssl target configuration. + /// Must be call when the configuration is retrieved + pub fn init_ssl_context(&mut self) { + if let Some(ssl_config) = &self.ssl { + if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { + self.ssl_context = Some(ssl_context_builder.build()); + } + } + } + + /// Method to connect a ProSA stream to the remote target using the configuration + pub async fn connect(&self) -> Result { + #[cfg(target_family = "unix")] + if self.url.scheme() == "unix" || self.url.scheme() == "file" { + return Stream::connect_unix(self.url.path()).await; + } + + let ssl_context = if self.ssl_context.is_some() { + self.ssl_context.clone() + } else if let Some(ssl_config) = &self.ssl { + if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { + Some(ssl_context_builder.build()) + } else { + None + } + } else if url_is_ssl(&self.url) { + let ssl_config = SslConfig::default(); + if let Ok(ssl_context_builder) = ssl_config.init_tls_client_context() { + Some(ssl_context_builder.build()) + } else { + None + } + } else { + None + }; + + if let Some(proxy_url) = &self.proxy { + if let Some(ssl_cx) = ssl_context { + Stream::connect_ssl_with_http_proxy( + self.url.host_str().unwrap_or_default(), + self.url.port_or_known_default().unwrap_or_default(), + &ssl_cx, + proxy_url, + ) + .await + } else { + Stream::connect_tcp_with_http_proxy( + self.url.host_str().unwrap_or_default(), + self.url.port_or_known_default().unwrap_or_default(), + proxy_url, + ) + .await + } + } else if let Some(ssl_cx) = ssl_context { + Stream::connect_ssl(&self.url, &ssl_cx).await + } else { + let addrs = self.url.socket_addrs(|| self.url.port_or_known_default())?; + Stream::connect_tcp(&*addrs).await + } + } +} + +impl From for TargetSetting { + fn from(url: Url) -> Self { + TargetSetting { + url, + ssl: None, + proxy: None, + ssl_context: None, + connect_timeout: Self::get_default_connect_timeout(), + } + } +} + +impl fmt::Debug for TargetSetting { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("TargetSetting") + .field("url", &self.url) + .field("ssl", &self.ssl) + .field("connect_timeout", &self.connect_timeout) + .finish() + } +} + +impl fmt::Display for TargetSetting { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut url = self.url.clone(); + if self.ssl.is_some() { + let url_scheme = url.scheme(); + if url_scheme.is_empty() { + let _ = url.set_scheme("ssl"); + } else if !url_scheme.ends_with("ssl") + && !url_scheme.ends_with("tls") + && !url_scheme.ends_with("https") + && !url_scheme.ends_with("wss") + { + let _ = url.set_scheme(format!("{}+ssl", url_scheme).as_str()); + } + } + + if let Some(proxy_url) = &self.proxy { + write!(f, "{} -proxy {}", url, proxy_url) + } else { + write!(f, "{}", url) + } + } +} diff --git a/prosa_utils/src/config/ssl.rs b/prosa_utils/src/config/ssl.rs index 5776337..c130efb 100644 --- a/prosa_utils/src/config/ssl.rs +++ b/prosa_utils/src/config/ssl.rs @@ -1,5 +1,6 @@ //! Definition of SSL configuration +use bytes::{BufMut, BytesMut}; use glob::glob; use openssl::{ asn1::{Asn1Integer, Asn1Time}, @@ -8,11 +9,17 @@ use openssl::{ hash::MessageDigest, nid::Nid, pkey::PKey, - ssl::{SslFiletype, SslMethod, SslVerifyMode}, - x509::{X509NameBuilder, X509}, + ssl::{AlpnError, SslContextBuilder, SslFiletype, SslMethod, SslVerifyMode}, + x509::{extension::SubjectAlternativeName, X509NameBuilder, X509}, }; use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, fmt, fs, time}; +use std::{ + collections::HashMap, + ffi::OsStr, + fmt, fs, + ops::DerefMut, + time::{self, Duration}, +}; use super::{os_country, ConfigError}; @@ -28,24 +35,22 @@ impl Store { path: &std::path::PathBuf, ) -> Result, ConfigError> { if path.is_file() { - if path.ends_with(".pem") { - match fs::read(path) { + match &path.extension().and_then(OsStr::to_str) { + Some("pem") => match fs::read(path) { Ok(pem_file) => Ok(Some(openssl::x509::X509::from_pem(&pem_file)?)), Err(io) => Err(ConfigError::IoFile( path.to_str().unwrap_or_default().into(), io, )), - } - } else if path.ends_with(".der") { - match fs::read(path) { - Ok(pem_file) => Ok(Some(openssl::x509::X509::from_der(&pem_file)?)), + }, + Some("der") => match fs::read(path) { + Ok(der_file) => Ok(Some(openssl::x509::X509::from_der(&der_file)?)), Err(io) => Err(ConfigError::IoFile( path.to_str().unwrap_or_default().into(), io, )), - } - } else { - Ok(None) + }, + _ => Ok(None), } } else { Ok(None) @@ -67,7 +72,7 @@ impl Store { /// let openssl_store: openssl::x509::store::X509Store = store.get_store().unwrap(); /// ``` pub fn get_store(&self) -> Result { - match glob(&self.path) { + match glob(&(self.path.clone() + "*")) { Ok(certs) => { let mut store = openssl::x509::store::X509StoreBuilder::new()?; for cert_path in certs.flatten() { @@ -94,7 +99,7 @@ impl Store { /// assert!(certs_map.is_empty()); /// ``` pub fn get_certs(&self) -> Result, ConfigError> { - match glob(&self.path) { + match glob(&(self.path.clone() + "*")) { Ok(certs) => { let mut certs_map = HashMap::new(); for cert_path in certs.flatten() { @@ -135,7 +140,7 @@ impl Default for Store { impl fmt::Display for Store { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let certs = self.get_certs().unwrap_or_default(); - writeln!(f, "Store cert path [{}]:\n", self.path)?; + writeln!(f, "Store cert path [{}]:", self.path)?; for (name, cert) in certs { if f.alternate() { writeln!(f, "{}:\n{:#?}", name, cert)?; @@ -164,8 +169,7 @@ impl fmt::Display for Store { /// /// let client_config = SslConfig::default(); /// if let Ok(mut ssl_context_builder) = client_config.init_tls_client_context() { -/// let ssl_context = ssl_context_builder.build(); -/// let ssl = Ssl::new(&ssl_context).unwrap(); +/// let ssl = ssl_context_builder.build().configure().unwrap().into_ssl("localhost").unwrap(); /// let mut stream = SslStream::new(ssl, stream).unwrap(); /// if let Err(e) = Pin::new(&mut stream).connect().await { /// if e.code() != ErrorCode::ZERO_RETURN { @@ -193,13 +197,13 @@ impl fmt::Display for Store { /// let listener = TcpListener::bind("0.0.0.0:4443").await?; /// /// let server_config = SslConfig::new_cert_key("cert.pem".into(), "cert.key".into(), Some("passphrase".into())); -/// if let Ok(mut ssl_context_builder) = server_config.init_tls_server_context() { +/// if let Ok(mut ssl_context_builder) = server_config.init_tls_server_context(None) { /// ssl_context_builder.set_verify(SslVerifyMode::NONE); /// let ssl_context = ssl_context_builder.build(); /// /// loop { /// let (stream, cli_addr) = listener.accept().await?; -/// let ssl = Ssl::new(&ssl_context).unwrap(); +/// let ssl = Ssl::new(&ssl_context.context()).unwrap(); /// let mut stream = SslStream::new(ssl, stream).unwrap(); /// if let Err(e) = Pin::new(&mut stream).accept().await { /// if e.code() != ErrorCode::ZERO_RETURN { @@ -214,7 +218,7 @@ impl fmt::Display for Store { /// Ok(()) /// } /// ``` -#[derive(Default, Debug, Clone, PartialEq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct SslConfig { /// SSL store certificate to verify the remote certificate store: Option, @@ -226,9 +230,28 @@ pub struct SslConfig { key: Option, /// passphrase for private key or pkcs12 passphrase: Option, + #[serde(default)] + /// ALPN list send by the client, or order of ALPN accepted by the server + alpn: Vec, + #[serde(skip_serializing)] + #[serde(default = "SslConfig::default_modern_security")] + /// Security level. If `true`, it'll use the [modern version 5 of Mozilla's](https://wiki.mozilla.org/Security/Server_Side_TLS) TLS recommendations. + modern_security: bool, + #[serde(skip_serializing)] + #[serde(default = "SslConfig::default_ssl_timeout")] + /// SSL operation timeout + ssl_timeout: u64, } impl SslConfig { + fn default_modern_security() -> bool { + true + } + + fn default_ssl_timeout() -> u64 { + 3000 + } + /// Method to create an ssl configuration from a pkcs12 manually /// Should be use with config instead of building it manually pub fn new_pkcs12(pkcs12_path: String) -> SslConfig { @@ -238,6 +261,9 @@ impl SslConfig { cert: None, key: None, passphrase: None, + alpn: Vec::default(), + modern_security: Self::default_modern_security(), + ssl_timeout: Self::default_ssl_timeout(), } } @@ -254,26 +280,37 @@ impl SslConfig { cert: Some(cert_path), key: Some(key_path), passphrase, + alpn: Vec::default(), + modern_security: Self::default_modern_security(), + ssl_timeout: Self::default_ssl_timeout(), } } + /// Getter of the SSL timeout + pub fn get_ssl_timeout(&self) -> Duration { + Duration::from_millis(self.ssl_timeout) + } + /// Setter of the store certificate pub fn set_store(&mut self, store: Store) { self.store = Some(store); } + /// Setter of the ALPN list send by the client, or order of ALPN accepted by the server + pub fn set_alpn(&mut self, alpn: Vec) { + self.alpn = alpn; + } + /// Method to init an SSL context for a socket - pub(crate) fn init_tls_context( + pub(crate) fn init_tls_context( &self, + mut context_builder: B, is_server: bool, - ) -> Result { - let mut ssl_context_builder = openssl::ssl::SslContext::builder(if is_server { - SslMethod::tls_server() - } else { - SslMethod::tls_client() - })?; - ssl_context_builder.set_min_proto_version(Some(openssl::ssl::SslVersion::TLS1_2))?; - + domain: Option<&str>, + ) -> Result + where + B: DerefMut, + { if let Some(pkcs12_path) = &self.pkcs12 { match fs::read(pkcs12_path) { Ok(pkcs12_file) => { @@ -281,23 +318,23 @@ impl SslConfig { .parse2(self.passphrase.as_ref().unwrap_or(&String::from("")))?; if let Some(pkey) = pkcs12.pkey { - ssl_context_builder.set_private_key(&pkey)?; + context_builder.set_private_key(&pkey)?; } if let Some(cert) = pkcs12.cert { - ssl_context_builder.set_certificate(&cert)?; + context_builder.set_certificate(&cert)?; } if let Some(ca) = pkcs12.ca { for cert in ca { - ssl_context_builder.add_extra_chain_cert(cert)?; + context_builder.add_extra_chain_cert(cert)?; } } } Err(io) => return Err(ConfigError::IoFile(pkcs12_path.to_string(), io)), } } else if let (Some(cert_path), Some(key_path)) = (&self.cert, &self.key) { - ssl_context_builder.set_certificate_file(cert_path, SslFiletype::PEM)?; + context_builder.set_certificate_file(cert_path, SslFiletype::PEM)?; match fs::read(key_path) { Ok(key_file) => { @@ -312,7 +349,7 @@ impl SslConfig { PKey::private_key_from_pem(key_file.as_slice())? }; - ssl_context_builder.set_private_key(&pkey)?; + context_builder.set_private_key(&pkey)?; } Err(io) => return Err(ConfigError::IoFile(key_path.to_string(), io)), } @@ -320,7 +357,7 @@ impl SslConfig { let mut group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; group.set_asn1_flag(Asn1Flag::NAMED_CURVE); let pkey = PKey::from_ec_key(EcKey::generate(&group)?)?; - ssl_context_builder.set_private_key(&pkey)?; + context_builder.set_private_key(&pkey)?; let mut cert = X509::builder()?; cert.set_version(2)?; @@ -346,23 +383,69 @@ impl SslConfig { cert.set_subject_name(&x509_name)?; cert.set_issuer_name(&x509_name)?; + // Add DNS subject alternative name if needed to check the certificate + if let Some(dns) = domain { + let mut subject_alternative_name = SubjectAlternativeName::new(); + let x509_extension = subject_alternative_name + .dns(dns) + .build(&cert.x509v3_context(None, None))?; + cert.append_extension2(&x509_extension)?; + } + cert.sign(&pkey, MessageDigest::sha256())?; - ssl_context_builder.set_certificate(&cert.build())?; + context_builder.set_certificate(&cert.build())?; } if let Some(store) = &self.store { - ssl_context_builder.set_cert_store(store.get_store()?); + context_builder.set_cert_store(store.get_store()?); if is_server { - ssl_context_builder.set_verify(SslVerifyMode::PEER); + context_builder.set_verify(SslVerifyMode::PEER); } } else if !is_server { - ssl_context_builder.set_cert_store(Store::default().get_store()?); + context_builder.set_cert_store(Store::default().get_store()?); } else { - ssl_context_builder.set_verify(SslVerifyMode::NONE); + context_builder.set_verify(SslVerifyMode::NONE); + } + + if !self.alpn.is_empty() { + if is_server { + let alpn_list = self.alpn.clone(); + context_builder.set_alpn_select_callback(move |_ssl, alpn| { + let mut cli_alpn = HashMap::new(); + + let mut current_split = alpn; + while let Some(length) = current_split.first() { + if current_split.len() > *length as usize { + let (left, right) = current_split.split_at(*length as usize + 1); + cli_alpn + .insert(String::from_utf8(left[1..].to_vec()).unwrap(), &left[1..]); + current_split = right; + } else { + return Err(AlpnError::ALERT_FATAL); + } + } + + for alpn_name in &alpn_list { + if let Some(alpn) = cli_alpn.get(alpn_name) { + return Ok(alpn); + } + } + + Err(AlpnError::NOACK) + }); + } else { + let mut alpn_bytes = BytesMut::new(); + for alpn in &self.alpn { + alpn_bytes.put_u8(alpn.len() as u8); + alpn_bytes.put(alpn.as_bytes()); + } + + context_builder.set_alpn_protos(&alpn_bytes)?; + } } - Ok(ssl_context_builder) + Ok(context_builder) } /// Method to init an SSL context for a client socket @@ -376,8 +459,14 @@ impl SslConfig { /// let ssl_context = ssl_context_builder.build(); /// } /// ``` - pub fn init_tls_client_context(&self) -> Result { - self.init_tls_context(false) + pub fn init_tls_client_context( + &self, + ) -> Result { + self.init_tls_context( + openssl::ssl::SslConnector::builder(SslMethod::tls_client())?, + false, + None, + ) } /// Method to init an SSL context for a server socket @@ -386,12 +475,35 @@ impl SslConfig { /// use prosa_utils::config::ssl::SslConfig; /// /// let server_config = SslConfig::new_pkcs12("server.pkcs12".into()); - /// if let Ok(mut ssl_context_builder) = server_config.init_tls_server_context() { + /// if let Ok(mut ssl_context_builder) = server_config.init_tls_server_context(None) { /// let ssl_context = ssl_context_builder.build(); /// } /// ``` - pub fn init_tls_server_context(&self) -> Result { - self.init_tls_context(true) + pub fn init_tls_server_context( + &self, + domain: Option<&str>, + ) -> Result { + let ssl_acceptor = if self.modern_security { + openssl::ssl::SslAcceptor::mozilla_modern_v5(SslMethod::tls_server()) + } else { + openssl::ssl::SslAcceptor::mozilla_intermediate_v5(SslMethod::tls_server()) + }?; + self.init_tls_context(ssl_acceptor, true, domain) + } +} + +impl Default for SslConfig { + fn default() -> SslConfig { + SslConfig { + store: None, + pkcs12: None, + cert: None, + key: None, + passphrase: None, + alpn: Vec::default(), + modern_security: Self::default_modern_security(), + ssl_timeout: Self::default_ssl_timeout(), + } } } @@ -402,10 +514,10 @@ mod tests { #[test] fn test_tls_server_context() { let ssl_config = SslConfig::default(); - let ssl_context = ssl_config.init_tls_server_context().unwrap().build(); + let ssl_acceptor = ssl_config.init_tls_server_context(None).unwrap().build(); // Check for self signed certificate - assert!(ssl_context.private_key().is_some()); - assert!(ssl_context.certificate().is_some()); + assert!(ssl_acceptor.context().private_key().is_some()); + assert!(ssl_acceptor.context().certificate().is_some()); } }