Skip to content

Commit

Permalink
Improve powerset
Browse files Browse the repository at this point in the history
  • Loading branch information
europeanplaice committed Mar 20, 2022
1 parent 43b22c8 commit feef8ad
Showing 1 changed file with 56 additions and 18 deletions.
74 changes: 56 additions & 18 deletions src/dp_module.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,29 @@
pub mod dp {
//! This is a module for dynamic programming.
use itertools::structs::Combinations;

struct MultiCombination<I: Iterator> {
combs: Vec<Combinations<I>>,
}

impl<I> Iterator for MultiCombination<I>
where I: Iterator,
I::Item: Clone
{
type Item = Vec<I::Item>;

fn next(&mut self) -> Option<Self::Item> {
for comb in &mut self.combs {
if let Some(elt) = comb.next() {
return Some(elt);
} else {
continue
}
}
None
}
}
use std::collections::HashMap;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
Expand Down Expand Up @@ -304,7 +328,7 @@ pub mod dp {
}
}
let a_length: usize = arr.len();
let mut route: Vec<u32> = Vec::with_capacity(max_length);
let mut route: Vec<u32> = vec![];
let mut answer: Vec<Vec<u32>> = vec![];

rec(
Expand Down Expand Up @@ -420,12 +444,12 @@ pub mod dp {
max_target_length: usize,
hashmap_fs: &mut Arc<Mutex<HashMap<(Vec<i32>, i32), Vec<Vec<i32>>>>>,
n_candidates: usize,
) {
) -> (){
use itertools::Itertools;
use std::cmp::min;
if answer.lock().unwrap().len() >= n_candidates {
return;
}

if keys.len() == 0 && targets.len() == 0 {
group.sort_by_key(|k| k.0.iter().sum::<i32>());
group.sort_by_key(|k| k.0.len());
Expand All @@ -440,7 +464,14 @@ pub mod dp {
return;
}
targets.sort();
(0..keys.len()).powerset().filter(|x| x.len() <= max_key_length).par_bridge().for_each(|i| {
let mut combs = vec![];
for i in 1..min(max_key_length, keys.len()) + 1{
combs.push((0..keys.len()).into_iter().combinations(i))
};
let mc = MultiCombination{
combs: combs,
};
mc.par_bridge().for_each(|i| {
let keys2 = keys.clone();
let targets2 = targets.clone();
let group2 = group.clone();
Expand All @@ -454,18 +485,23 @@ pub mod dp {
if sum_key > targets2.iter().sum() {
return;
}
if sum_key < *targets2.iter().min().unwrap() {
return;
}
if targets2.iter().max().unwrap() == &0 {
return;
}
let set_ = match hashmap_fs.try_lock() {
Ok(mut v) => v.entry((targets2.clone(), sum_key))
let mut set_ = match hashmap_fs.try_lock() {
Ok(mut v) => {v.entry((targets2.clone(), sum_key))
.or_insert(find_subset(targets2.clone(), sum_key, max_target_length))
.clone(),
Err(_) => find_subset(targets2.clone(), sum_key, max_target_length)
.clone()},
Err(_) => {find_subset(targets2.clone(), sum_key, max_target_length)}
};
if set_.len() == 0 {
return;
}
set_.sort();
set_.dedup();
set_.par_iter().for_each(|set| {
let mut keys3 = keys2.clone();
let mut targets3 = targets2.clone();
Expand All @@ -487,8 +523,9 @@ pub mod dp {
&mut hashmap_fs.clone(),
n_candidates
);
})
});
});
()
}

#[test]
Expand All @@ -509,26 +546,27 @@ pub mod dp {
2,
200
);
assert_eq!(answer.len(), 197);
assert_eq!(answer.len(), 195);
assert_eq!(
answer[0],
vec![
(vec![], vec![-700, 700]),
(vec![200, 300], vec![500]),
(vec![100, 500], vec![600]),
(vec![-700, 400, 600], vec![300]),
(vec![-700], vec![-700]),
(vec![100, 200], vec![300]),
(vec![300, 400], vec![700]),
(vec![500, 600], vec![500, 600]),
(vec![800, 900, 1000], vec![2700]),
]
);

let answer = sequence_matcher(&mut vec![9, 0, 1, 7, 1], &mut vec![7, 2, 8, 0, 1], 3, 2, 100);
assert_eq!(answer.len(), 37);
assert_eq!(answer.len(), 24);
assert_eq!(
answer[0],
vec![
(vec![], vec![0]),
(vec![0, 9], vec![1, 8]),
(vec![1, 1, 7], vec![2, 7]),
(vec![0], vec![0]),
(vec![1], vec![1]),
(vec![7], vec![7]),
(vec![1, 9], vec![2, 8]),
]
);

Expand Down

0 comments on commit feef8ad

Please sign in to comment.