Skip to content
This repository has been archived by the owner on Apr 9, 2024. It is now read-only.

Commit

Permalink
[WIP] many-to-many connections support
Browse files Browse the repository at this point in the history
This patch is an attempt at adding many-to-many connection support
without overhauling the library and minimal breaking changes to the
user.

The core of the idea is to replace Graph::connections from
SecondaryMap<InputId, OutputId> to [..]<InputId, HashSet<OutputId>> and
roll from that.

Ports are parametrized by an argument specifying the maximum non-zero
number of connections (not yet enforced), and if that number is greater
than 1 a so called wide-port is drawn -- it's characterized by being
longer than a typical port and connections made to it are spread out and
can be disconnected individually.

Known limitations:
* Again, max connections per node is not enforced.
* HashSet isn't an adequate data structure for storing the connected
  outputs -- we have no control over ordering, let alone the end user.
* ...?
  • Loading branch information
kamirr committed May 10, 2023
1 parent 0594981 commit 400f6a4
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 118 deletions.
175 changes: 116 additions & 59 deletions egui_node_graph/src/editor_ui.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashSet;
use std::num::NonZeroU32;

use crate::color_hex_utils::*;
use crate::utils::ColorUtils;
Expand All @@ -8,6 +9,7 @@ use egui::epaint::{CubicBezierShape, RectShape};
use egui::*;

pub type PortLocations = std::collections::HashMap<AnyParameterId, Pos2>;
pub type ConnLocations = std::collections::HashMap<(InputId, OutputId), Pos2>;
pub type NodeRects = std::collections::HashMap<NodeId, Rect>;

const DISTANCE_TO_CONNECT: f32 = 10.0;
Expand Down Expand Up @@ -76,6 +78,7 @@ pub struct GraphNodeWidget<'a, NodeData, DataType, ValueType> {
pub position: &'a mut Pos2,
pub graph: &'a mut Graph<NodeData, DataType, ValueType>,
pub port_locations: &'a mut PortLocations,
pub conn_locations: &'a mut ConnLocations,
pub node_rects: &'a mut NodeRects,
pub node_id: NodeId,
pub ongoing_drag: Option<(NodeId, AnyParameterId)>,
Expand Down Expand Up @@ -126,6 +129,9 @@ where
let mut port_locations = PortLocations::new();
let mut node_rects = NodeRects::new();

// actual dest location of each connection
let mut conn_locations = ConnLocations::default();

// The responses returned from node drawing have side effects that are best
// executed at the end of this function.
let mut delayed_responses: Vec<NodeResponse<UserResponse, NodeData>> = vec![];
Expand Down Expand Up @@ -161,6 +167,7 @@ where
position: self.node_positions.get_mut(node_id).unwrap(),
graph: &mut self.graph,
port_locations: &mut port_locations,
conn_locations: &mut conn_locations,
node_rects: &mut node_rects,
node_id,
ongoing_drag: self.connection_in_progress,
Expand Down Expand Up @@ -212,7 +219,7 @@ where
self.node_finder = None;
}

/* Draw connections */
// draw in-progress connections
if let Some((_, ref locator)) = self.connection_in_progress {
let port_type = self.graph.any_param_type(*locator).unwrap();
let connection_color = port_type.data_type_color(user_state);
Expand Down Expand Up @@ -281,15 +288,18 @@ where
draw_connection(ui.painter(), src_pos, dst_pos, connection_color);
}

