Skip to content

Commit

Permalink
Merge pull request #1441 from cberkhoff/rscli
Browse files Browse the repository at this point in the history
Reverting #1431
  • Loading branch information
cberkhoff authored Nov 16, 2024
2 parents 5fdabfd + 5904d98 commit 4e0608e
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 670 deletions.
156 changes: 48 additions & 108 deletions ipa-core/src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,17 @@ use std::{
};

use clap::{self, Parser, Subcommand};
use futures::future::join;
use hyper::http::uri::Scheme;
use ipa_core::{
cli::{
client_config_setup, keygen, test_setup, ConfGenArgs, KeygenArgs, LoggingHandle,
TestSetupArgs, Verbosity,
},
config::{
hpke_registry, sharded_server_from_toml_str, HpkeServerConfig, ServerConfig, TlsConfig,
},
config::{hpke_registry, HpkeServerConfig, NetworkConfig, ServerConfig, TlsConfig},
error::BoxError,
executor::IpaRuntime,
helpers::HelperIdentity,
net::{
ClientIdentity, ConnectionFlavor, IpaHttpClient, MpcHttpTransport, Shard,
ShardHttpTransport,
},
net::{ClientIdentity, IpaHttpClient, MpcHttpTransport, ShardHttpTransport},
sharding::ShardIndex,
AppConfig, AppSetup, NonZeroU32PowerOfTwo,
};
Expand Down Expand Up @@ -61,32 +55,16 @@ struct ServerArgs {
#[arg(short, long, required = true)]
identity: Option<usize>,

#[arg(default_value = "0")]
shard_index: Option<u32>,

#[arg(default_value = "1")]
shard_count: Option<u32>,

/// Port to listen on
#[arg(short, long, default_value = "3000")]
port: Option<u16>,

/// Port to use for shard-to-shard communication, if sharded MPC is used
#[arg(default_value = "6000")]
shard_port: Option<u16>,

/// Use the supplied prebound socket instead of binding a new socket for mpc
/// Use the supplied prebound socket instead of binding a new socket
///
/// This is only intended for avoiding port conflicts in tests.
#[arg(hide = true, long)]
server_socket_fd: Option<RawFd>,

/// Use the supplied prebound socket instead of binding a new socket for shard server
///
/// This is only intended for avoiding port conflicts in tests.
#[arg(hide = true, long)]
shard_server_socket_fd: Option<RawFd>,

/// Use insecure HTTP
#[arg(short = 'k', long)]
disable_https: bool,
Expand All @@ -95,7 +73,7 @@ struct ServerArgs {
#[arg(long, required = true)]
network: Option<PathBuf>,

/// TLS certificate for helper-to-helper and shard-to-shard communication
/// TLS certificate for helper-to-helper communication
#[arg(
long,
visible_alias("cert"),
Expand All @@ -104,7 +82,7 @@ struct ServerArgs {
)]
tls_cert: Option<PathBuf>,

/// TLS key for helper-to-helper and shard-to-shard communication
/// TLS key for helper-to-helper communication
#[arg(long, visible_alias("key"), requires = "tls_cert")]
tls_key: Option<PathBuf>,

Expand Down Expand Up @@ -136,58 +114,24 @@ fn read_file(path: &Path) -> Result<BufReader<fs::File>, BoxError> {
.map_err(|e| format!("failed to open file {}: {e:?}", path.display()))?)
}

