diff --git a/src/internal/id.rs b/src/internal/id.rs index 56bccc8..47fe226 100644 --- a/src/internal/id.rs +++ b/src/internal/id.rs @@ -1,4 +1,7 @@ -use std::fmt::{Display, Formatter}; +use std::{ + fmt::{Display, Formatter}, + num::NonZeroU32, +}; use crate::{internal::arena::ArenaId, Interner}; @@ -95,32 +98,24 @@ impl From for u32 { #[repr(transparent)] #[derive(Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug, Hash)] -pub(crate) struct ClauseId(u32); +pub(crate) struct ClauseId(NonZeroU32); impl ClauseId { - /// There is a guarentee that ClauseId(0) will always be + /// There is a guarentee that ClauseId(1) will always be /// "Clause::InstallRoot". This assumption is verified by the solver. pub(crate) fn install_root() -> Self { - Self(0) - } - - pub(crate) fn is_null(self) -> bool { - self.0 == u32::MAX - } - - pub(crate) fn null() -> ClauseId { - ClauseId(u32::MAX) + Self(unsafe { NonZeroU32::new_unchecked(1) }) } } impl ArenaId for ClauseId { fn from_usize(x: usize) -> Self { - assert!(x < u32::MAX as usize, "clause id too big"); - Self(x as u32) + // SAFETY: Safe because we always add 1 to the index + Self(unsafe { NonZeroU32::new_unchecked((x + 1).try_into().expect("clause id too big")) }) } fn to_usize(self) -> usize { - self.0 as usize + (self.0.get() - 1) as usize } } @@ -266,3 +261,17 @@ impl<'i, I: Interner> Display for DisplaySolvableOrRootId<'i, I> { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_clause_id_size() { + // Verify that the size of a ClauseId is the same as an Option. + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::>() + ); + } +} diff --git a/src/internal/mapping.rs b/src/internal/mapping.rs index 26df86b..d0e8ea7 100644 --- a/src/internal/mapping.rs +++ b/src/internal/mapping.rs @@ -65,6 +65,21 @@ impl Mapping { previous_value } + /// Unset a specific value in the mapping, returns the previous value. + pub fn unset(&mut self, id: TId) -> Option { + let idx = id.to_usize(); + let (chunk, offset) = Self::chunk_and_offset(idx); + if chunk >= self.chunks.len() { + return None; + } + + let previous_value = self.chunks[chunk][offset].take(); + if previous_value.is_some() { + self.len -= 1; + } + previous_value + } + /// Get a specific value in the mapping with bound checks pub fn get(&self, id: TId) -> Option<&TValue> { let (chunk, offset) = Self::chunk_and_offset(id.to_usize()); diff --git a/src/solver/clause.rs b/src/solver/clause.rs index 0f23449..4ccc610 100644 --- a/src/solver/clause.rs +++ b/src/solver/clause.rs @@ -325,7 +325,7 @@ pub(crate) struct ClauseState { // The ids of the literals this clause is watching pub watched_literals: [Literal; 2], // The ids of the next clause in each linked list that this clause is part of - pub(crate) next_watches: [ClauseId; 2], + pub(crate) next_watches: [Option; 2], } impl ClauseState { @@ -415,7 +415,7 @@ impl ClauseState { let clause = Self { watched_literals, - next_watches: [ClauseId::null(), ClauseId::null()], + next_watches: [None, None], }; debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]); @@ -438,7 +438,7 @@ impl ClauseState { } #[inline] - pub fn next_watched_clause(&self, solvable_id: VariableId) -> ClauseId { + pub fn next_watched_clause(&self, solvable_id: VariableId) -> Option { if solvable_id == self.watched_literals[0].variable() { self.next_watches[0] } else { @@ -647,7 +647,7 @@ mod test { use super::*; use crate::{internal::arena::ArenaId, solver::decision::Decision}; - fn clause(next_clauses: [ClauseId; 2], watch_literals: [Literal; 2]) -> ClauseState { + fn clause(next_clauses: [Option; 2], watch_literals: [Literal; 2]) -> ClauseState { ClauseState { watched_literals: watch_literals, next_watches: next_clauses, @@ -688,21 +688,24 @@ mod test { #[test] fn test_unlink_clause_different() { let clause1 = clause( - [ClauseId::from_usize(2), ClauseId::from_usize(3)], + [ + ClauseId::from_usize(2).into(), + ClauseId::from_usize(3).into(), + ], [ VariableId::from_usize(1596).negative(), VariableId::from_usize(1211).negative(), ], ); let clause2 = clause( - [ClauseId::null(), ClauseId::from_usize(3)], + [None, ClauseId::from_usize(3).into()], [ VariableId::from_usize(1596).negative(), VariableId::from_usize(1208).negative(), ], ); let clause3 = clause( - [ClauseId::null(), ClauseId::null()], + [None, None], [ VariableId::from_usize(1211).negative(), VariableId::from_usize(42).negative(), @@ -720,10 +723,7 @@ mod test { VariableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::null(), ClauseId::from_usize(3)] - ) + assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()]) } // Unlink 1 @@ -737,24 +737,24 @@ mod test { VariableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::from_usize(2), ClauseId::null()] - ) + assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None]) } } #[test] fn test_unlink_clause_same() { let clause1 = clause( - [ClauseId::from_usize(2), ClauseId::from_usize(2)], + [ + ClauseId::from_usize(2).into(), + ClauseId::from_usize(2).into(), + ], [ VariableId::from_usize(1596).negative(), VariableId::from_usize(1211).negative(), ], ); let clause2 = clause( - [ClauseId::null(), ClauseId::null()], + [None, None], [ VariableId::from_usize(1596).negative(), VariableId::from_usize(1211).negative(), @@ -772,10 +772,7 @@ mod test { VariableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::null(), ClauseId::from_usize(2)] - ) + assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()]) } // Unlink 1 @@ -789,10 +786,7 @@ mod test { VariableId::from_usize(1211).negative() ] ); - assert_eq!( - clause1.next_watches, - [ClauseId::from_usize(2), ClauseId::null()] - ) + assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None]) } } @@ -817,7 +811,10 @@ mod test { // No conflict, still one candidate available decisions - .try_add_decision(Decision::new(candidate1.into(), false, ClauseId::null()), 1) + .try_add_decision( + Decision::new(candidate1.into(), false, ClauseId::from_usize(0)), + 1, + ) .unwrap(); let (clause, conflict, _kind) = ClauseState::requires( parent, @@ -831,7 +828,10 @@ mod test { // Conflict, no candidates available decisions - .try_add_decision(Decision::new(candidate2.into(), false, ClauseId::null()), 1) + .try_add_decision( + Decision::new(candidate2.into(), false, ClauseId::install_root()), + 1, + ) .unwrap(); let (clause, conflict, _kind) = ClauseState::requires( parent, @@ -845,7 +845,7 @@ mod test { // Panic decisions - .try_add_decision(Decision::new(parent, false, ClauseId::null()), 1) + .try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1) .unwrap(); let panicked = std::panic::catch_unwind(|| { ClauseState::requires( @@ -875,7 +875,7 @@ mod test { // Conflict, forbidden package installed decisions - .try_add_decision(Decision::new(forbidden, true, ClauseId::null()), 1) + .try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1) .unwrap(); let (clause, conflict, _kind) = ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions); @@ -885,7 +885,7 @@ mod test { // Panic decisions - .try_add_decision(Decision::new(parent, false, ClauseId::null()), 1) + .try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1) .unwrap(); let panicked = std::panic::catch_unwind(|| { ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions) diff --git a/src/solver/mod.rs b/src/solver/mod.rs index e2198f7..9fca0bf 100644 --- a/src/solver/mod.rs +++ b/src/solver/mod.rs @@ -803,9 +803,12 @@ impl Solver { { // Find the first candidate that is not yet assigned a value or find the first // value that makes this clause true. - candidate = candidates - .iter() - .try_fold(None, |first_candidate, &candidate| { + candidate = candidates.iter().try_fold( + match candidate { + ControlFlow::Continue(x) => x, + _ => None, + }, + |first_candidate, &candidate| { let assigned_value = self.decision_tracker.assigned_value(candidate); ControlFlow::Continue(match assigned_value { Some(true) => { @@ -850,7 +853,8 @@ impl Solver { } }, }) - }); + }, + ); // Stop searching if we found a candidate that makes the clause true. if candidate.is_break() { @@ -1155,11 +1159,8 @@ impl Solver { // solvable let mut old_predecessor_clause_id: Option; let mut predecessor_clause_id: Option = None; - let mut clause_id = self - .watches - .first_clause_watching_literal(watched_literal) - .unwrap_or(ClauseId::null()); - while !clause_id.is_null() { + let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal); + while let Some(clause_id) = next_clause_id { debug_assert!( predecessor_clause_id != Some(clause_id), "Linked list is circular!" @@ -1186,8 +1187,7 @@ impl Solver { predecessor_clause_id = Some(clause_id); // Configure the next clause to visit - let this_clause_id = clause_id; - clause_id = clause_state.next_watched_clause(watched_literal.variable()); + next_clause_id = clause_state.next_watched_clause(watched_literal.variable()); // Determine which watch turned false. let (watch_index, other_watch_index) = @@ -1210,7 +1210,7 @@ impl Solver { // If the other watch is already true, we can simply skip // this clause. } else if let Some(variable) = clause_state.next_unwatched_literal( - &clauses[this_clause_id.to_usize()], + &clauses[clause_id.to_usize()], &self.learnt_clauses, &self.requirement_to_sorted_candidates, self.decision_tracker.map(), @@ -1219,7 +1219,7 @@ impl Solver { self.watches.update_watched( predecessor_clause_state, clause_state, - this_clause_id, + clause_id, watch_index, watched_literal, variable, @@ -1245,20 +1245,16 @@ impl Solver { Decision::new( remaining_watch.variable(), remaining_watch.satisfying_value(), - this_clause_id, + clause_id, ), level, ) .map_err(|_| { - PropagationError::Conflict( - remaining_watch.variable(), - true, - this_clause_id, - ) + PropagationError::Conflict(remaining_watch.variable(), true, clause_id) })?; if decided { - let clause = &clauses[this_clause_id.to_usize()]; + let clause = &clauses[clause_id.to_usize()]; match clause { // Skip logging for ForbidMultipleInstances, which is so noisy Clause::ForbidMultipleInstances(..) => {} diff --git a/src/solver/watch_map.rs b/src/solver/watch_map.rs index a14b888..07cce36 100644 --- a/src/solver/watch_map.rs +++ b/src/solver/watch_map.rs @@ -25,11 +25,7 @@ impl WatchMap { // Construct a linked list by adding the clause to the start of the linked list // and setting the previous head of the chain as the next element in the linked // list. - let current_head = self - .map - .get(watched_literal) - .copied() - .unwrap_or(ClauseId::null()); + let current_head = self.map.get(watched_literal).copied(); clause.next_watches[watch_index] = current_head; self.map.insert(watched_literal, clause_id); } @@ -49,18 +45,16 @@ impl WatchMap { if let Some(predecessor_clause) = predecessor_clause { // Unlink the clause predecessor_clause.unlink_clause(clause, previous_watch.variable(), watch_index); - } else { + } else if let Some(next_watch) = clause.next_watches[watch_index] { // This was the first clause in the chain - self.map - .insert(previous_watch, clause.next_watches[watch_index]); + self.map.insert(previous_watch, next_watch); + } else { + self.map.unset(previous_watch); } // Set the new watch clause.watched_literals[watch_index] = new_watch; - let previous_clause_id = self - .map - .insert(new_watch, clause_id) - .unwrap_or(ClauseId::null()); + let previous_clause_id = self.map.insert(new_watch, clause_id); clause.next_watches[watch_index] = previous_clause_id; } diff --git a/tests/solver.rs b/tests/solver.rs index 389c1c9..de15d8a 100644 --- a/tests/solver.rs +++ b/tests/solver.rs @@ -1365,6 +1365,17 @@ fn test_snapshot_union_requirements() { )); } +#[test] +fn test_union_empty_requirements() { + let provider = BundleBoxProvider::from_packages(&[("a", 1, vec!["b 1 | c"]), ("b", 1, vec![])]); + + let result = solve_snapshot(provider, &["a"]); + assert_snapshot!(result, @r" + a=1 + b=1 + "); +} + #[test] fn test_root_constraints() { let provider =