for (input, output) in self.graph.iter_connections() {
let port_type = self
.graph
.any_param_type(AnyParameterId::Output(output))
.unwrap();
let connection_color = port_type.data_type_color(user_state);
let src_pos = port_locations[&AnyParameterId::Output(output)];
let dst_pos = port_locations[&AnyParameterId::Input(input)];
draw_connection(ui.painter(), src_pos, dst_pos, connection_color);
// draw existing connections
for (input, outputs) in self.graph.iter_connection_groups() {
for &output in outputs.iter() {
let port_type = self
.graph
.any_param_type(AnyParameterId::Output(output))
.unwrap();
let connection_color = port_type.data_type_color(user_state);
let src_pos = port_locations[&AnyParameterId::Output(output)];
let dst_pos = conn_locations[&(input, output)];
draw_connection(ui.painter(), src_pos, dst_pos, connection_color);
}
}

/* Handle responses from drawing nodes */
Expand Down Expand Up @@ -334,7 +344,7 @@ where
}
NodeResponse::DisconnectEvent { input, output } => {
let other_node = self.graph.get_output(*output).node;
self.graph.remove_connection(*input);
self.graph.remove_connection(*input, *output);
self.connection_in_progress =
Some((other_node, AnyParameterId::Output(*output)));
}
Expand Down Expand Up @@ -569,38 +579,41 @@ where
for (param_name, param_id) in inputs {
if self.graph[param_id].shown_inline {
let height_before = ui.min_rect().bottom();
// NOTE: We want to pass the `user_data` to
// `value_widget`, but we can't since that would require
// borrowing the graph twice. Here, we make the
// assumption that the value is cheaply replaced, and
// use `std::mem::take` to temporarily replace it with a
// dummy value. This requires `ValueType` to implement
// Default, but results in a totally safe alternative.
let mut value = std::mem::take(&mut self.graph[param_id].value);

if self.graph.connection(param_id).is_some() {
let node_responses = value.value_widget_connected(
&param_name,
self.node_id,
ui,
user_state,
&self.graph[self.node_id].user_data,
);

responses.extend(node_responses.into_iter().map(NodeResponse::User));
} else {
let node_responses = value.value_widget(
&param_name,
self.node_id,
ui,
user_state,
&self.graph[self.node_id].user_data,
);

responses.extend(node_responses.into_iter().map(NodeResponse::User));
}

self.graph[param_id].value = value;
if self.graph[param_id].max_connections == NonZeroU32::new(1) {
// NOTE: We want to pass the `user_data` to
// `value_widget`, but we can't since that would require
// borrowing the graph twice. Here, we make the
// assumption that the value is cheaply replaced, and
// use `std::mem::take` to temporarily replace it with a
// dummy value. This requires `ValueType` to implement
// Default, but results in a totally safe alternative.
let mut value = std::mem::take(&mut self.graph[param_id].value);

if !self.graph.connections(param_id).is_empty() {
let node_responses = value.value_widget_connected(
&param_name,
self.node_id,
ui,
user_state,
&self.graph[self.node_id].user_data,
);

responses.extend(node_responses.into_iter().map(NodeResponse::User));
} else {
let node_responses = value.value_widget(
&param_name,
self.node_id,
ui,
user_state,
&self.graph[self.node_id].user_data,
);

responses.extend(node_responses.into_iter().map(NodeResponse::User));
}

self.graph[param_id].value = value;
}

let height_after = ui.min_rect().bottom();
input_port_heights.push((height_before + height_after) / 2.0);
Expand Down Expand Up @@ -651,16 +664,30 @@ where
responses: &mut Vec<NodeResponse<UserResponse, NodeData>>,
param_id: AnyParameterId,
port_locations: &mut PortLocations,
conn_locations: &mut ConnLocations,
ongoing_drag: Option<(NodeId, AnyParameterId)>,
is_connected_input: bool,
wide_port: bool,
connections: usize,
) where
DataType: DataTypeTrait<UserState>,
UserResponse: UserResponseTrait,
NodeData: NodeDataTrait,
{
let port_type = graph.any_param_type(param_id).unwrap();

let port_rect = Rect::from_center_size(port_pos, egui::vec2(10.0, 10.0));
let port_rect = Rect::from_center_size(
port_pos,
egui::vec2(
10.0,
if wide_port {
5.0 + (7.5 * (connections + 1) as f32).max(5.0)
} else {
10.0
},
),
);

port_locations.insert(param_id, port_rect.center_top() + Vec2::new(2.5, 5.0));

let sense = if ongoing_drag.is_some() {
Sense::hover()
Expand All @@ -672,7 +699,9 @@ where

// Check if the distance between the port and the mouse is the distance to connect
let close_enough = if let Some(pointer_pos) = ui.ctx().pointer_hover_pos() {
port_rect.center().distance(pointer_pos) < DISTANCE_TO_CONNECT
port_rect
.expand(DISTANCE_TO_CONNECT / 2.0)
.contains(pointer_pos)
} else {
false
};
Expand All @@ -682,19 +711,41 @@ where
} else {
port_type.data_type_color(user_state)
};
ui.painter()
.circle(port_rect.center(), 5.0, port_color, Stroke::NONE);

if wide_port {
ui.painter().rect_filled(port_rect, 5.0, port_color);
} else {
ui.painter()
.circle(port_rect.center(), 5.0, port_color, Stroke::NONE);
}

if connections > 0 {
let input = param_id.assume_input();
for (i, output) in graph.connections(input).into_iter().enumerate() {
let dst_pos = port_locations[&AnyParameterId::Input(input)]
+ Vec2::new(0.0, 7.5) * i as f32;
conn_locations.insert((input, output), dst_pos);
}
}

if resp.drag_started() {
if is_connected_input {
let input = param_id.assume_input();
let corresp_output = graph
.connection(input)
.expect("Connection data should be valid");
responses.push(NodeResponse::DisconnectEvent {
input: param_id.assume_input(),
output: corresp_output,
});
if connections > 0 {
if let Some(mouse_pos) = ui.input(|in_state| in_state.pointer.hover_pos()) {
let input = param_id.assume_input();
let outputs = graph.connections(input).into_iter();
let output = outputs
.min_by(|&out1, &out2| {
let out1_dist = conn_locations[&(input, out1)].distance(mouse_pos);
let out2_dist = conn_locations[&(input, out2)].distance(mouse_pos);

out1_dist.partial_cmp(&out2_dist).unwrap()
})
.unwrap();
responses.push(NodeResponse::DisconnectEvent {
input: param_id.assume_input(),
output,
});
}
} else {
responses.push(NodeResponse::ConnectEventStarted(node_id, param_id));
}
Expand All @@ -717,8 +768,6 @@ where
}
}
}

port_locations.insert(param_id, port_rect.center());
}

