From 7d120e13709821832adb9c2e29c426e38f17ce2c Mon Sep 17 00:00:00 2001 From: Romain Beauxis Date: Wed, 11 Dec 2024 07:43:50 +0100 Subject: [PATCH] SRT cleanup. --- src/core/io/srt_io.ml | 177 +++++++++++++++++++++------------------- src/core/tools/utils.ml | 2 + 2 files changed, 96 insertions(+), 83 deletions(-) diff --git a/src/core/io/srt_io.ml b/src/core/io/srt_io.ml index 54bc148a30..e3ce1b53c7 100644 --- a/src/core/io/srt_io.ml +++ b/src/core/io/srt_io.ml @@ -368,6 +368,7 @@ module Poll = struct Duppy.Async.wake_up task let remove_socket socket = + Hashtbl.remove t.handlers socket; if List.mem socket (Srt.Poll.sockets t.p) then Srt.Poll.remove_usock t.p socket end @@ -495,28 +496,31 @@ class virtual caller ~enforced_encryption ~pbkeylen ~passphrase ~streamid try Srt.setsockflag s Srt.sndsyn true; Srt.setsockflag s Srt.rcvsyn true; - ignore - (Option.map (fun id -> Srt.(setsockflag s streamid id)) streamid); - ignore - (Option.map - (fun b -> Srt.(setsockflag s enforced_encryption b)) - enforced_encryption); - ignore - (Option.map (fun len -> Srt.(setsockflag s pbkeylen len)) pbkeylen); - ignore - (Option.map (fun p -> Srt.(setsockflag s passphrase p)) passphrase); - ignore - (Option.map - (fun v -> Srt.(setsockflag s conntimeo v)) - connection_timeout); - ignore - (Option.map (fun v -> Srt.(setsockflag s sndtimeo v)) write_timeout); - ignore - (Option.map (fun v -> Srt.(setsockflag s rcvtimeo v)) read_timeout); + Utils.optional_apply + (fun id -> Srt.(setsockflag s streamid id)) + streamid; + Utils.optional_apply + (fun b -> Srt.(setsockflag s enforced_encryption b)) + enforced_encryption; + Utils.optional_apply + (fun len -> Srt.(setsockflag s pbkeylen len)) + pbkeylen; + Utils.optional_apply + (fun p -> Srt.(setsockflag s passphrase p)) + passphrase; + Utils.optional_apply + (fun v -> Srt.(setsockflag s conntimeo v)) + connection_timeout; + Utils.optional_apply + (fun v -> Srt.(setsockflag s sndtimeo v)) + write_timeout; + Utils.optional_apply + (fun v -> Srt.(setsockflag s rcvtimeo v)) + read_timeout; Srt.connect s sockaddr; - Atomic.set socket (Some (sockaddr, s)); self#log#important "Client connected!"; !on_connect (); + Atomic.set socket (Some (sockaddr, s)); -1. with exn -> let bt = Printexc.get_raw_backtrace () in @@ -572,76 +576,83 @@ class virtual listener ~enforced_encryption ~pbkeylen ~passphrase ~max_clients method private listening_socket = match Atomic.get listening_socket with | Some s -> s - | None -> + | None -> ( let s = mk_socket ~payload_size ~messageapi () in - Srt.bind s bind_address; - let max_clients_callback = - Option.map - (fun n _ _ _ _ -> - self#mutexify (fun () -> List.length client_sockets < n) ()) - max_clients - in - let listen_callback = - List.fold_left - (fun cur v -> - match (cur, v) with - | None, _ -> v - | Some _, None -> cur - | Some cur, Some fn -> - Some - (fun s hs_version peeraddr streamid -> - cur s hs_version peeraddr streamid - && fn s hs_version peeraddr streamid)) - None - [max_clients_callback; listen_callback] - in - ignore - (Option.map (fun fn -> Srt.listen_callback s fn) listen_callback); - ignore - (Option.map - (fun b -> Srt.(setsockflag s enforced_encryption b)) - enforced_encryption); - ignore - (Option.map - (fun len -> Srt.(setsockflag s pbkeylen len)) - pbkeylen); - ignore - (Option.map - (fun p -> Srt.(setsockflag s passphrase p)) - passphrase); - Srt.listen s (Option.value ~default:1 max_clients); - self#log#info "Setting up socket to listen at %s" - (string_of_address bind_address); - Atomic.set listening_socket (Some s); - s + try + Srt.bind s bind_address; + let max_clients_callback = + Option.map + (fun n _ _ _ _ -> + self#mutexify (fun () -> List.length client_sockets < n) ()) + max_clients + in + let listen_callback = + List.fold_left + (fun cur v -> + match (cur, v) with + | None, _ -> v + | Some _, None -> cur + | Some cur, Some fn -> + Some + (fun s hs_version peeraddr streamid -> + cur s hs_version peeraddr streamid + && fn s hs_version peeraddr streamid)) + None + [max_clients_callback; listen_callback] + in + Utils.optional_apply + (fun fn -> Srt.listen_callback s fn) + listen_callback; + Utils.optional_apply + (fun b -> Srt.(setsockflag s enforced_encryption b)) + enforced_encryption; + Utils.optional_apply + (fun len -> Srt.(setsockflag s pbkeylen len)) + pbkeylen; + Utils.optional_apply + (fun p -> Srt.(setsockflag s passphrase p)) + passphrase; + Srt.listen s (Option.value ~default:1 max_clients); + self#log#info "Setting up socket to listen at %s" + (string_of_address bind_address); + Atomic.set listening_socket (Some s); + s + with exn -> + let bt = Printexc.get_raw_backtrace () in + Srt.close s; + Printexc.raise_with_backtrace exn bt) method private connect = let rec accept_connection s = try let client, origin = Srt.accept s in - Poll.add_socket ~mode:`Read s accept_connection; - (try self#log#info "New connection from %s" (string_of_address origin) - with exn -> - self#log#important "Error while fetching connection source: %s" - (Printexc.to_string exn)); - Srt.(setsockflag client sndsyn true); - Srt.(setsockflag client rcvsyn true); - ignore - (Option.map - (fun v -> Srt.(setsockflag client sndtimeo v)) - write_timeout); - ignore - (Option.map - (fun v -> Srt.(setsockflag client rcvtimeo v)) - read_timeout); - if self#should_stop then ( - close_socket client; - raise Done); - self#mutexify - (fun () -> - client_sockets <- (origin, client) :: client_sockets; - !on_connect ()) - () + try + Poll.add_socket ~mode:`Read s accept_connection; + (try + self#log#info "New connection from %s" (string_of_address origin) + with exn -> + self#log#important "Error while fetching connection source: %s" + (Printexc.to_string exn)); + Srt.(setsockflag client sndsyn true); + Srt.(setsockflag client rcvsyn true); + Utils.optional_apply + (fun v -> Srt.(setsockflag client sndtimeo v)) + write_timeout; + Utils.optional_apply + (fun v -> Srt.(setsockflag client rcvtimeo v)) + read_timeout; + if self#should_stop then ( + close_socket client; + raise Done); + self#mutexify + (fun () -> + client_sockets <- (origin, client) :: client_sockets; + !on_connect ()) + () + with exn -> + let bt = Printexc.get_raw_backtrace () in + Srt.close client; + Printexc.raise_with_backtrace exn bt with exn -> self#log#debug "Failed to connect: %s" (Printexc.to_string exn) in diff --git a/src/core/tools/utils.ml b/src/core/tools/utils.ml index b341f3b0ca..f78f0450b1 100644 --- a/src/core/tools/utils.ml +++ b/src/core/tools/utils.ml @@ -521,3 +521,5 @@ let is_docker = Lazy.from_fun (fun () -> Sys.unix && Sys.command "grep 'docker\\|lxc' /proc/1/cgroup >/dev/null 2>&1" = 0) + +let optional_apply fn = function None -> () | Some v -> fn v