/// Helper function that creates the client identity; either with certificates if they are provided
/// or just with headers otherwise. This works both for sharded and helper configs.
fn create_client_identity<F: ConnectionFlavor>(
id: F::Identity,
tls_cert: Option<PathBuf>,
tls_key: Option<PathBuf>,
) -> Result<(ClientIdentity<F>, Option<TlsConfig>), BoxError> {
match (tls_cert, tls_key) {
async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), BoxError> {
let my_identity = HelperIdentity::try_from(args.identity.expect("enforced by clap")).unwrap();

let (identity, server_tls) = match (args.tls_cert, args.tls_key) {
(Some(cert_file), Some(key_file)) => {
let mut key = read_file(&key_file)?;
let mut certs = read_file(&cert_file)?;
Ok((
ClientIdentity::<F>::from_pkcs8(&mut certs, &mut key)?,
(
ClientIdentity::from_pkcs8(&mut certs, &mut key)?,
Some(TlsConfig::File {
certificate_file: cert_file,
private_key_file: key_file,
}),
))
)
}
(None, None) => Ok((ClientIdentity::Header(id), None)),
_ => Err("should have been rejected by clap".into()),
}
}

/// Creates a [`TcpListener`] from an optional raw file descriptor. Safety notes:
/// 1. The `--server-socket-fd` option is only intended for use in tests, not in production.
/// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has
/// only one owner.
fn create_listener(server_socket_fd: Option<RawFd>) -> Result<Option<TcpListener>, BoxError> {
server_socket_fd
.map(|fd| {
let listener = unsafe { TcpListener::from_raw_fd(fd) };
if listener.local_addr().is_ok() {
info!("adopting fd {fd} as listening socket");
Ok(listener)
} else {
Err(BoxError::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket")))
}
})
.transpose()
}

async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), BoxError> {
let my_identity = HelperIdentity::try_from(args.identity.expect("enforced by clap")).unwrap();
let shard_index = ShardIndex::from(args.shard_index.expect("enforced by clap"));
let shard_count = ShardIndex::from(args.shard_count.expect("enforced by clap"));
assert!(shard_index < shard_count);

let (identity, server_tls) =
create_client_identity(my_identity, args.tls_cert.clone(), args.tls_key.clone())?;
let (shard_identity, shard_server_tls) =
create_client_identity(shard_index, args.tls_cert, args.tls_key)?;
(None, None) => (ClientIdentity::Header(my_identity), None),
_ => panic!("should have been rejected by clap"),
};

let mk_encryption = args.mk_private_key.map(|sk_path| HpkeServerConfig::File {
private_key_file: sk_path,
Expand All @@ -205,13 +149,6 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B
port: args.port,
disable_https: args.disable_https,
tls: server_tls,
hpke_config: mk_encryption.clone(),
};

let shard_server_config = ServerConfig {
port: args.shard_port,
disable_https: args.disable_https,
tls: shard_server_tls,
hpke_config: mk_encryption,
};

Expand All @@ -220,48 +157,60 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B
} else {
Scheme::HTTPS
};

let network_config_path = args.network.as_deref().unwrap();
let network_config_string = &fs::read_to_string(network_config_path)?;
let (mut mpc_network, mut shard_network) =
sharded_server_from_toml_str(network_config_string, my_identity, shard_index, shard_count)?;
mpc_network = mpc_network.override_scheme(&scheme);
shard_network = shard_network.override_scheme(&scheme);
let network_config = NetworkConfig::from_toml_str(&fs::read_to_string(network_config_path)?)?
.override_scheme(&scheme);

// TODO: Following is just temporary until Shard Transport is actually used.
let shard_clients_config = network_config.client.clone();
let shard_server_config = server_config.clone();
// ---

let http_runtime = new_http_runtime(&logging_handle);
let clients = IpaHttpClient::from_conf(
&IpaRuntime::from_tokio_runtime(&http_runtime),
&mpc_network,
&network_config,
&identity,
);
let (transport, server) = MpcHttpTransport::new(
IpaRuntime::from_tokio_runtime(&http_runtime),
my_identity,
server_config,
mpc_network,
network_config,
&clients,
Some(handler),
);

let shard_clients = IpaHttpClient::<Shard>::shards_from_conf(
&IpaRuntime::from_tokio_runtime(&http_runtime),
&shard_network,
&shard_identity,
);
let (shard_transport, shard_server) = ShardHttpTransport::new(
// TODO: Following is just temporary until Shard Transport is actually used.
let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config);
let (shard_transport, _shard_server) = ShardHttpTransport::new(
IpaRuntime::from_tokio_runtime(&http_runtime),
shard_index,
shard_count,
ShardIndex::FIRST,
ShardIndex::from(1),
shard_server_config,
shard_network,
shard_clients,
shard_network_config,
vec![],
Some(shard_handler),
);
// ---

let _app = setup.connect(transport.clone(), shard_transport.clone());

let listener = create_listener(args.server_socket_fd)?;
let shard_listener = create_listener(args.shard_server_socket_fd)?;
let listener = args.server_socket_fd
.map(|fd| {
// SAFETY:
// 1. The `--server-socket-fd` option is only intended for use in tests, not in production.
// 2. This must be the only call to from_raw_fd for this file descriptor, to ensure it has
// only one owner.
let listener = unsafe { TcpListener::from_raw_fd(fd) };
if listener.local_addr().is_ok() {
info!("adopting fd {fd} as listening socket");
Ok(listener)
} else {
Err(BoxError::from(format!("the server was asked to listen on fd {fd}, but it does not appear to be a valid socket")))
}
})
.transpose()?;

let (_addr, server_handle) = server
.start_on(
Expand All @@ -271,17 +220,8 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B
None as Option<()>,
)
.await;
let (_saddr, shard_server_handle) = shard_server
.start_on(
&IpaRuntime::from_tokio_runtime(&http_runtime),
shard_listener,
// TODO, trace based on the content of the query.
None as Option<()>,
)
.await;

join(server_handle, shard_server_handle).await;

server_handle.await;
[query_runtime, http_runtime].map(Runtime::shutdown_background);

Ok(())
Expand Down
17 changes: 2 additions & 15 deletions ipa-core/src/cli/clientconf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ pub struct ConfGenArgs {
#[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])]
ports: Vec<u16>,

#[arg(short, long, num_args = 3, value_name = "SHARD_PORTS", default_values = vec!["6000", "6001", "6002"])]
shard_ports: Vec<u16>,

#[arg(long, num_args = 3, default_values = vec!["localhost", "localhost", "localhost"])]
hosts: Vec<String>,

Expand Down Expand Up @@ -57,14 +54,13 @@ pub struct ConfGenArgs {
/// [`ConfGenArgs`]: ConfGenArgs
/// [`Paths`]: crate::cli::paths::PathExt
pub fn setup(args: ConfGenArgs) -> Result<(), BoxError> {
let clients_conf: [_; 3] = zip(args.hosts.iter(), zip(args.ports, args.shard_ports))
let clients_conf: [_; 3] = zip(args.hosts.iter(), args.ports)
.enumerate()
.map(|(id, (host, (port, shard_port)))| {
.map(|(id, (host, port))| {
let id: u8 = u8::try_from(id).unwrap() + 1;
HelperClientConf {
host,
port,
shard_port,
tls_cert_file: args.keys_dir.helper_tls_cert(id),
mk_public_key_file: args.keys_dir.helper_mk_public_key(id),
}
Expand Down Expand Up @@ -100,7 +96,6 @@ pub fn setup(args: ConfGenArgs) -> Result<(), BoxError> {
pub struct HelperClientConf<'a> {
pub(crate) host: &'a str,
pub(crate) port: u16,
pub(crate) shard_port: u16,
pub(crate) tls_cert_file: PathBuf,
pub(crate) mk_public_key_file: PathBuf,
}
Expand Down Expand Up @@ -138,14 +133,6 @@ pub fn gen_client_config<'a>(
port = client_conf.port
)),
);
peer.insert(
String::from("shard_url"),
Value::String(format!(
"{host}:{port}",
host = client_conf.host,
port = client_conf.shard_port
)),
);
peer.insert(String::from("certificate"), Value::String(certificate));
peer.insert(
String::from("hpke"),
Expand Down
8 changes: 2 additions & 6 deletions ipa-core/src/cli/test_setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ pub struct TestSetupArgs {

#[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])]
ports: Vec<u16>,

#[arg(short, long, num_args = 3, value_name = "SHARD_PORT", default_values = vec!["6000", "6001", "6002"])]
shard_ports: Vec<u16>,
}

/// Prepare a test network of three helpers.
Expand All @@ -59,8 +56,8 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> {

let localhost = String::from("localhost");

let clients_config: [_; 3] = zip([1, 2, 3], zip(args.ports, args.shard_ports))
.map(|(id, (port, shard_port))| {
let clients_config: [_; 3] = zip([1, 2, 3], args.ports)
.map(|(id, port)| {
let keygen_args = KeygenArgs {
name: localhost.clone(),
tls_cert: args.output_dir.helper_tls_cert(id),
Expand All @@ -75,7 +72,6 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> {
Ok(HelperClientConf {
host: &localhost,
port,
shard_port,
tls_cert_file: keygen_args.tls_cert,
mk_public_key_file: keygen_args.mk_public_key,
})
Expand Down
Loading

0 comments on commit 4e0608e

Please sign in to comment.