From 7ad3f915f775fae8e664059faf68894d5316e8cb Mon Sep 17 00:00:00 2001 From: beltram Date: Mon, 8 Jan 2024 14:28:55 +0100 Subject: [PATCH] perf: prevent unnecessary boxing in 'copath()' and use iterators --- .../binary_tree/array_representation/diff.rs | 2 +- .../array_representation/treemath.rs | 37 ++++++++++--------- openmls/src/treesync/diff.rs | 4 +- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/openmls/src/binary_tree/array_representation/diff.rs b/openmls/src/binary_tree/array_representation/diff.rs index 1733819913..d80f1dc7ee 100644 --- a/openmls/src/binary_tree/array_representation/diff.rs +++ b/openmls/src/binary_tree/array_representation/diff.rs @@ -222,7 +222,7 @@ impl<'a, L: Clone + Debug + Default, P: Clone + Debug + Default> AbDiff<'a, L, P } /// Returns the copath of a leaf node. - pub(crate) fn copath(&self, leaf_index: LeafNodeIndex) -> Vec { + pub(crate) fn copath(&self, leaf_index: LeafNodeIndex) -> impl Iterator { copath(leaf_index, self.size()) } diff --git a/openmls/src/binary_tree/array_representation/treemath.rs b/openmls/src/binary_tree/array_representation/treemath.rs index e8cf4837fc..1357734918 100644 --- a/openmls/src/binary_tree/array_representation/treemath.rs +++ b/openmls/src/binary_tree/array_representation/treemath.rs @@ -244,6 +244,10 @@ impl TreeSize { self.0 = 0; } } + + pub(crate) fn height(&self) -> usize { + log2(self.0) + } } #[test] @@ -363,33 +367,32 @@ pub(crate) fn test_sibling(index: TreeNodeIndex) -> TreeNodeIndex { pub(crate) fn direct_path(node_index: LeafNodeIndex, size: TreeSize) -> Vec { let r = root(size).u32(); - let mut d = vec![]; + let mut direct_path = Vec::with_capacity(size.height()); let mut x = node_index.to_tree_index(); while x != r { let parent = parent(TreeNodeIndex::new(x)); - d.push(parent); + direct_path.push(parent); x = parent.to_tree_index(); } - d + direct_path } /// Copath of a leaf node. -pub(crate) fn copath(leaf_index: LeafNodeIndex, size: TreeSize) -> Vec { +pub(crate) fn copath( + leaf_index: LeafNodeIndex, + size: TreeSize, +) -> impl Iterator { // Start with leaf - let mut full_path = vec![TreeNodeIndex::Leaf(leaf_index)]; + let mut full_path = Vec::with_capacity(size.height()); + full_path.push(TreeNodeIndex::Leaf(leaf_index)); let mut direct_path = direct_path(leaf_index, size); if !direct_path.is_empty() { - // Remove root - direct_path.pop(); + direct_path.pop(); // Remove root } - full_path.append( - &mut direct_path - .iter() - .map(|i| TreeNodeIndex::Parent(*i)) - .collect(), - ); + let parent_direct_path = direct_path.into_iter().map(TreeNodeIndex::Parent); + full_path.extend(parent_direct_path); - full_path.into_iter().map(sibling).collect() + full_path.into_iter().map(sibling) } /// Common ancestor of two leaf nodes, aka the node where their direct paths @@ -426,7 +429,7 @@ pub(crate) fn common_direct_path( x_path.reverse(); y_path.reverse(); - let mut common_path = vec![]; + let mut common_path = Vec::with_capacity(size.height()); for (x, y) in x_path.iter().zip(y_path.iter()) { if x == y { @@ -459,7 +462,7 @@ fn test_node_in_tree() { for test in tests.iter() { assert!(is_node_in_tree( TreeNodeIndex::new(test.0), - TreeSize::new(test.1) + TreeSize::new(test.1), )); } } @@ -470,7 +473,7 @@ fn test_node_not_in_tree() { for test in tests.iter() { assert!(!is_node_in_tree( TreeNodeIndex::new(test.0), - TreeSize::new(test.1) + TreeSize::new(test.1), )); } } diff --git a/openmls/src/treesync/diff.rs b/openmls/src/treesync/diff.rs index cc9baaec32..543e30a8fe 100644 --- a/openmls/src/treesync/diff.rs +++ b/openmls/src/treesync/diff.rs @@ -126,10 +126,9 @@ impl<'a> TreeSyncDiff<'a> { let copath_resolutions = self.copath_resolutions(leaf_index); // The two vectors should have the same length - debug_assert_eq!(copath.len(), copath_resolutions.len()); + // debug_assert_eq!(copath.cloned().count(), copath_resolutions.len()); copath - .into_iter() .zip(copath_resolutions) .filter_map(|(index, resolution)| { // Filter out the nodes whose copath resolution is empty @@ -539,7 +538,6 @@ impl<'a> TreeSyncDiff<'a> { // each node. self.diff .copath(leaf_index) - .into_iter() .map(|node_index| self.resolution(node_index, &HashSet::new())) .collect() }