Skip to content

Commit

Permalink
Add thread-safe port reservation for tests
Browse files Browse the repository at this point in the history
Implement a port reservation system to prevent concurrent tests from
tryingto use the same ports. This fixes flaky test failures where
integration tests would occasionally fail with AddrInUse errors.
  • Loading branch information
spacebear21 committed Nov 13, 2024
1 parent 974cdca commit 89caf11
Showing 1 changed file with 66 additions and 31 deletions.
97 changes: 66 additions & 31 deletions payjoin/tests/integration.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#[cfg(all(feature = "send", feature = "receive"))]
mod integration {
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::env;
use std::str::FromStr;
use std::sync::Mutex;

use bitcoin::policy::DEFAULT_MIN_RELAY_TX_FEE;
use bitcoin::psbt::{Input as PsbtInput, Psbt};
Expand Down Expand Up @@ -189,6 +190,21 @@ mod integration {

static TESTS_TIMEOUT: Lazy<Duration> = Lazy::new(|| Duration::from_secs(20));
static WAIT_SERVICE_INTERVAL: Lazy<Duration> = Lazy::new(|| Duration::from_secs(3));
/// Global set of TCP ports that are currently reserved by tests.
/// Protected by a mutex for thread-safety in concurrent testing scenarios.
static RESERVED_PORTS: Lazy<Mutex<HashSet<u16>>> = Lazy::new(|| Mutex::new(HashSet::new()));

/// A RAII guard that keeps a port reserved until dropped.
#[derive(Debug)]
struct PortGuard(u16);

impl PortGuard {
fn port(&self) -> u16 { self.0 }
}

impl Drop for PortGuard {
fn drop(&mut self) { RESERVED_PORTS.lock().unwrap().remove(&self.0); }
}

#[tokio::test]
async fn test_bad_ohttp_keys() {
Expand All @@ -197,10 +213,11 @@ mod integration {
.expect("Invalid OhttpKeys");

let (cert, key) = local_cert_key();
let port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", port)).unwrap();
let port_guard = reserve_port();
let directory =
Url::parse(&format!("https://localhost:{}", port_guard.port())).unwrap();
tokio::select!(
err = init_directory(port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = init_directory(port_guard.port(), (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
res = try_request_with_bad_keys(directory, bad_ohttp_keys, cert) => {
assert_eq!(
res.unwrap().headers().get("content-type").unwrap(),
Expand Down Expand Up @@ -231,15 +248,16 @@ mod integration {
async fn test_session_expiration() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_port_guard = reserve_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
Url::parse(&format!("http://localhost:{}", ohttp_port_guard.port())).unwrap();
let directory_port_guard = reserve_port();
let directory =
Url::parse(&format!("https://localhost:{}", directory_port_guard.port())).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = ohttp_relay::listen_tcp(ohttp_port_guard.port(), gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port_guard.port(), (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
res = do_expiration_tests(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -299,15 +317,16 @@ mod integration {
async fn v2_to_v2() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_port_guard = reserve_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
Url::parse(&format!("http://localhost:{}", ohttp_port_guard.port())).unwrap();
let directory_port_guard = reserve_port();
let directory =
Url::parse(&format!("https://localhost:{}", directory_port_guard.port())).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = ohttp_relay::listen_tcp(ohttp_port_guard.port(), gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port_guard.port(), (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -431,15 +450,16 @@ mod integration {
async fn v2_to_v2_mixed_input_script_types() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_port_guard = reserve_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
Url::parse(&format!("http://localhost:{}", ohttp_port_guard.port())).unwrap();
let directory_port_guard = reserve_port();
let directory =
Url::parse(&format!("https://localhost:{}", directory_port_guard.port())).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = ohttp_relay::listen_tcp(ohttp_port_guard.port(), gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port_guard.port(), (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
res = do_v2_send_receive(ohttp_relay, directory, cert) => assert!(res.is_ok(), "v2 send receive failed: {:#?}", res)
);

Expand Down Expand Up @@ -648,15 +668,16 @@ mod integration {
async fn v1_to_v2() {
init_tracing();
let (cert, key) = local_cert_key();
let ohttp_relay_port = find_free_port();
let ohttp_port_guard = reserve_port();
let ohttp_relay =
Url::parse(&format!("http://localhost:{}", ohttp_relay_port)).unwrap();
let directory_port = find_free_port();
let directory = Url::parse(&format!("https://localhost:{}", directory_port)).unwrap();
Url::parse(&format!("http://localhost:{}", ohttp_port_guard.port())).unwrap();
let directory_port_guard = reserve_port();
let directory =
Url::parse(&format!("https://localhost:{}", directory_port_guard.port())).unwrap();
let gateway_origin = http::Uri::from_str(directory.as_str()).unwrap();
tokio::select!(
err = ohttp_relay::listen_tcp(ohttp_relay_port, gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port, (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
err = ohttp_relay::listen_tcp(ohttp_port_guard.port(), gateway_origin) => panic!("Ohttp relay exited early: {:?}", err),
err = init_directory(directory_port_guard.port(), (cert.clone(), key)) => panic!("Directory server exited early: {:?}", err),
res = do_v1_to_v2(ohttp_relay, directory, cert) => assert!(res.is_ok()),
);

Expand Down Expand Up @@ -912,9 +933,23 @@ mod integration {
))
}

fn find_free_port() -> u16 {
let listener = std::net::TcpListener::bind("0.0.0.0:0").unwrap();
listener.local_addr().unwrap().port()
fn reserve_port() -> PortGuard {
let mut reserved_ports = RESERVED_PORTS.lock().unwrap();

for _ in 0..100 {
// Try up to 100 times to find a free port
let listener = std::net::TcpListener::bind("0.0.0.0:0").unwrap();
let port = listener.local_addr().unwrap().port();

if !reserved_ports.contains(&port) {
reserved_ports.insert(port);
return PortGuard(port);
} else {
println!("port {} is already reserved, trying a different port...", port);
}
}

panic!("Couldn't find a free port");
}

async fn wait_for_service_ready(
Expand Down

0 comments on commit 89caf11

Please sign in to comment.