Skip to content

Commit

Permalink
Teach the comm manager how to handle info requests (#548)
Browse files Browse the repository at this point in the history
And use an info request in `handle_comm_info_request()` rather than trying to maintain our own copy of `open_comms`, which can get out of sync with the manager's list of open comms
  • Loading branch information
DavisVaughan authored Oct 1, 2024
1 parent 90632ca commit af8f1f3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 79 deletions.
54 changes: 26 additions & 28 deletions crates/amalthea/src/comm/comm_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ use stdext::result::ResultOrLog;
use stdext::spawn;

use crate::comm::comm_channel::CommMsg;
use crate::comm::event::CommInfo;
use crate::comm::event::CommManagerEvent;
use crate::comm::event::CommShellEvent;
use crate::comm::event::CommManagerInfoReply;
use crate::comm::event::CommManagerRequest;
use crate::socket::comm::CommInitiator;
use crate::socket::comm::CommSocket;
use crate::socket::iopub::IOPubMessage;
Expand All @@ -29,7 +31,6 @@ pub struct CommManager {
open_comms: Vec<CommSocket>,
iopub_tx: Sender<IOPubMessage>,
comm_event_rx: Receiver<CommManagerEvent>,
comm_shell_tx: Sender<CommShellEvent>,
pending_rpcs: HashMap<String, JupyterHeader>,
}

Expand All @@ -43,32 +44,22 @@ impl CommManager {
* - `comm_event_rx`: The channel to receive messages about changes to the set
* (or state) of open comms.
*/
pub fn start(
iopub_tx: Sender<IOPubMessage>,
comm_event_rx: Receiver<CommManagerEvent>,
) -> Receiver<CommShellEvent> {
let (comm_changed_tx, comm_changed_rx) = crossbeam::channel::unbounded();
pub fn start(iopub_tx: Sender<IOPubMessage>, comm_event_rx: Receiver<CommManagerEvent>) {
spawn!("comm-manager", move || {
let mut comm_manager = CommManager::new(iopub_tx, comm_event_rx, comm_changed_tx);
let mut comm_manager = CommManager::new(iopub_tx, comm_event_rx);
loop {
comm_manager.execution_thread();
}
});
return comm_changed_rx;
}

/**
* Create a new CommManager.
*/
pub fn new(
iopub_tx: Sender<IOPubMessage>,
comm_event_rx: Receiver<CommManagerEvent>,
comm_shell_tx: Sender<CommShellEvent>,
) -> Self {
pub fn new(iopub_tx: Sender<IOPubMessage>, comm_event_rx: Receiver<CommManagerEvent>) -> Self {
Self {
iopub_tx,
comm_event_rx,
comm_shell_tx,
open_comms: Vec::<CommSocket>::new(),
pending_rpcs: HashMap::<String, JupyterHeader>::new(),
}
Expand Down Expand Up @@ -107,17 +98,8 @@ impl CommManager {
return;
}
match comm_event.unwrap() {
// A Comm was opened; notify everyone
// A Comm was opened
CommManagerEvent::Opened(comm_socket, val) => {
// Notify the shell handler; it maintains a list of open
// comms so that the frontend can query for comm state
self.comm_shell_tx
.send(CommShellEvent::Added(
comm_socket.comm_id.clone(),
comm_socket.comm_name.clone(),
))
.unwrap();

// Notify the frontend, if this request originated from the back end
if comm_socket.initiator == CommInitiator::BackEnd {
self.iopub_tx
Expand Down Expand Up @@ -182,10 +164,9 @@ impl CommManager {
.send(CommMsg::Close)
.or_log_error("Failed to send comm_close to comm.");

// Remove it from our list of open comms
self.open_comms.remove(index);
self.comm_shell_tx
.send(CommShellEvent::Removed(comm_id))
.unwrap();

info!(
"Comm channel closed; there are now {} open comms",
self.open_comms.len()
Expand All @@ -197,6 +178,23 @@ impl CommManager {
);
}
},

// A comm manager request
CommManagerEvent::Request(req) => match req {
// Requesting information about the open comms
CommManagerRequest::Info(tx) => {
let comms: Vec<CommInfo> = self
.open_comms
.iter()
.map(|comm| CommInfo {
id: comm.comm_id.clone(),
name: comm.comm_name.clone(),
})
.collect();

tx.send(CommManagerInfoReply { comms }).unwrap();
},
},
}
} else {
// Otherwise, the message was received on one of the open comms.
Expand Down
25 changes: 16 additions & 9 deletions crates/amalthea/src/comm/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*
*/

use crossbeam::channel::Sender;
use serde_json::Value;

use crate::comm::comm_channel::CommMsg;
Expand All @@ -27,18 +28,24 @@ pub enum CommManagerEvent {

/// A Comm was closed
Closed(String),

/// A comm manager request
Request(CommManagerRequest),
}

/**
* Enumeration of events that can be sent by the comm manager. These notify
* other parts of the application that a comm was opened or closed, so that they
* can update their state.
* Enumeration of requests that can be received by the comm manager.
*/
pub enum CommShellEvent {
/// A new comm was opened. The first value is the comm ID, and the second
/// value is the comm name.
Added(String, String),
pub enum CommManagerRequest {
/// Open comm information
Info(Sender<CommManagerInfoReply>),
}

pub struct CommManagerInfoReply {
pub comms: Vec<CommInfo>,
}

/// A comm was removed. The value is the comm ID.
Removed(String),
pub struct CommInfo {
pub id: String,
pub name: String,
}
6 changes: 1 addition & 5 deletions crates/amalthea/src/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use stdext::unwrap;

use crate::comm::comm_manager::CommManager;
use crate::comm::event::CommManagerEvent;
use crate::comm::event::CommShellEvent;
use crate::connection_file::ConnectionFile;
use crate::error::Error;
use crate::language::control_handler::ControlHandler;
Expand Down Expand Up @@ -126,7 +125,7 @@ impl Kernel {
// Create the comm manager thread
let iopub_tx = self.create_iopub_tx();
let comm_manager_rx = self.comm_manager_rx.clone();
let comm_changed_rx = CommManager::start(iopub_tx, comm_manager_rx);
CommManager::start(iopub_tx, comm_manager_rx);

// Create the Shell ROUTER/DEALER socket and start a thread to listen
// for client messages.
Expand All @@ -149,7 +148,6 @@ impl Kernel {
shell_socket,
iopub_tx_clone,
comm_manager_tx_clone,
comm_changed_rx,
shell_clone,
lsp_handler_clone,
dap_handler_clone,
Expand Down Expand Up @@ -311,7 +309,6 @@ impl Kernel {
socket: Socket,
iopub_tx: Sender<IOPubMessage>,
comm_manager_tx: Sender<CommManagerEvent>,
comm_changed_rx: Receiver<CommShellEvent>,
shell_handler: Arc<Mutex<dyn ShellHandler>>,
lsp_handler: Option<Arc<Mutex<dyn ServerHandler>>>,
dap_handler: Option<Arc<Mutex<dyn ServerHandler>>>,
Expand All @@ -320,7 +317,6 @@ impl Kernel {
socket,
iopub_tx.clone(),
comm_manager_tx,
comm_changed_rx,
shell_handler,
lsp_handler,
dap_handler,
Expand Down
56 changes: 19 additions & 37 deletions crates/amalthea/src/socket/shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use stdext::result::ResultOrLog;
use crate::comm::comm_channel::Comm;
use crate::comm::comm_channel::CommMsg;
use crate::comm::event::CommManagerEvent;
use crate::comm::event::CommShellEvent;
use crate::comm::event::CommManagerInfoReply;
use crate::comm::event::CommManagerRequest;
use crate::comm::server_comm::ServerComm;
use crate::error::Error;
use crate::language::server_handler::ServerHandler;
Expand Down Expand Up @@ -70,14 +71,8 @@ pub struct Shell {
/// Language-provided DAP handler object
dap_handler: Option<Arc<Mutex<dyn ServerHandler>>>,

/// Set of open comm channels; vector of (comm_id, target_name)
open_comms: Vec<(String, String)>,

/// Channel used to deliver comm events to the comm manager
comm_manager_tx: Sender<CommManagerEvent>,

/// Channel used to receive comm events from the comm manager
comm_shell_rx: Receiver<CommShellEvent>,
}

impl Shell {
Expand All @@ -93,7 +88,6 @@ impl Shell {
socket: Socket,
iopub_tx: Sender<IOPubMessage>,
comm_manager_tx: Sender<CommManagerEvent>,
comm_shell_rx: Receiver<CommShellEvent>,
shell_handler: Arc<Mutex<dyn ShellHandler>>,
lsp_handler: Option<Arc<Mutex<dyn ServerHandler>>>,
dap_handler: Option<Arc<Mutex<dyn ServerHandler>>>,
Expand All @@ -104,9 +98,7 @@ impl Shell {
shell_handler,
lsp_handler,
dap_handler,
open_comms: Vec::new(),
comm_manager_tx,
comm_shell_rx,
}
}

Expand All @@ -124,9 +116,6 @@ impl Shell {
},
};

// Process any comm changes before handling the message
self.process_comm_changes();

// Handle the message; any failures while handling the messages are
// delivered to the client instead of reported up the stack, so the
// only errors likely here are "can't deliver to client"
Expand Down Expand Up @@ -283,16 +272,28 @@ impl Shell {
) -> Result<(), Error> {
log::info!("Received request for open comms: {req:?}");

// Convert our internal map of open comms to a JSON object
// One off sender/receiver pair for this request
let (tx, rx) = crossbeam::channel::bounded(1);

// Request the list of open comms from the comm manager
self.comm_manager_tx
.send(CommManagerEvent::Request(CommManagerRequest::Info(tx)))
.unwrap();

// Wait on the reply
let CommManagerInfoReply { comms } = rx.recv().unwrap();

// Convert to a JSON object
let mut info = serde_json::Map::new();
for (comm_id, target_name) in &self.open_comms {

for comm in comms.into_iter() {
// Only include comms that match the target name, if one was specified
if req.content.target_name.is_empty() || &req.content.target_name == target_name {
if req.content.target_name.is_empty() || req.content.target_name == comm.name {
let comm_info_target = CommInfoTargetName {
target_name: target_name.clone(),
target_name: comm.name,
};
let comm_info = serde_json::to_value(comm_info_target).unwrap();
info.insert(comm_id.clone(), comm_info);
info.insert(comm.id, comm_info);
}
}

Expand Down Expand Up @@ -575,23 +576,4 @@ impl Shell {
Err(err) => req.send_error::<InspectReply>(err, &self.socket),
}
}

// Process changes to open comms
fn process_comm_changes(&mut self) {
if let Ok(comm_changed) = self.comm_shell_rx.try_recv() {
match comm_changed {
// Comm was added; add it to the list of open comms
CommShellEvent::Added(comm_id, target_name) => {
self.open_comms.push((comm_id, target_name));
},

// Comm was removed; remove it from the list of open comms
CommShellEvent::Removed(comm_id) => {
self.open_comms.retain(|(id, _)| id != &comm_id);
},
}
}
// No need to log errors; `try_recv` will return an error if there are no
// messages to receive
}
}

0 comments on commit af8f1f3

Please sign in to comment.