Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into refactor/variable_map
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra committed Jan 2, 2025
2 parents 401772c + f6537fc commit ac092e0
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 77 deletions.
39 changes: 24 additions & 15 deletions src/internal/id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::fmt::{Display, Formatter};
use std::{
fmt::{Display, Formatter},
num::NonZeroU32,
};

use crate::{internal::arena::ArenaId, Interner};

Expand Down Expand Up @@ -95,32 +98,24 @@ impl From<SolvableId> for u32 {

#[repr(transparent)]
#[derive(Copy, Clone, PartialOrd, Ord, Eq, PartialEq, Debug, Hash)]
pub(crate) struct ClauseId(u32);
pub(crate) struct ClauseId(NonZeroU32);

impl ClauseId {
/// There is a guarentee that ClauseId(0) will always be
/// There is a guarentee that ClauseId(1) will always be
/// "Clause::InstallRoot". This assumption is verified by the solver.
pub(crate) fn install_root() -> Self {
Self(0)
}

pub(crate) fn is_null(self) -> bool {
self.0 == u32::MAX
}

pub(crate) fn null() -> ClauseId {
ClauseId(u32::MAX)
Self(unsafe { NonZeroU32::new_unchecked(1) })
}
}

impl ArenaId for ClauseId {
fn from_usize(x: usize) -> Self {
assert!(x < u32::MAX as usize, "clause id too big");
Self(x as u32)
// SAFETY: Safe because we always add 1 to the index
Self(unsafe { NonZeroU32::new_unchecked((x + 1).try_into().expect("clause id too big")) })
}

fn to_usize(self) -> usize {
self.0 as usize
(self.0.get() - 1) as usize
}
}

Expand Down Expand Up @@ -266,3 +261,17 @@ impl<'i, I: Interner> Display for DisplaySolvableOrRootId<'i, I> {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_clause_id_size() {
// Verify that the size of a ClauseId is the same as an Option<ClauseId>.
assert_eq!(
std::mem::size_of::<ClauseId>(),
std::mem::size_of::<Option<ClauseId>>()
);
}
}
15 changes: 15 additions & 0 deletions src/internal/mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ impl<TId: ArenaId, TValue> Mapping<TId, TValue> {
previous_value
}

/// Unset a specific value in the mapping, returns the previous value.
pub fn unset(&mut self, id: TId) -> Option<TValue> {
let idx = id.to_usize();
let (chunk, offset) = Self::chunk_and_offset(idx);
if chunk >= self.chunks.len() {
return None;
}

let previous_value = self.chunks[chunk][offset].take();
if previous_value.is_some() {
self.len -= 1;
}
previous_value
}

/// Get a specific value in the mapping with bound checks
pub fn get(&self, id: TId) -> Option<&TValue> {
let (chunk, offset) = Self::chunk_and_offset(id.to_usize());
Expand Down
60 changes: 30 additions & 30 deletions src/solver/clause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub(crate) struct ClauseState {
// The ids of the literals this clause is watching
pub watched_literals: [Literal; 2],
// The ids of the next clause in each linked list that this clause is part of
pub(crate) next_watches: [ClauseId; 2],
pub(crate) next_watches: [Option<ClauseId>; 2],
}

impl ClauseState {
Expand Down Expand Up @@ -415,7 +415,7 @@ impl ClauseState {

let clause = Self {
watched_literals,
next_watches: [ClauseId::null(), ClauseId::null()],
next_watches: [None, None],
};

debug_assert!(!clause.has_watches() || watched_literals[0] != watched_literals[1]);
Expand All @@ -438,7 +438,7 @@ impl ClauseState {
}

#[inline]
pub fn next_watched_clause(&self, solvable_id: VariableId) -> ClauseId {
pub fn next_watched_clause(&self, solvable_id: VariableId) -> Option<ClauseId> {
if solvable_id == self.watched_literals[0].variable() {
self.next_watches[0]
} else {
Expand Down Expand Up @@ -647,7 +647,7 @@ mod test {
use super::*;
use crate::{internal::arena::ArenaId, solver::decision::Decision};

fn clause(next_clauses: [ClauseId; 2], watch_literals: [Literal; 2]) -> ClauseState {
fn clause(next_clauses: [Option<ClauseId>; 2], watch_literals: [Literal; 2]) -> ClauseState {
ClauseState {
watched_literals: watch_literals,
next_watches: next_clauses,
Expand Down Expand Up @@ -688,21 +688,24 @@ mod test {
#[test]
fn test_unlink_clause_different() {
let clause1 = clause(
[ClauseId::from_usize(2), ClauseId::from_usize(3)],
[
ClauseId::from_usize(2).into(),
ClauseId::from_usize(3).into(),
],
[
VariableId::from_usize(1596).negative(),
VariableId::from_usize(1211).negative(),
],
);
let clause2 = clause(
[ClauseId::null(), ClauseId::from_usize(3)],
[None, ClauseId::from_usize(3).into()],
[
VariableId::from_usize(1596).negative(),
VariableId::from_usize(1208).negative(),
],
);
let clause3 = clause(
[ClauseId::null(), ClauseId::null()],
[None, None],
[
VariableId::from_usize(1211).negative(),
VariableId::from_usize(42).negative(),
Expand All @@ -720,10 +723,7 @@ mod test {
VariableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::null(), ClauseId::from_usize(3)]
)
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(3).into()])
}

// Unlink 1
Expand All @@ -737,24 +737,24 @@ mod test {
VariableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::from_usize(2), ClauseId::null()]
)
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
}
}

#[test]
fn test_unlink_clause_same() {
let clause1 = clause(
[ClauseId::from_usize(2), ClauseId::from_usize(2)],
[
ClauseId::from_usize(2).into(),
ClauseId::from_usize(2).into(),
],
[
VariableId::from_usize(1596).negative(),
VariableId::from_usize(1211).negative(),
],
);
let clause2 = clause(
[ClauseId::null(), ClauseId::null()],
[None, None],
[
VariableId::from_usize(1596).negative(),
VariableId::from_usize(1211).negative(),
Expand All @@ -772,10 +772,7 @@ mod test {
VariableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::null(), ClauseId::from_usize(2)]
)
assert_eq!(clause1.next_watches, [None, ClauseId::from_usize(2).into()])
}

// Unlink 1
Expand All @@ -789,10 +786,7 @@ mod test {
VariableId::from_usize(1211).negative()
]
);
assert_eq!(
clause1.next_watches,
[ClauseId::from_usize(2), ClauseId::null()]
)
assert_eq!(clause1.next_watches, [ClauseId::from_usize(2).into(), None])
}
}

Expand All @@ -817,7 +811,10 @@ mod test {

// No conflict, still one candidate available
decisions
.try_add_decision(Decision::new(candidate1.into(), false, ClauseId::null()), 1)
.try_add_decision(
Decision::new(candidate1.into(), false, ClauseId::from_usize(0)),
1,
)
.unwrap();
let (clause, conflict, _kind) = ClauseState::requires(
parent,
Expand All @@ -831,7 +828,10 @@ mod test {

// Conflict, no candidates available
decisions
.try_add_decision(Decision::new(candidate2.into(), false, ClauseId::null()), 1)
.try_add_decision(
Decision::new(candidate2.into(), false, ClauseId::install_root()),
1,
)
.unwrap();
let (clause, conflict, _kind) = ClauseState::requires(
parent,
Expand All @@ -845,7 +845,7 @@ mod test {

// Panic
decisions
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
.unwrap();
let panicked = std::panic::catch_unwind(|| {
ClauseState::requires(
Expand Down Expand Up @@ -875,7 +875,7 @@ mod test {

// Conflict, forbidden package installed
decisions
.try_add_decision(Decision::new(forbidden, true, ClauseId::null()), 1)
.try_add_decision(Decision::new(forbidden, true, ClauseId::install_root()), 1)
.unwrap();
let (clause, conflict, _kind) =
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions);
Expand All @@ -885,7 +885,7 @@ mod test {

// Panic
decisions
.try_add_decision(Decision::new(parent, false, ClauseId::null()), 1)
.try_add_decision(Decision::new(parent, false, ClauseId::install_root()), 1)
.unwrap();
let panicked = std::panic::catch_unwind(|| {
ClauseState::constrains(parent, forbidden, VersionSetId::from_usize(0), &decisions)
Expand Down
36 changes: 16 additions & 20 deletions src/solver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,12 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
{
// Find the first candidate that is not yet assigned a value or find the first
// value that makes this clause true.
candidate = candidates
.iter()
.try_fold(None, |first_candidate, &candidate| {
candidate = candidates.iter().try_fold(
match candidate {
ControlFlow::Continue(x) => x,
_ => None,
},
|first_candidate, &candidate| {
let assigned_value = self.decision_tracker.assigned_value(candidate);
ControlFlow::Continue(match assigned_value {
Some(true) => {
Expand Down Expand Up @@ -850,7 +853,8 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
}
},
})
});
},
);

// Stop searching if we found a candidate that makes the clause true.
if candidate.is_break() {
Expand Down Expand Up @@ -1155,11 +1159,8 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
// solvable
let mut old_predecessor_clause_id: Option<ClauseId>;
let mut predecessor_clause_id: Option<ClauseId> = None;
let mut clause_id = self
.watches
.first_clause_watching_literal(watched_literal)
.unwrap_or(ClauseId::null());
while !clause_id.is_null() {
let mut next_clause_id = self.watches.first_clause_watching_literal(watched_literal);
while let Some(clause_id) = next_clause_id {
debug_assert!(
predecessor_clause_id != Some(clause_id),
"Linked list is circular!"
Expand All @@ -1186,8 +1187,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
predecessor_clause_id = Some(clause_id);

// Configure the next clause to visit
let this_clause_id = clause_id;
clause_id = clause_state.next_watched_clause(watched_literal.variable());
next_clause_id = clause_state.next_watched_clause(watched_literal.variable());

// Determine which watch turned false.
let (watch_index, other_watch_index) =
Expand All @@ -1210,7 +1210,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
// If the other watch is already true, we can simply skip
// this clause.
} else if let Some(variable) = clause_state.next_unwatched_literal(
&clauses[this_clause_id.to_usize()],
&clauses[clause_id.to_usize()],
&self.learnt_clauses,
&self.requirement_to_sorted_candidates,
self.decision_tracker.map(),
Expand All @@ -1219,7 +1219,7 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
self.watches.update_watched(
predecessor_clause_state,
clause_state,
this_clause_id,
clause_id,
watch_index,
watched_literal,
variable,
Expand All @@ -1245,20 +1245,16 @@ impl<D: DependencyProvider, RT: AsyncRuntime> Solver<D, RT> {
Decision::new(
remaining_watch.variable(),
remaining_watch.satisfying_value(),
this_clause_id,
clause_id,
),
level,
)
.map_err(|_| {
PropagationError::Conflict(
remaining_watch.variable(),
true,
this_clause_id,
)
PropagationError::Conflict(remaining_watch.variable(), true, clause_id)
})?;

if decided {
let clause = &clauses[this_clause_id.to_usize()];
let clause = &clauses[clause_id.to_usize()];
match clause {
// Skip logging for ForbidMultipleInstances, which is so noisy
Clause::ForbidMultipleInstances(..) => {}
Expand Down
18 changes: 6 additions & 12 deletions src/solver/watch_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ impl WatchMap {
// Construct a linked list by adding the clause to the start of the linked list
// and setting the previous head of the chain as the next element in the linked
// list.
let current_head = self
.map
.get(watched_literal)
.copied()
.unwrap_or(ClauseId::null());
let current_head = self.map.get(watched_literal).copied();
clause.next_watches[watch_index] = current_head;
self.map.insert(watched_literal, clause_id);
}
Expand All @@ -49,18 +45,16 @@ impl WatchMap {
if let Some(predecessor_clause) = predecessor_clause {
// Unlink the clause
predecessor_clause.unlink_clause(clause, previous_watch.variable(), watch_index);
} else {
} else if let Some(next_watch) = clause.next_watches[watch_index] {
// This was the first clause in the chain
self.map
.insert(previous_watch, clause.next_watches[watch_index]);
self.map.insert(previous_watch, next_watch);
} else {
self.map.unset(previous_watch);
}

// Set the new watch
clause.watched_literals[watch_index] = new_watch;
let previous_clause_id = self
.map
.insert(new_watch, clause_id)
.unwrap_or(ClauseId::null());
let previous_clause_id = self.map.insert(new_watch, clause_id);
clause.next_watches[watch_index] = previous_clause_id;
}

Expand Down
Loading

0 comments on commit ac092e0

Please sign in to comment.