From feef8ad81c30780078866a05ee824eb93702efcb Mon Sep 17 00:00:00 2001 From: europeanplaice Date: Mon, 21 Mar 2022 01:08:11 +0900 Subject: [PATCH] Improve powerset --- src/dp_module.rs | 74 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 56 insertions(+), 18 deletions(-) diff --git a/src/dp_module.rs b/src/dp_module.rs index 59960b6..4f8abcc 100644 --- a/src/dp_module.rs +++ b/src/dp_module.rs @@ -1,5 +1,29 @@ pub mod dp { //! This is a module for dynamic programming. + + use itertools::structs::Combinations; + + struct MultiCombination { + combs: Vec>, + } + + impl Iterator for MultiCombination + where I: Iterator, + I::Item: Clone + { + type Item = Vec; + + fn next(&mut self) -> Option { + for comb in &mut self.combs { + if let Some(elt) = comb.next() { + return Some(elt); + } else { + continue + } + } + None + } + } use std::collections::HashMap; use rayon::prelude::*; use std::sync::{Arc, Mutex}; @@ -304,7 +328,7 @@ pub mod dp { } } let a_length: usize = arr.len(); - let mut route: Vec = Vec::with_capacity(max_length); + let mut route: Vec = vec![]; let mut answer: Vec> = vec![]; rec( @@ -420,12 +444,12 @@ pub mod dp { max_target_length: usize, hashmap_fs: &mut Arc, i32), Vec>>>>, n_candidates: usize, - ) { + ) -> (){ use itertools::Itertools; + use std::cmp::min; if answer.lock().unwrap().len() >= n_candidates { return; } - if keys.len() == 0 && targets.len() == 0 { group.sort_by_key(|k| k.0.iter().sum::()); group.sort_by_key(|k| k.0.len()); @@ -440,7 +464,14 @@ pub mod dp { return; } targets.sort(); - (0..keys.len()).powerset().filter(|x| x.len() <= max_key_length).par_bridge().for_each(|i| { + let mut combs = vec![]; + for i in 1..min(max_key_length, keys.len()) + 1{ + combs.push((0..keys.len()).into_iter().combinations(i)) + }; + let mc = MultiCombination{ + combs: combs, + }; + mc.par_bridge().for_each(|i| { let keys2 = keys.clone(); let targets2 = targets.clone(); let group2 = group.clone(); @@ -454,18 +485,23 @@ pub mod dp { if sum_key > targets2.iter().sum() { return; } + if sum_key < *targets2.iter().min().unwrap() { + return; + } if targets2.iter().max().unwrap() == &0 { return; } - let set_ = match hashmap_fs.try_lock() { - Ok(mut v) => v.entry((targets2.clone(), sum_key)) + let mut set_ = match hashmap_fs.try_lock() { + Ok(mut v) => {v.entry((targets2.clone(), sum_key)) .or_insert(find_subset(targets2.clone(), sum_key, max_target_length)) - .clone(), - Err(_) => find_subset(targets2.clone(), sum_key, max_target_length) + .clone()}, + Err(_) => {find_subset(targets2.clone(), sum_key, max_target_length)} }; if set_.len() == 0 { return; } + set_.sort(); + set_.dedup(); set_.par_iter().for_each(|set| { let mut keys3 = keys2.clone(); let mut targets3 = targets2.clone(); @@ -487,8 +523,9 @@ pub mod dp { &mut hashmap_fs.clone(), n_candidates ); - }) + }); }); + () } #[test] @@ -509,26 +546,27 @@ pub mod dp { 2, 200 ); - assert_eq!(answer.len(), 197); + assert_eq!(answer.len(), 195); assert_eq!( answer[0], vec![ - (vec![], vec![-700, 700]), - (vec![200, 300], vec![500]), - (vec![100, 500], vec![600]), - (vec![-700, 400, 600], vec![300]), + (vec![-700], vec![-700]), + (vec![100, 200], vec![300]), + (vec![300, 400], vec![700]), + (vec![500, 600], vec![500, 600]), (vec![800, 900, 1000], vec![2700]), ] ); let answer = sequence_matcher(&mut vec![9, 0, 1, 7, 1], &mut vec![7, 2, 8, 0, 1], 3, 2, 100); - assert_eq!(answer.len(), 37); + assert_eq!(answer.len(), 24); assert_eq!( answer[0], vec![ - (vec![], vec![0]), - (vec![0, 9], vec![1, 8]), - (vec![1, 1, 7], vec![2, 7]), + (vec![0], vec![0]), + (vec![1], vec![1]), + (vec![7], vec![7]), + (vec![1, 9], vec![2, 8]), ] );