diff --git a/Cargo.toml b/Cargo.toml index 9b1e522..cb4e181 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ keywords = ["combinatorial", "optimization", "tree", "graph"] license-file = "LICENSE" readme = "README.md" categories = ["algorithms"] -version = "0.1.1" +version = "0.2.0" edition = "2021" authors = ["Du Shiqiao "] diff --git a/README.md b/README.md index 33adfab..a34464f 100644 --- a/README.md +++ b/README.md @@ -84,10 +84,10 @@ fn main() { let lower_bound_fn = |n: &Node| { let current_profit = total_profit(n); let max_remained_profit: u32 = profits[n.len()..].into_iter().sum(); - u32::MAX - (current_profit + max_remained_profit) + Some(u32::MAX - (current_profit + max_remained_profit)) }; - let cost_fn = |n: &Node| u32::MAX - total_profit(n); + let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n)); let leaf_check_fn = |n: &Node| n.len() == total_items; diff --git a/examples/bbs_knapsack_problem.rs b/examples/bbs_knapsack_problem.rs index db1afbb..97e4c24 100644 --- a/examples/bbs_knapsack_problem.rs +++ b/examples/bbs_knapsack_problem.rs @@ -60,10 +60,10 @@ fn main() { let lower_bound_fn = |n: &Node| { let current_profit = total_profit(n); let max_remained_profit: u32 = profits[n.len()..].into_iter().sum(); - u32::MAX - (current_profit + max_remained_profit) + Some(u32::MAX - (current_profit + max_remained_profit)) }; - let cost_fn = |n: &Node| u32::MAX - total_profit(n); + let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n)); let leaf_check_fn = |n: &Node| n.len() == total_items; diff --git a/src/bbs.rs b/src/bbs.rs index cfcfdf3..5d6ef72 100644 --- a/src/bbs.rs +++ b/src/bbs.rs @@ -16,16 +16,19 @@ where N: Clone, FN: FnMut(&N) -> IN, IN: IntoIterator, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, { type Item = N; fn next(&mut self) -> Option { if let Some(n) = self.to_see.pop() { - if (self.lower_bound_fn)(&n) <= self.current_best_cost { - for s in (self.successor_fn)(&n) { - self.to_see.push(s.clone()); + // get lower bound + if let Some(lb) = (self.lower_bound_fn)(&n) { + if lb <= self.current_best_cost { + for s in (self.successor_fn)(&n) { + self.to_see.push(s.clone()); + } } } Some(n) @@ -40,7 +43,7 @@ where N: Clone, FN: FnMut(&N) -> IN, IN: IntoIterator, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, { } @@ -55,7 +58,7 @@ where N: Clone, FN: FnMut(&N) -> IN, IN: IntoIterator, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, { BbsReachable { @@ -86,8 +89,8 @@ where N: Clone, IN: IntoIterator, FN: FnMut(&N) -> IN, - FC1: Fn(&N) -> C, - FC2: Fn(&N) -> C, + FC1: Fn(&N) -> Option, + FC2: Fn(&N) -> Option, C: Ord + Copy + Bounded, FR: Fn(&N) -> bool, { @@ -100,10 +103,11 @@ where } let n = op_n.unwrap(); if leaf_check_fn(&n) { - let cost = cost_fn(&n); - if res.current_best_cost > cost { - res.current_best_cost = cost; - best_leaf_node = Some(n) + if let Some(cost) = cost_fn(&n) { + if res.current_best_cost > cost { + res.current_best_cost = cost; + best_leaf_node = Some(n) + } } } } @@ -175,10 +179,10 @@ mod test { let lower_bound_fn = |n: &Node| { let current_profit = total_profit(n); let max_remained_profit: u32 = profits[n.len()..].into_iter().sum(); - u32::MAX - (current_profit + max_remained_profit) + Some(u32::MAX - (current_profit + max_remained_profit)) }; - let cost_fn = |n: &Node| u32::MAX - total_profit(n); + let cost_fn = |n: &Node| Some(u32::MAX - total_profit(n)); let leaf_check_fn = |n: &Node| n.len() == total_items; diff --git a/src/bfs.rs b/src/bfs.rs index ab3152d..3ca9ab5 100644 --- a/src/bfs.rs +++ b/src/bfs.rs @@ -22,14 +22,14 @@ where N: Clone, IN: IntoIterator, FN: FnMut(&N) -> IN, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, FR: Fn(&N) -> bool, { bms( start, successor_fn, - |_| C::min_value(), + |_| Some(C::min_value()), usize::MAX, usize::MAX, cost_fn, @@ -95,7 +95,7 @@ mod test { } }) .sum(); - u32::MAX - cost + Some(u32::MAX - cost) }; let leaf_check_fn = |n: &Node| n.len() == total_items; diff --git a/src/bms.rs b/src/bms.rs index d0e46d2..d40452d 100644 --- a/src/bms.rs +++ b/src/bms.rs @@ -48,7 +48,7 @@ where N: Clone, FN: FnMut(&N) -> IN, IN: IntoIterator, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, { type Item = N; @@ -64,7 +64,10 @@ where if let Some(node) = self.to_see.pop_front() { let mut successors: Vec<_> = (self.successor_fn)(&node) .into_iter() - .map(|n| ((self.eval_fn)(&n), n)) + .filter_map(|n| { + let cost = (self.eval_fn)(&n)?; + Some((cost, n)) + }) .collect(); successors.sort_unstable_by_key(|x| x.0); successors @@ -90,7 +93,7 @@ where N: Clone, FN: FnMut(&N) -> IN, IN: IntoIterator, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, { BmsReachable { @@ -127,8 +130,8 @@ where N: Clone, IN: IntoIterator, FN: FnMut(&N) -> IN, - FC1: Fn(&N) -> C, - FC2: Fn(&N) -> C, + FC1: Fn(&N) -> Option, + FC2: Fn(&N) -> Option, C: Ord + Copy + Bounded, FR: Fn(&N) -> bool, { @@ -142,10 +145,11 @@ where } let n = op_n.unwrap(); if leaf_check_fn(&n) { - let cost = cost_fn(&n); - if current_best_cost > cost { - current_best_cost = cost; - best_leaf_node = Some(n) + if let Some(cost) = cost_fn(&n) { + if current_best_cost > cost { + current_best_cost = cost; + best_leaf_node = Some(n) + } } } } @@ -310,12 +314,12 @@ mod test { let eval_fn = |n: &Node| { let (remained_duration, route) = greedy_tsp_solver(n.city, n.children.clone(), &time_func); - n.t + remained_duration + time_func(*route.last().unwrap(), start) + Some(n.t + remained_duration + time_func(*route.last().unwrap(), start)) }; let branch_factor = 10; let beam_width = 5; - let cost_fn = |n: &Node| n.t + time_func(n.city, start); + let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start)); let leaf_check_fn = |n: &Node| n.is_leaf(); let (cost, best_node) = bms( diff --git a/src/dfs.rs b/src/dfs.rs index b4af6fb..983a43d 100644 --- a/src/dfs.rs +++ b/src/dfs.rs @@ -22,14 +22,14 @@ where N: Clone, IN: IntoIterator, FN: FnMut(&N) -> IN, - FC: Fn(&N) -> C, + FC: Fn(&N) -> Option, C: Ord + Copy + Bounded, FR: Fn(&N) -> bool, { bbs( start, successor_fn, - |_| C::min_value(), + |_| Some(C::min_value()), cost_fn, leaf_check_fn, ) @@ -93,7 +93,7 @@ mod test { } }) .sum(); - u32::MAX - cost + Some(u32::MAX - cost) }; let leaf_check_fn = |n: &Node| n.len() == total_items; diff --git a/src/gds.rs b/src/gds.rs index d1c176f..3ba1de5 100644 --- a/src/gds.rs +++ b/src/gds.rs @@ -24,8 +24,8 @@ where N: Clone, IN: IntoIterator, FN: FnMut(&N) -> IN, - FC1: Fn(&N) -> C, - FC2: Fn(&N) -> C, + FC1: Fn(&N) -> Option, + FC2: Fn(&N) -> Option, C: Ord + Copy + Bounded, FR: Fn(&N) -> bool, { @@ -168,9 +168,9 @@ mod test { let time_func = |p: CityId, c: CityId| distance_matrix[p][c]; let successor_fn = |n: &Node| n.generate_child_nodes(&time_func); - let eval_fn = |n: &Node| n.t; + let eval_fn = |n: &Node| Some(n.t); - let cost_fn = |n: &Node| n.t + time_func(n.city, start); + let cost_fn = |n: &Node| Some(n.t + time_func(n.city, start)); let leaf_check_fn = |n: &Node| n.is_leaf(); let (cost, best_node) =