diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 0000000..ca6c72f --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,2 @@ +[target.x86_64-unknown-linux-gnu] +runner = 'sudo -E' diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..dc995b7 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,18 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/rust +{ + "name": "Rust", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/rust:1-1-bullseye", + "features": { + "ghcr.io/devcontainers/features/rust:1": { + "version": "latest", + "profile": "minimal" + } + }, + "runArgs": [ + // TODO: figure out the exact cap-add ? + "--privileged" + ], + "remoteUser": "root" +} \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..f33a02c --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for more information: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +# https://containers.dev/guide/dependabot + +version: 2 +updates: + - package-ecosystem: "devcontainers" + directory: "/" + schedule: + interval: weekly diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 000bb2c..af5881b 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,8 +11,11 @@ env: jobs: build: + strategy: + matrix: + os: [ubuntu-latest, macos-latest] - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 diff --git a/benches/bench_find_process_by_socket.rs b/benches/bench_find_process_by_socket.rs index a11ebc1..85822ac 100644 --- a/benches/bench_find_process_by_socket.rs +++ b/benches/bench_find_process_by_socket.rs @@ -1,12 +1,12 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use sock2proc::{FindProc, FindProcImpl}; +use sock2proc::find_process_name; fn run_find_process_by_socket() { let dst = std::net::SocketAddr::new( std::net::IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)), 80, ); - let _process_name = FindProcImpl::resolve(None, Some(dst), libc::IPPROTO_TCP); + let _process_name = find_process_name(None, Some(dst), sock2proc::NetworkProtocol::TCP); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/lib.rs b/src/lib.rs index f6755d2..cbd7ccc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,10 @@ mod platform; mod utils; -pub use libc::{IPPROTO_TCP, IPPROTO_UDP}; -pub use platform::{FindProc, FindProcImpl}; +#[derive(PartialEq, Clone, Copy, Debug)] +#[repr(u8)] +pub enum NetworkProtocol { + TCP = 6, + UDP = 17, +} +pub use platform::find_process_name; diff --git a/src/platform/linux.rs b/src/platform/linux.rs index 138fc9e..7ded159 100644 --- a/src/platform/linux.rs +++ b/src/platform/linux.rs @@ -16,29 +16,14 @@ use netlink_packet_sock_diag::{ }; use netlink_sys::{protocols::NETLINK_SOCK_DIAG, Socket, SocketAddr}; -use super::FindProc; +use crate::{utils::pre_condition, NetworkProtocol}; -pub struct FindProcImpl; - -impl FindProc for FindProcImpl { - fn resolve( - src: Option, - dst: Option, - proto: i32, - ) -> Option { - resolve(src, dst, proto) - } -} - -fn resolve( +pub fn find_process_name( src: Option, dst: Option, - proto: i32, + proto: NetworkProtocol, ) -> Option { - if !crate::utils::check(src, dst) { - return None; - } - if proto != libc::IPPROTO_TCP || proto != libc::IPPROTO_UDP { + if !pre_condition(src, dst) { return None; } @@ -49,7 +34,7 @@ fn resolve( fn resolve_uid_inode( src: Option, dst: Option, - proto: i32, + proto: NetworkProtocol, ) -> Option<(u32, u32)> { let mut socket = Socket::new(NETLINK_SOCK_DIAG).unwrap(); let _port_number = socket.bind_auto().unwrap().port_number(); @@ -90,11 +75,12 @@ fn resolve_uid_inode( // Before calling serialize, it is important to check that the buffer in // which we're emitting is big enough for the packet, other // `serialize()` panics. - assert_eq!(buf.len(), packet.buffer_len()); + assert_eq!(buf.len(), packet.buffer_len(), "Buffer is too small"); packet.serialize(&mut buf[..]); - if let Err(_) = socket.send(&buf[..], 0) { + if let Err(e) = socket.send(&buf[..], 0) { + eprintln!("Failed to send packet: {:?}", e); return None; } diff --git a/src/platform/macos.rs b/src/platform/macos.rs index 3898aa3..adcb288 100644 --- a/src/platform/macos.rs +++ b/src/platform/macos.rs @@ -1,8 +1,7 @@ -use libc::{IPPROTO_TCP, IPPROTO_UDP}; use sysctl::Sysctl; -use crate::utils::{check, is_ipv6}; -use crate::FindProc; +use crate::utils::{is_ipv6, pre_condition}; +use crate::NetworkProtocol; use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -15,28 +14,24 @@ const PROCPIDPATHINFOSIZE: usize = 1024; const PROCCALLNUMPIDINFO: i32 = 0x2; static STRUCT_SIZE: AtomicUsize = AtomicUsize::new(0); -const STRUCT_SIZE_SETTER: Once = Once::new(); - -pub struct FindProcImpl; - -impl FindProc for FindProcImpl { - fn resolve( - src: Option, - dst: Option, - proto: i32, - ) -> Option { - if !check(src, dst) { - return None; - } - find_process_name(src, dst, proto).ok() - } +static STRUCT_SIZE_SETTER: Once = Once::new(); + +pub fn find_process_name( + src: Option, + dst: Option, + proto: NetworkProtocol, +) -> Option { + find_process_name_inner(src, dst, proto).ok() } -fn find_process_name( +fn find_process_name_inner( src: Option, dst: Option, - proto: i32, + proto: NetworkProtocol, ) -> Result { + if !pre_condition(src, dst) { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "Invalid input")); + } STRUCT_SIZE_SETTER.call_once(|| { let default = "".to_string(); let ctl = sysctl::Ctl::new("kern.osrelease").unwrap(); @@ -53,14 +48,8 @@ fn find_process_name( // see: https://github.com/apple-oss-distributions/xnu/blob/94d3b452840153a99b38a3a9659680b2a006908e/bsd/netinet/in_pcblist.c#L292 let spath = match proto { - IPPROTO_TCP => "net.inet.tcp.pcblist_n", - IPPROTO_UDP => "net.inet.udp.pcblist_n", - _ => { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Invalid network", - )) - } + NetworkProtocol::TCP => "net.inet.tcp.pcblist_n", + NetworkProtocol::UDP => "net.inet.udp.pcblist_n", }; let is_ipv4 = !is_ipv6(src, dst); @@ -69,7 +58,12 @@ fn find_process_name( let value = ctl.value().unwrap(); let buf = value.as_struct().unwrap(); let struct_size = STRUCT_SIZE.load(std::sync::atomic::Ordering::Relaxed); - let item_size = struct_size + if proto == IPPROTO_TCP { 208 } else { 0 }; + let item_size = struct_size + + if proto == NetworkProtocol::TCP { + 208 + } else { + 0 + }; // see https://github.com/apple-oss-distributions/xnu/blob/94d3b452840153a99b38a3a9659680b2a006908e/bsd/netinet/in_pcb.h#L451 // offset of flag is 44 @@ -144,7 +138,7 @@ fn find_process_name( fn get_pid(bytes: &[u8]) -> u32 { assert_eq!(bytes.len(), 4); let mut pid_bytes = [0; 4]; - pid_bytes.copy_from_slice(&bytes); + pid_bytes.copy_from_slice(bytes); if cfg!(target_endian = "big") { u32::from_be_bytes(pid_bytes) } else { diff --git a/src/platform/mod.rs b/src/platform/mod.rs index 95be2c5..e2b8404 100644 --- a/src/platform/mod.rs +++ b/src/platform/mod.rs @@ -1,32 +1,39 @@ -pub trait FindProc { - fn resolve( - src: Option, - dst: Option, - proto: i32, - ) -> Option; -} - #[cfg(target_os = "linux")] mod linux; #[cfg(target_os = "linux")] -pub use linux::FindProcImpl; +pub use linux::find_process_name; #[cfg(target_os = "macos")] mod macos; #[cfg(target_os = "macos")] -pub use macos::FindProcImpl; +pub use macos::find_process_name; #[cfg(test)] mod tests { - use super::*; + use std::net::TcpListener; #[test] - fn test_compile() { - let dst = std::net::SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(8, 8, 8, 8)), - 80, - ); - let _process_name = FindProcImpl::resolve(None, Some(dst), libc::IPPROTO_TCP); + fn test_get_find_tcp_socket() { + let socket = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = socket.local_addr().unwrap(); + let path = super::find_process_name(Some(addr), None, crate::NetworkProtocol::TCP); + + assert!(path.is_some()); + + let current_exe = std::env::current_exe().unwrap(); + assert_eq!(path.unwrap(), current_exe.to_str().unwrap()); + } + + #[test] + fn test_get_find_udp_socket() { + let socket = std::net::UdpSocket::bind("127.0.0.1:0").unwrap(); + let addr = socket.local_addr().unwrap(); + let path = super::find_process_name(Some(addr), None, crate::NetworkProtocol::UDP); + + assert!(path.is_some()); + + let current_exe = std::env::current_exe().unwrap(); + assert_eq!(path.unwrap(), current_exe.to_str().unwrap()); } } diff --git a/src/utils.rs b/src/utils.rs index ef9995f..42217ba 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,19 +1,26 @@ -pub(crate) fn check(src: Option, dst: Option) -> bool { - if src.is_none() && dst.is_none() { - false - } else if src.is_some() && dst.is_some() { - let inner1 = src.unwrap(); - let inner2 = dst.unwrap(); - (inner1.is_ipv4() && inner2.is_ipv6()) || (inner2.is_ipv4() && inner1.is_ipv6()) - } else { - true +pub(crate) fn pre_condition( + src: Option, + dst: Option, +) -> bool { + match (src, dst) { + (None, None) => false, + (Some(_), None) => true, + (None, Some(_)) => true, + (Some(left), Some(right)) => { + // it was (inner1.is_ipv4() && inner2.is_ipv6()) || (inner2.is_ipv4() && inner1.is_ipv6()) + (left.is_ipv4() && right.is_ipv4()) || (left.is_ipv6() && right.is_ipv6()) + } } } -pub(crate) fn is_ipv6(src: Option, dst: Option) -> bool { - if src.is_some() { - src.unwrap().is_ipv6() - } else { - dst.unwrap().is_ipv6() +pub(crate) fn is_ipv6( + src: Option, + dst: Option, +) -> bool { + match (src, dst) { + (Some(addr), None) => addr.is_ipv6(), + (None, Some(addr)) => addr.is_ipv6(), + (Some(left), Some(right)) => left.is_ipv6() || right.is_ipv6(), + _ => false, } -} \ No newline at end of file +}