From 8f3ec1b86ce4b14e765c8d0ccceb73c3ea1da288 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 18 Dec 2024 17:09:06 -0800 Subject: [PATCH] lighthouse, manager: support multiple quorum rooms --- proto/torchft.proto | 16 +- src/lib.rs | 2 + src/lighthouse.rs | 386 +++++++++++++++++++++++------------------- src/manager.rs | 57 ++++--- templates/status.html | 18 +- torchft/manager.py | 11 +- torchft/torchft.pyi | 2 +- 7 files changed, 287 insertions(+), 205 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index e84855c..3bf03b3 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -50,7 +50,11 @@ message Quorum { } message LighthouseQuorumRequest { - QuorumMember requester = 1; + // room_id is the specific quorum channel to use. All workers/replicas + // participating in the quorum must specify the same channel. + // Multiple channels can be active simultaneously. + string room_id = 1; + QuorumMember requester = 2; } message LighthouseQuorumResponse { @@ -69,9 +73,13 @@ service LighthouseService { } message ManagerQuorumRequest { - int64 rank = 1; - int64 step = 2; - string checkpoint_server_addr = 3; + // room_id is the specific quorum channel to use. All workers/replicas + // participating in the quorum must specify the same channel. + // Multiple channels can be active simultaneously. + string room_id = 1; + int64 rank = 2; + int64 step = 3; + string checkpoint_server_addr = 4; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 16ae317..e4d84a0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -105,12 +105,14 @@ impl ManagerClient { fn quorum( &mut self, py: Python<'_>, + room_id: String, rank: i64, step: i64, checkpoint_server_addr: String, ) -> PyResult<(i64, i64, i64, String, String, i64, Option, i64, bool)> { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { + room_id: room_id, rank: rank, step: step, checkpoint_server_addr: checkpoint_server_addr, diff --git a/src/lighthouse.rs b/src/lighthouse.rs index 6f76d41..578d16a 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -43,12 +43,16 @@ struct QuorumMemberDetails { member: QuorumMember, } -struct State { +struct RoomState { + room_id: String, channel: broadcast::Sender, participants: HashMap, prev_quorum: Option, quorum_id: i64, +} +struct State { + rooms: HashMap, // heartbeat information // replica_id -> last heartbeat heartbeats: HashMap, @@ -85,86 +89,76 @@ fn quorum_changed(a: &Vec, b: &Vec) -> bool { return a_ids != b_ids; } -impl Lighthouse { - pub async fn new(opt: LighthouseOpt) -> Result> { - let (tx, _) = broadcast::channel(16); - let listener = tokio::net::TcpListener::bind(&opt.bind).await?; - Ok(Arc::new(Self { - state: Mutex::new(State { - participants: HashMap::new(), - channel: tx, - prev_quorum: None, - quorum_id: 0, - heartbeats: HashMap::new(), - }), - opt: opt, - local_addr: listener.local_addr()?, - listener: Mutex::new(Some(listener)), - })) - } - - // Checks whether the quorum is valid and an explanation for the state. - async fn quorum_valid(&self) -> (bool, String) { - let state = self.state.lock().await; +// Checks whether the quorum is valid and an explanation for the state. +fn quorum_valid(state: &RoomState, opt: &LighthouseOpt) -> (bool, String) { + let mut first_joined = Instant::now(); - let mut first_joined = Instant::now(); - - for details in state.participants.values() { - if details.joined < first_joined { - first_joined = details.joined; - } + for details in state.participants.values() { + if details.joined < first_joined { + first_joined = details.joined; } + } - if state.prev_quorum.is_some() { - let mut is_fast_quorum = true; - let prev_quorum = state.prev_quorum.as_ref().unwrap(); - - for prev_member in prev_quorum.participants.iter() { - if !state.participants.contains_key(&prev_member.replica_id) { - is_fast_quorum = false; - } - } + if state.prev_quorum.is_some() { + let mut is_fast_quorum = true; + let prev_quorum = state.prev_quorum.as_ref().unwrap(); - if is_fast_quorum { - return (is_fast_quorum, format!("Fast quorum found!")); + for prev_member in prev_quorum.participants.iter() { + if !state.participants.contains_key(&prev_member.replica_id) { + is_fast_quorum = false; } } - if state.participants.len() < self.opt.min_replicas as usize { - return ( - false, - format!( - "No quorum, only have {} participants, need {}", - state.participants.len(), - self.opt.min_replicas - ), - ); + if is_fast_quorum { + return (is_fast_quorum, format!("Fast quorum found!")); } + } - // Quorum is valid at this point but lets wait for stragglers. + if state.participants.len() < opt.min_replicas as usize { + return ( + false, + format!( + "No quorum, only have {} participants, need {}", + state.participants.len(), + opt.min_replicas + ), + ); + } - if Instant::now().duration_since(first_joined) - < Duration::from_millis(self.opt.join_timeout_ms) - { - return ( - false, - format!( - "Valid quorum with {} participants, waiting for stragglers due to join timeout", - state.participants.len() - ), - ); - } + // Quorum is valid at this point but lets wait for stragglers. + + if Instant::now().duration_since(first_joined) < Duration::from_millis(opt.join_timeout_ms) { + return ( + false, + format!( + "Valid quorum with {} participants, waiting for stragglers due to join timeout", + state.participants.len() + ), + ); + } + + (true, format!("Valid quorum found")) +} - (true, format!("Valid quorum found")) +impl Lighthouse { + pub async fn new(opt: LighthouseOpt) -> Result> { + let listener = tokio::net::TcpListener::bind(&opt.bind).await?; + Ok(Arc::new(Self { + state: Mutex::new(State { + rooms: HashMap::new(), + heartbeats: HashMap::new(), + }), + opt: opt, + local_addr: listener.local_addr()?, + listener: Mutex::new(Some(listener)), + })) } - async fn _quorum_tick(self: Arc) -> Result<()> { - // TODO: these should probably run under the same lock - let (quorum_met, reason) = self.quorum_valid().await; - info!("{}", reason); + fn _quorum_tick(self: Arc, state: &mut RoomState) -> Result<()> { + let (quorum_met, reason) = quorum_valid(state, &self.opt); + info!("{}: {}", state.room_id, reason); if quorum_met { - let mut state = self.state.lock().await; let mut participants: Vec = state .participants .values() @@ -184,8 +178,8 @@ impl Lighthouse { { state.quorum_id += 1; info!( - "Detected quorum change, bumping quorum_id to {}", - state.quorum_id + "{}: Detected quorum change, bumping quorum_id to {}", + state.room_id, state.quorum_id ); } @@ -195,7 +189,7 @@ impl Lighthouse { created: Some(SystemTime::now().into()), }; - info!("Quorum! {:?}", quorum); + info!("{}: Quorum! {:?}", state.room_id, quorum); state.prev_quorum = Some(quorum.clone()); state.participants.clear(); @@ -209,7 +203,12 @@ impl Lighthouse { async fn _run_quorum(self: Arc) -> Result<()> { loop { - self.clone()._quorum_tick().await?; + { + let mut state = self.state.lock().await; + for (_room_id, room) in &mut state.rooms { + self.clone()._quorum_tick(room)?; + } + } sleep(Duration::from_millis(self.opt.quorum_tick_ms)).await; } @@ -277,33 +276,45 @@ impl Lighthouse { } async fn get_status(self: Arc) -> Html { - let (_, quorum_status) = self.quorum_valid().await; - let template = { let state = self.state.lock().await; - let max_step = { - if let Some(quorum) = state.prev_quorum.clone() { - quorum - .participants - .iter() - .map(|p| p.step) - .max() - .unwrap_or(-1) - } else { - -1 - } - }; + let rooms = state + .rooms + .iter() + .map(|(room_id, room)| { + let (_, quorum_status) = quorum_valid(&room, &self.opt); + + let max_step = { + if let Some(quorum) = room.prev_quorum.clone() { + quorum + .participants + .iter() + .map(|p| p.step) + .max() + .unwrap_or(-1) + } else { + -1 + } + }; + + RoomStatus { + room_id: room_id.clone(), + quorum_id: room.quorum_id, + prev_quorum: room.prev_quorum.clone(), + quorum_status: quorum_status, + + max_step: max_step, + } + }) + .collect(); StatusTemplate { - quorum_id: state.quorum_id, - prev_quorum: state.prev_quorum.clone(), + rooms: rooms, heartbeats: state.heartbeats.clone(), - quorum_status: quorum_status, old_age_threshold: Instant::now() .checked_sub(Duration::from_secs(1)) .unwrap_or(Instant::now()), - max_step: max_step, } }; Html(template.render().unwrap()) @@ -312,13 +323,16 @@ impl Lighthouse { async fn kill(self: Arc, Path(replica_id): Path) -> Result<(), AppError> { let addr = 'addr: { let state = self.state.lock().await; - if state.prev_quorum.is_none() { - return Err(AppError(anyhow!("failed to find replica"))); - } - for member in state.prev_quorum.clone().unwrap().participants { - if member.replica_id == replica_id { - break 'addr member.address; + for (_room_id, room) in &state.rooms { + if room.prev_quorum.is_none() { + return Err(AppError(anyhow!("failed to find replica"))); + } + + for member in room.prev_quorum.clone().unwrap().participants { + if member.replica_id == replica_id { + break 'addr member.address; + } } } return Err(AppError(anyhow!("failed to find replica"))); @@ -341,8 +355,9 @@ impl LighthouseService for Arc { &self, request: Request, ) -> Result, Status> { - let requester = request - .into_inner() + let req = request.into_inner(); + let room_id = req.room_id; + let requester = req .requester .ok_or_else(|| return Status::invalid_argument("missing requester"))?; @@ -350,21 +365,40 @@ impl LighthouseService for Arc { let mut rx = { let mut state = self.state.lock().await; - state.participants.insert( + + if !state.rooms.contains_key(&room_id) { + let (tx, _) = broadcast::channel(16); + + state.rooms.insert( + room_id.clone(), + RoomState { + room_id: room_id.clone(), + participants: HashMap::new(), + channel: tx, + prev_quorum: None, + quorum_id: 0, + }, + ); + } + + let room = state.rooms.get_mut(&room_id).unwrap(); + + room.participants.insert( requester.replica_id.clone(), QuorumMemberDetails { joined: Instant::now(), member: requester, }, ); - state.channel.subscribe() - }; + let rx = room.channel.subscribe(); - // proactively run quorum tick - self.clone() - ._quorum_tick() - .await - .map_err(|e| Status::from_error(e.into()))?; + // proactively run quorum tick + self.clone() + ._quorum_tick(room) + .map_err(|e| Status::from_error(e.into()))?; + + rx + }; let quorum = rx.recv().await.map_err(|e| Status::from_error(e.into()))?; @@ -398,13 +432,18 @@ struct IndexTemplate {} #[derive(Template)] #[template(path = "status.html")] struct StatusTemplate { - prev_quorum: Option, - quorum_id: i64, + rooms: Vec, heartbeats: HashMap, - quorum_status: String, // visualization thresholds old_age_threshold: Instant, +} + +struct RoomStatus { + room_id: String, + prev_quorum: Option, + quorum_id: i64, + quorum_status: String, max_step: i64, } @@ -442,16 +481,6 @@ mod tests { use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; - async fn lighthouse_test_new() -> Result> { - let opt = LighthouseOpt { - min_replicas: 1, - bind: "[::]:0".to_string(), - join_timeout_ms: 60 * 60 * 1000, // 1hr - quorum_tick_ms: 10, - }; - Lighthouse::new(opt).await - } - async fn lighthouse_client_new(addr: String) -> Result> { let conn = Endpoint::new(addr)? .connect_timeout(Duration::from_secs(10)) @@ -462,79 +491,95 @@ mod tests { #[tokio::test] async fn test_quorum_join_timeout() -> Result<()> { - let lighthouse = lighthouse_test_new().await?; - assert!(!lighthouse.quorum_valid().await.0); + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + }; - { - let mut state = lighthouse.state.lock().await; - state.participants.insert( - "a".to_string(), - QuorumMemberDetails { - joined: Instant::now(), - member: QuorumMember { - replica_id: "a".to_string(), - address: "".to_string(), - store_address: "".to_string(), - step: 1, - world_size: 1, - }, + let mut state = RoomState { + room_id: "test".to_string(), + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + }; + + assert!(!quorum_valid(&state, &opt).0); + + state.participants.insert( + "a".to_string(), + QuorumMemberDetails { + joined: Instant::now(), + member: QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, }, - ); - } + }, + ); - assert!(!lighthouse.quorum_valid().await.0); + assert!(!quorum_valid(&state, &opt).0); - { - let mut state = lighthouse.state.lock().await; - state.participants.get_mut("a").unwrap().joined = - Instant::now().sub(Duration::from_secs(10 * 60 * 60)); - } + state.participants.get_mut("a").unwrap().joined = + Instant::now().sub(Duration::from_secs(10 * 60 * 60)); - assert!(lighthouse.quorum_valid().await.0); + assert!(quorum_valid(&state, &opt).0); Ok(()) } #[tokio::test] async fn test_quorum_fast_prev_quorum() -> Result<()> { - let lighthouse = lighthouse_test_new().await?; - assert!(!lighthouse.quorum_valid().await.0); + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + }; - { - let mut state = lighthouse.state.lock().await; - state.participants.insert( - "a".to_string(), - QuorumMemberDetails { - joined: Instant::now(), - member: QuorumMember { - replica_id: "a".to_string(), - address: "".to_string(), - store_address: "".to_string(), - step: 1, - world_size: 1, - }, - }, - ); - } + let mut state = RoomState { + room_id: "test".to_string(), + channel: broadcast::channel(16).0, + participants: HashMap::new(), + prev_quorum: None, + quorum_id: 0, + }; - assert!(!lighthouse.quorum_valid().await.0); + assert!(!quorum_valid(&state, &opt).0); - { - let mut state = lighthouse.state.lock().await; - state.prev_quorum = Some(Quorum { - quorum_id: 1, - participants: vec![QuorumMember { + state.participants.insert( + "a".to_string(), + QuorumMemberDetails { + joined: Instant::now(), + member: QuorumMember { replica_id: "a".to_string(), address: "".to_string(), store_address: "".to_string(), step: 1, world_size: 1, - }], - created: Some(SystemTime::now().into()), - }); - } + }, + }, + ); + + assert!(!quorum_valid(&state, &opt).0); + + state.prev_quorum = Some(Quorum { + quorum_id: 1, + participants: vec![QuorumMember { + replica_id: "a".to_string(), + address: "".to_string(), + store_address: "".to_string(), + step: 1, + world_size: 1, + }], + created: Some(SystemTime::now().into()), + }); - assert!(lighthouse.quorum_valid().await.0); + assert!(quorum_valid(&state, &opt).0); Ok(()) } @@ -563,6 +608,7 @@ mod tests { { let request = tonic::Request::new(LighthouseQuorumRequest { + room_id: "test".to_string(), requester: Some(QuorumMember { replica_id: "foo".to_string(), address: "".to_string(), diff --git a/src/manager.rs b/src/manager.rs index 275c6d3..6200b27 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -36,10 +36,14 @@ use log::{info, warn}; #[cfg(test)] use std::{println as info, println as warn}; -struct ManagerState { +struct RoomState { channel: broadcast::Sender, - participants: u64, + participants: HashSet, +} + +struct ManagerState { checkpoint_servers: HashMap, + rooms: HashMap, should_commit_channel: broadcast::Sender, should_commit_failures: HashSet, @@ -81,11 +85,10 @@ impl Manager { store_addr: String, world_size: u64, ) -> Result> { - let (tx, _) = broadcast::channel(16); - let (should_commit_tx, _) = broadcast::channel(16); - let listener = tokio::net::TcpListener::bind(&bind).await?; + let (should_commit_tx, _) = broadcast::channel(16); + Ok(Arc::new(Self { replica_id: replica_id, lighthouse_addr: lighthouse_addr, @@ -93,9 +96,8 @@ impl Manager { store_address: store_addr, world_size: world_size, state: Mutex::new(ManagerState { - channel: tx, - participants: 0, checkpoint_servers: HashMap::new(), + rooms: HashMap::new(), should_commit_channel: should_commit_tx, should_commit_count: HashSet::new(), @@ -180,8 +182,9 @@ impl ManagerService for Arc { ) -> Result, Status> { let req = request.into_inner(); let rank = req.rank; + let room_id = req.room_id; - info!("got quorum request for rank {}", rank); + info!("{}: got quorum request for rank {}", room_id, rank); let mut rx = { let mut state = self.state.lock().await; @@ -192,13 +195,27 @@ impl ManagerService for Arc { .checkpoint_servers .insert(req.rank, req.checkpoint_server_addr.clone()); + if !state.rooms.contains_key(&room_id) { + let (tx, _) = broadcast::channel(16); + + state.rooms.insert( + room_id.clone(), + RoomState { + channel: tx, + participants: HashSet::new(), + }, + ); + } + + let room = state.rooms.get_mut(&room_id).unwrap(); + // TODO check step - state.participants += 1; - let rx = state.channel.subscribe(); + room.participants.insert(rank); + let rx = room.channel.subscribe(); - if state.participants >= self.world_size { - state.participants = 0; - info!("all workers joined -- starting quorum"); + if room.participants.len() as u64 >= self.world_size { + room.participants.clear(); + info!("{}: all workers joined -- starting quorum", room_id); // TODO: don't hold the lock during quorum @@ -208,6 +225,7 @@ impl ManagerService for Arc { .map_err(|e| Status::from_error(e.into()))?; let request = tonic::Request::new(LighthouseQuorumRequest { + room_id: room_id.clone(), requester: Some(QuorumMember { replica_id: self.replica_id.clone(), address: self.address.clone(), @@ -220,10 +238,9 @@ impl ManagerService for Arc { let response = client.quorum(request).await.unwrap(); let resp = response.into_inner(); - info!("got lighthouse quorum {:?}", resp); + info!("{}: got lighthouse quorum {:?}", room_id, resp); - state - .channel + room.channel .send( resp.quorum .ok_or_else(|| Status::internal("missing quorum"))?, @@ -270,8 +287,8 @@ impl ManagerService for Arc { let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id; if heal { info!( - "healing is required step={}, max_step={}", - req.step, max_step + "{}: healing is required step={}, max_step={}", + room_id, req.step, max_step ); } @@ -288,7 +305,7 @@ impl ManagerService for Arc { heal: heal, }; - info!("returning quorum for rank {}", rank); + info!("{}: returning quorum for rank {}", room_id, rank); Ok(Response::new(reply)) } @@ -455,6 +472,7 @@ mod tests { let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?; let request = tonic::Request::new(ManagerQuorumRequest { + room_id: "room".to_string(), rank: 0, step: 123, checkpoint_server_addr: "addr".to_string(), @@ -509,6 +527,7 @@ mod tests { manager_client_new(manager.address(), Duration::from_secs(10)).await?; let request = tonic::Request::new(ManagerQuorumRequest { + room_id: "room".to_string(), rank: 0, step: 0, checkpoint_server_addr: "addr".to_string(), diff --git a/templates/status.html b/templates/status.html index 429419d..bacd340 100644 --- a/templates/status.html +++ b/templates/status.html @@ -1,10 +1,12 @@ -

Quorum Status

-Current quorum_id: {{quorum_id}}
-Next quorum status: {{quorum_status}} +{% for room in rooms %} +

Room Status: {{room.room_id}}

+ +Current quorum_id: {{room.quorum_id}}
+Next quorum status: {{room.quorum_status}}

Previous Quorum

-{% if let Some(prev_quorum) = prev_quorum %} +{% if let Some(prev_quorum) = room.prev_quorum %} Previous quorum id: {{prev_quorum.quorum_id}}
Quorum age: @@ -14,7 +16,7 @@

Previous Quorum

{% for member in prev_quorum.participants %}
{{ member.replica_id }}
Step: {{ member.step }}
@@ -33,7 +35,9 @@

Previous Quorum

{% endif %} -

Heartbeats

+{% endfor %} + +

Heartbeats

    {% for replica_id in heartbeats.keys() %} @@ -47,5 +51,3 @@

    Heartbeats

    {% endfor %}
- - diff --git a/torchft/manager.py b/torchft/manager.py index 1f76729..ac4d64d 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -313,7 +313,7 @@ def callback( self._pending_work.append(cast(torch.futures.Future[object], fut)) return fut - def start_quorum(self, allow_heal: bool = True) -> None: + def start_quorum(self, room_id: str = "default", allow_heal: bool = True) -> None: """ .. note:: We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly. @@ -329,6 +329,8 @@ def start_quorum(self, allow_heal: bool = True) -> None: If allow_heal is set, the manager will attempt to heal either synchronously before returning or asynchronously prior to any network calls. All replicas must pass the same value to allow_heal. + room_id: (experimental) the room id to use for quorum, this allows + for multiple quorums to be used within the same job. """ # wait for previous quorum to complete @@ -342,7 +344,9 @@ def start_quorum(self, allow_heal: bool = True) -> None: # TODO: we should really be wrapping this whole section in a try-except # block to allow gracefully recovering from issues in PG setup and quorum. - self._quorum_future = self._executor.submit(self._async_quorum, allow_heal) + self._quorum_future = self._executor.submit( + self._async_quorum, room_id=room_id, allow_heal=allow_heal + ) if not self._use_async_quorum: self.wait_quorum() @@ -365,7 +369,7 @@ def wait_quorum(self) -> None: ), "must call start_quorum before wait_quorum" self._quorum_future.result() - def _async_quorum(self, allow_heal: bool) -> None: + def _async_quorum(self, room_id: str, allow_heal: bool) -> None: ( quorum_id, replica_rank, @@ -377,6 +381,7 @@ def _async_quorum(self, allow_heal: bool) -> None: max_world_size, heal, ) = self._client.quorum( + room_id=room_id, rank=self._rank, step=self._step, checkpoint_server_addr=self._ckpt_server.address(), diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index c3fc2b3..aee2947 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -4,7 +4,7 @@ from typing import Optional, Tuple class ManagerClient: def __init__(self, addr: str, timeout: timedelta) -> None: ... def quorum( - self, rank: int, step: int, checkpoint_server_addr: str + self, room_id: str, rank: int, step: int, checkpoint_server_addr: str ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ... def checkpoint_address(self, rank: int) -> str: ... def should_commit(self, rank: int, step: int, should_commit: bool) -> bool: ...