// Input ports
Expand All @@ -744,8 +793,14 @@ where
&mut responses,
AnyParameterId::Input(*param),
self.port_locations,
self.conn_locations,
self.ongoing_drag,
self.graph.connection(*param).is_some(),
self.graph[*param]
.max_connections
.map(NonZeroU32::get)
.unwrap_or(std::u32::MAX)
> 1,
self.graph.connections(*param).len(),
);
}
}
Expand All @@ -766,8 +821,10 @@ where
&mut responses,
AnyParameterId::Output(*param),
self.port_locations,
self.conn_locations,
self.ongoing_drag,
false,
0,
);
}

Expand Down
6 changes: 5 additions & 1 deletion egui_node_graph/src/graph.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{collections::HashSet, num::NonZeroU32};

use super::*;

#[cfg(feature = "persistence")]
Expand Down Expand Up @@ -55,6 +57,8 @@ pub struct InputParam<DataType, ValueType> {
pub kind: InputParamKind,
/// Back-reference to the node containing this parameter.
pub node: NodeId,
/// How many connections can be made with this input. `None` means no limit.
pub max_connections: Option<NonZeroU32>,
/// When true, the node is shown inline inside the node graph.
#[cfg_attr(feature = "persistence", serde(default = "shown_inline_default"))]
pub shown_inline: bool,
Expand Down Expand Up @@ -87,5 +91,5 @@ pub struct Graph<NodeData, DataType, ValueType> {
pub outputs: SlotMap<OutputId, OutputParam<DataType>>,
// Connects the input of a node, to the output of its predecessor that
// produces it
pub connections: SecondaryMap<InputId, OutputId>,
pub connections: SecondaryMap<InputId, HashSet<OutputId>>,
}
Loading

0 comments on commit 400f6a4

Please sign in to comment.