/*!
A simple implementation of congruence closure, [`CongruenceClosure`], which drives modifications to an arbitrary union-find ADT implementing the [`UnionFind`] trait.
Based on Nieuwenhuis and Oliveras [Proof-producing Congruence Closure](https://www.cs.upc.edu/~oliveras/rta05.pdf).

A minimal [disjoint set forest](https://en.wikipedia.org/wiki/Disjoint-set_data_structure) implementation with dense `usize` nodes, [`DisjointSetForest`], is also provided.
*/
#![forbid(missing_docs, unsafe_code, missing_debug_implementations)]
use core::{
    cmp::Ordering,
    fmt::{self, Debug, Formatter},
    hash::{BuildHasher, Hash},
};
use hashbrown::{hash_map::Entry, HashMap};

/// A simple implementation of congruence closure, parametric over an arbitrary disjoint set forest implementation.
///
/// Based on Nieuwenhuis and Oliveras [Proof-producing Congruence Closure](https://www.cs.upc.edu/~oliveras/rta05.pdf).
///
/// # Examples
/// ```rust
/// # use congruence::{DisjointSetForest, CongruenceClosure, UnionFind, CongruenceState};
/// let mut dsf = DisjointSetForest::with_capacity(5);
/// let mut cc = CongruenceClosure::new();
/// let mut st = CongruenceState::default();
/// let a = 0;
/// let b = 1;
/// let c = 2;
/// let d = 3;
/// let e = 4;
/// let f = 5;
/// // Set a * b = c
/// cc.equation(a, b, c, &mut dsf, &mut st);
/// assert!(dsf.is_empty());
/// // Set d * e = f
/// cc.equation(d, e, f, &mut dsf, &mut st);
/// assert!(dsf.is_empty());
/// // Set a = d
/// cc.union(a, d, &mut dsf, &mut st);
/// assert!(!dsf.is_empty());
/// for i in 0..5 {
///     for j in 0..5 {
///         assert_eq!(dsf.node_eq(i, j), i == j || (i == a || i == d) && (j == a || j == d))  
///     }
/// }
/// // Set b = e
/// cc.union(b, e, &mut dsf, &mut st);
/// assert!(dsf.node_eq(b, e));
/// // By congruence, we now have that c = a * b = d * e = f
/// assert!(dsf.node_eq(c, f));
/// assert!(dsf.node_eq(a, d));
/// assert!(!dsf.node_eq(a, b));
/// assert!(!dsf.node_eq(a, c));
/// assert!(!dsf.node_eq(a, e));
/// assert!(!dsf.node_eq(a, f));
/// assert!(!dsf.node_eq(b, c));
/// assert!(!dsf.node_eq(b, d));
/// assert!(!dsf.node_eq(b, f));
/// assert!(!dsf.node_eq(c, d));
/// assert!(!dsf.node_eq(c, e));
/// assert!(!dsf.node_eq(d, e));
/// assert!(!dsf.node_eq(d, f));
/// assert!(!dsf.node_eq(e, f));
/// ```
#[derive(Debug, Clone)]
pub struct CongruenceClosure<I = usize, S = hashbrown::hash_map::DefaultHashBuilder> {
    /// The use-lists: for each representative `a`, a list of input equations `b_1(b_2) = b` where, for some `i`, `a ~ b_i`
    use_lists: HashMap<I, Vec<(I, I, I)>, S>,
    /// A lookup table mapping pairs of representatives `(b, c)` to input equations `a_1(a_2) = a` where `b ~ a_1` and `c ~ a_2` iff such an equation exists.
    /// We have the invariant that `use_lists[b].contains(a)` and `use_lists[c].contains(a)` iff `lookup[(b, c)] = "a_1(a_2) = a"`.
    lookup: HashMap<(I, I), (I, I, I), S>,
}

impl<I, S: Default + BuildHasher> Default for CongruenceClosure<I, S> {
    fn default() -> Self {
        Self {
            use_lists: Default::default(),
            lookup: Default::default(),
        }
    }
}

impl<I> CongruenceClosure<I>
where
    I: Hash + Copy + Eq,
{
    /// Create a new, empty congruence closure
    #[inline]
    pub fn new() -> CongruenceClosure<I> {
        CongruenceClosure {
            use_lists: HashMap::new(),
            lookup: HashMap::new(),
        }
    }

    /// Create a new, empty congruence closure with the given node and pair capacities
    #[inline]
    pub fn with_capacity(nodes: usize, pairs: usize) -> CongruenceClosure<I> {
        CongruenceClosure {
            use_lists: HashMap::with_capacity(nodes),
            lookup: HashMap::with_capacity(pairs),
        }
    }
}

impl<I, S> CongruenceClosure<I, S>
where
    I: Hash + Copy + Eq,
    S: BuildHasher,
{
    /// Whether this congruence closure is empty, i.e. contains no *congruence* relations
    #[inline]
    pub fn is_empty(&self) -> bool {
        self.use_lists.is_empty() && self.lookup.is_empty()
    }

    /// Create a new, empty congruence closure with the given hasher
    #[inline]
    pub fn with_hasher(hasher: S) -> CongruenceClosure<I, S>
    where
        S: Clone,
    {
        //TODO: remove Clone bound, store hasher and use raw tables
        CongruenceClosure {
            use_lists: HashMap::with_hasher(hasher.clone()),
            lookup: HashMap::with_hasher(hasher),
        }
    }

    /// Create a new, empty congruence closure with the given node capacity, pair capacity, and hasher
    #[inline]
    pub fn with_capacity_and_hasher(
        nodes: usize,
        pairs: usize,
        hasher: S,
    ) -> CongruenceClosure<I, S>
    where
        S: Clone,
    {
        //TODO: remove Clone bound, store hasher and use raw tables
        CongruenceClosure {
            use_lists: HashMap::with_capacity_and_hasher(nodes, hasher.clone()),
            lookup: HashMap::with_capacity_and_hasher(pairs, hasher),
        }
    }

    /// Register an equation of the form `left * right = result`
    pub fn equation(
        &mut self,
        left: I,
        right: I,
        result: I,
        union_find: &mut impl UnionFind<I>,
        state: &mut CongruenceState<I>,
    ) {
        let left_repr = union_find.find(left);
        let right_repr = union_find.find(right);
        match self.lookup.entry((left_repr, right_repr)) {
            Entry::Occupied(fb) => {
                let b = fb.get().2;
                self.union(result, b, union_find, state)
            }
            Entry::Vacant(v) => {
                v.insert((left, right, result));
                self.use_lists
                    .entry(left_repr)
                    .or_default()
                    .push((left, right, result));
                self.use_lists
                    .entry(right_repr)
                    .or_default()
                    .push((left, right, result));
            }
        }
    }

    /// Merge the equivalence classes of two nodes
    #[inline]
    pub fn union(
        &mut self,
        mut a: I,
        mut b: I,
        union_find: &mut impl UnionFind<I>,
        state: &mut CongruenceState<I>,
    ) {
        loop {
            let a_repr = union_find.find(a);
            let b_repr = union_find.find(b);
            if a_repr != b_repr {
                let new_repr = union_find.union_find(a, b);

                debug_assert!(union_find.find(a) == new_repr);
                debug_assert!(union_find.find(b) == new_repr);

                let old_repr = if new_repr == a_repr {
                    b_repr
                } else {
                    debug_assert!(new_repr == b_repr);
                    a_repr
                };

                if let Some(list) = self.use_lists.remove(&old_repr) {
                    for (c_0, c_1, c) in list {
                        let left_repr = union_find.find(c_0);
                        let right_repr = union_find.find(c_1);
                        match self.lookup.entry((left_repr, right_repr)) {
                            Entry::Occupied(d) => {
                                state.pending.push((c, d.get().2));
                            }
                            Entry::Vacant(v) => {
                                v.insert((c_0, c_1, c));
                                self.use_lists
                                    .entry(new_repr)
                                    .or_default()
                                    .push((c_0, c_1, c));
                            }
                        }
                    }
                }
            }
            if let Some(pending) = state.pending.pop() {
                a = pending.0;
                b = pending.1;
            } else {
                return;
            }
        }
    }
}

/// State for implementing congruence closure
#[derive(Clone)]
pub struct CongruenceState<I = usize> {
    pending: Vec<(I, I)>,
}

impl<I> CongruenceState<I> {
    /// Creates a new, null congruence state. Guaranteed not to allocate
    pub const fn new() -> CongruenceState<I> {
        CongruenceState {
            pending: Vec::new(),
        }
    }

    /// Creates a new congruence state with the given capacity.
    pub fn with_capacity(cap: usize) -> CongruenceState<I> {
        CongruenceState {
            pending: Vec::with_capacity(cap),
        }
    }
}

impl<I> Debug for CongruenceState<I>
where
    I: Debug,
{
    #[inline]
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("CongruenceState")
            .field("pending", &self.pending)
            .field("capacity", &self.pending.capacity())
            .finish()
    }
}

impl<I> Default for CongruenceState<I> {
    fn default() -> Self {
        Self {
            pending: Default::default(),
        }
    }
}

/// A trait implemented by data-structures implementing the union-find ADT.
///
/// See [`DisjointSetForest`] for an example implementation and usage
pub trait UnionFind<I = usize> {
    /// Take the union of two nodes, returning the new root.
    ///
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// # let mut dsf = DisjointSetForest::with_capacity(3);
    /// let union = dsf.union_find(0, 1);
    /// ```
    /// is equivalent to
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// # let mut dsf = DisjointSetForest::with_capacity(3);
    /// dsf.union(0, 1);
    /// let union = dsf.find(0);
    /// assert_eq!(union, dsf.find(1));
    /// ```
    /// but may be optimized better.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert!(!dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// let union = dsf.union_find(0, 1);
    /// assert!(union == 0 || union == 1);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// ```
    fn union_find(&mut self, left: I, right: I) -> I;

    /// Merge the equivalence classes of two nodes
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert!(!dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// dsf.union(0, 1);
    /// assert!(dsf.node_eq(0, 1));
    /// assert!(!dsf.node_eq(0, 2));
    /// assert!(!dsf.node_eq(1, 2));
    /// ```
    fn union(&mut self, left: I, right: I);

    /// Find the representative of a node
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(3);
    /// assert_eq!(dsf.find(0), 0);
    /// assert_eq!(dsf.find(1), 1);
    /// dsf.union(0, 1);
    /// assert_eq!(dsf.find(0), dsf.find(1));
    /// ```
    fn find(&mut self, node: I) -> I;

    /// Check whether two nodes are in the same equivalence class
    fn node_eq(&mut self, left: I, right: I) -> bool;
}

/// A minimal [disjoint set forest](https://en.wikipedia.org/wiki/Disjoint-set_data_structure) with dense `usize` nodes, implementing the [`UnionFind`] trait.
///
/// Has support for up to `2**(64 - 8)` nodes, which should be enough to exhaust the entire address space of modern CPUs.
///
/// # Examples
/// ```rust
/// # use congruence::{DisjointSetForest, UnionFind};
/// let mut dsf = DisjointSetForest::with_capacity(4);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert_eq!(dsf.node_eq(i, j), i == j)
///     }
/// }
/// dsf.union(0, 1);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert_eq!(
///             dsf.node_eq(i, j),
///             i == j || (i == 1 && j == 0) || (i == 0 && j == 1)
///         )
///     }
/// }
/// dsf.union(2, 3);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert_eq!(
///             dsf.node_eq(i, j),
///             i == j
///                 || (i == 1 && j == 0)
///                 || (i == 0 && j == 1)
///                 || (i == 2 && j == 3)
///                 || (i == 3 && j == 2)
///         )
///     }
/// }
/// dsf.union(1, 2);
/// for i in 0..3 {
///     for j in 0..3 {
///         assert!(dsf.node_eq(i, j),)
///     }
/// }
/// ```
#[derive(Debug, Default, Clone)]
pub struct DisjointSetForest {
    nodes: Vec<Node>,
}

impl DisjointSetForest {
    /// Create a new, empty disjoint set forest. Guaranteed not to allocate.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// assert_eq!(dsf.capacity(), 0);
    /// ```
    #[inline(always)]
    pub const fn new() -> DisjointSetForest {
        DisjointSetForest { nodes: Vec::new() }
    }

    /// Create a new disjoint set forest with `n` nodes
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let dsf = DisjointSetForest::with_capacity(5);
    /// assert!(dsf.is_empty());
    /// assert!(dsf.capacity() >= 5);
    /// ```
    #[inline]
    pub fn with_capacity(n: usize) -> DisjointSetForest {
        DisjointSetForest {
            nodes: Vec::with_capacity(n),
        }
    }

    /// Reserve capacity for `n` additional nodes
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let mut dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// assert_eq!(dsf.capacity(), 0);
    /// dsf.reserve(5);
    /// assert!(dsf.is_empty());
    /// assert!(dsf.capacity() >= 5);
    /// ```
    #[inline(always)]
    pub fn reserve(&mut self, n: usize) {
        self.nodes.reserve(n);
    }

    /// Get whether this forest is empty, i.e. contains no *relations* between elements
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let mut dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// ```
    #[inline(always)]
    pub fn is_empty(&self) -> bool {
        for (ix, node) in self.nodes.iter().enumerate() {
            if node.parent() != ix {
                return false;
            }
        }
        true
    }

    /// Get the current capacity of this forest
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::DisjointSetForest;
    /// let mut dsf = DisjointSetForest::new();
    /// assert_eq!(dsf.capacity(), 0);
    /// dsf.reserve(1);
    /// assert!(dsf.capacity() >= 1);
    /// dsf.reserve(5);
    /// assert!(dsf.capacity() >= 6);
    /// ```
    #[inline(always)]
    pub fn capacity(&self) -> usize {
        self.nodes.capacity()
    }

    /// Clear this forest, removing all nodes but preserving it's capacity
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::new();
    /// assert!(dsf.is_empty());
    /// assert!(!dsf.node_eq(5, 3));
    /// assert_eq!(dsf.capacity(), 0);
    /// dsf.union(5, 3);
    /// assert!(!dsf.is_empty());
    /// assert!(dsf.node_eq(5, 3));
    /// let capacity = dsf.capacity();
    /// assert!(capacity >= 5);
    /// dsf.clear();
    /// assert!(dsf.is_empty());
    /// assert!(!dsf.node_eq(5, 3));
    /// assert_eq!(dsf.capacity(), capacity);
    /// ```
    #[inline(always)]
    pub fn clear(&mut self) {
        self.nodes.clear()
    }

    /// Insert a new node
    fn insert(&mut self) {
        self.nodes.push(Node::new(self.nodes.len()))
    }

    /// Find the representative of a node, without requiring a mutable borrow of this data structure.
    ///
    /// Always returns the same result as `find`, but does not optimize the data structure, leading to worse amortized performance.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(10);
    /// dsf.union(3, 5);
    /// dsf.union(6, 3);
    /// dsf.union(7, 2);
    /// for ix in 0..10 {
    ///     assert_eq!(dsf.find(ix), dsf.find_ref(ix));
    /// }
    /// ```
    #[inline]
    pub fn find_ref(&self, mut ix: usize) -> usize {
        if ix >= self.nodes.len() {
            return ix;
        }
        while self.nodes[ix].parent() != ix {
            ix = self.nodes[ix].parent()
        }
        ix
    }

    /// Check whether two nodes are in the same equivalence class, without requiring a mutable borrow of this data structure.
    ///
    /// Always returns the same result as [`Self::node_eq`], but does not optimize the data structure, leading to worse amortized performance.
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf = DisjointSetForest::with_capacity(10);
    /// dsf.union(3, 5);
    /// dsf.union(6, 3);
    /// dsf.union(7, 2);
    /// for i in 0..10 {
    ///     for j in 0..10 {
    ///         assert_eq!(dsf.node_eq(i, j), dsf.node_eq_ref(i, j))
    ///     }
    /// }
    /// ```
    #[inline]
    pub fn node_eq_ref(&self, i: usize, j: usize) -> bool {
        if i == j {
            true
        } else {
            self.find_ref(i) == self.find_ref(j)
        }
    }

    /// Get whether this disjoint set forest refines another, i.e., if `a ~ b` in `self`, then `a ~ b` in `other`
    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf_0 = DisjointSetForest::with_capacity(5);
    /// let mut dsf_1 = DisjointSetForest::with_capacity(5);
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// dsf_0.union(3, 4);
    /// dsf_0.union(3, 2);
    /// assert!(!dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_1.refines(&mut dsf_0));
    /// dsf_1.union(3, 4);
    /// assert!(!dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_1.refines(&mut dsf_0));
    /// dsf_1.union(3, 5);
    /// assert!(!dsf_0.refines(&mut dsf_1));
    /// assert!(!dsf_1.refines(&mut dsf_0));
    /// dsf_1.union(3, 2);
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// assert!(!dsf_1.refines(&mut dsf_0));
    /// dsf_0.union(3, 5);
    /// assert!(dsf_0.refines(&mut dsf_1));
    /// assert!(dsf_1.refines(&mut dsf_0));
    /// ```
    pub fn refines(&mut self, other: &mut impl UnionFind) -> bool {
        for ix in 0..self.nodes.len() {
            if !other.node_eq(ix, self.find(ix)) {
                return false;
            }
        }
        true
    }

    /// Get whether this disjoint set forest refines another, i.e., if `a ~ b` in `self`, then `a ~ b` in `other`, without performing path compression
    ///
    /// Always returns the same result as [Self::`refines`], but does not optimize the data structure, leading to worse amortized performance.    ///
    /// # Examples
    /// ```rust
    /// # use congruence::{DisjointSetForest, UnionFind};
    /// let mut dsf_0 = DisjointSetForest::with_capacity(5);
    /// let mut dsf_1 = DisjointSetForest::with_capacity(5);
    /// assert!(dsf_0.refines_ref(&mut dsf_1));
    /// assert!(dsf_0.refines_ref(&mut dsf_1));
    /// dsf_0.union(3, 4);
    /// dsf_0.union(3, 2);
    /// assert!(!dsf_0.refines_ref(&mut dsf_1));
    /// assert!(dsf_1.refines_ref(&mut dsf_0));
    /// dsf_1.union(3, 4);
    /// assert!(!dsf_0.refines_ref(&mut dsf_1));
    /// assert!(dsf_1.refines_ref(&mut dsf_0));
    /// dsf_1.union(3, 5);
    /// assert!(!dsf_0.refines_ref(&mut dsf_1));
    /// assert!(!dsf_1.refines_ref(&mut dsf_0));
    /// dsf_1.union(3, 2);
    /// assert!(dsf_0.refines_ref(&mut dsf_1));
    /// assert!(!dsf_1.refines_ref(&mut dsf_0));
    /// dsf_0.union(3, 5);
    /// assert!(dsf_0.refines_ref(&mut dsf_1));
    /// assert!(dsf_1.refines_ref(&mut dsf_0));
    /// ```
    pub fn refines_ref(&self, other: &DisjointSetForest) -> bool {
        for ix in 0..self.nodes.len() {
            if !other.node_eq_ref(ix, self.find_ref(ix)) {
                return false;
            }
        }
        true
    }

    /// Take the union of two representative nodes, returning the new root
    fn union_repr(&mut self, left: usize, right: usize) -> usize {
        let (parent, child, bump) = match self.nodes[left].rank().cmp(&self.nodes[right].rank()) {
            Ordering::Greater => (left, right, 0),
            Ordering::Equal => (left, right, 1),
            Ordering::Less => (right, left, 0),
        };
        self.nodes[child].set_parent(parent);
        self.nodes[parent].bump_rank(bump);
        parent
    }
}

impl PartialEq for DisjointSetForest {
    fn eq(&self, other: &Self) -> bool {
        self.refines_ref(other) && other.refines_ref(self)
    }
}

impl Eq for DisjointSetForest {}

impl UnionFind for DisjointSetForest {
    #[inline]
    fn union_find(&mut self, left: usize, right: usize) -> usize {
        if left >= self.nodes.len() || right >= self.nodes.len() {
            let max = left.max(right);
            self.nodes.reserve(1 + max - self.nodes.len());
            while max >= self.nodes.len() {
                self.insert()
            }
        }
        let left = self.find(left);
        let right = self.find(right);
        self.union_repr(left, right)
    }

    #[inline]
    fn union(&mut self, left: usize, right: usize) {
        self.union_find(left, right);
    }

    #[inline]
    fn find(&mut self, mut node: usize) -> usize {
        if node >= self.nodes.len() {
            return node;
        }
        loop {
            let parent = self.nodes[node].parent();
            if parent == node {
                return node;
            }
            let grandparent = self.nodes[parent].parent();
            self.nodes[node].set_parent(grandparent);
            node = parent;
        }
    }

    #[inline]
    fn node_eq(&mut self, left: usize, right: usize) -> bool {
        if left == right {
            true
        } else {
            self.find(left) == self.find(right)
        }
    }
}

/// A node in a disjoint set forest. The first 8 bits are the node's rank, and the remaining bits are the node's parent
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
struct Node(u64);

impl Node {
    const RANK_SHIFT: u32 = u64::BITS - 8;
    const IX_MASK: u64 = (0b1 << Self::RANK_SHIFT) - 1;
    const RANK_MASK: u64 = !Self::IX_MASK;

    fn new(parent: usize) -> Node {
        assert_eq!(
            parent as u64 & !Self::IX_MASK,
            0,
            "Overflowed maximum node capacity!"
        );
        Node(parent as u64)
    }

    fn set_parent(&mut self, ix: usize) {
        debug_assert_eq!(ix as u64 & !Self::IX_MASK, 0, "Set invalid parent {}", ix);
        self.0 = (self.0 & !Self::IX_MASK) | (ix as u64)
    }

    fn parent(&self) -> usize {
        (self.0 & Self::IX_MASK) as usize
    }

    fn rank(&self) -> u8 {
        ((self.0 & Self::RANK_MASK) >> Self::RANK_SHIFT) as u8
    }

    fn bump_rank(&mut self, bump: u8) {
        self.0 = (self.0 & !Self::RANK_MASK)
            | (self.rank().saturating_add(bump) as u64) << Self::RANK_SHIFT;
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use rand::prelude::SliceRandom;
    use rand::Rng;

    #[test]
    fn congruence_closure_order_and_join() {
        let mut rng = rand::thread_rng();
        let mut dsf = DisjointSetForest::with_capacity(5);
        let mut cc_dsf = DisjointSetForest::with_capacity(5);
        let mut cc_end_dsf = DisjointSetForest::with_capacity(5);
        let mut cc = CongruenceClosure::new();
        let mut cc_end = CongruenceClosure::new();
        let mut st = CongruenceState::default();
        let mut equations = Vec::with_capacity(4096);

        for _ in 0..1024 {
            let a = rng.gen::<usize>() % 32;
            let b = rng.gen::<usize>() % 32;
            dsf.union(a, b);
            cc_end.union(a, b, &mut cc_end_dsf, &mut st);
            cc.union(a, b, &mut cc_dsf, &mut st);
            let c = rng.gen::<usize>() % 32;
            let d = rng.gen::<usize>() % 32;
            dsf.union(c, d);
            cc_end.union(c, d, &mut cc_end_dsf, &mut st);
            cc.union(c, d, &mut cc_dsf, &mut st);
            let e = rng.gen::<usize>() % 32;
            let f = rng.gen::<usize>() % 32;
            cc.equation(a, b, c, &mut cc_dsf, &mut st);
            cc.equation(d, e, f, &mut cc_dsf, &mut st);
            equations.push((a, b, c, d, e, f));
        }
        assert!(dsf.refines(&mut cc_dsf));
        assert!(cc_end_dsf.refines(&mut cc_dsf));
        assert_eq!(dsf, cc_end_dsf);

        equations.shuffle(&mut rng);
        for (a, b, c, d, e, f) in equations {
            cc_end.equation(a, b, c, &mut cc_dsf, &mut st);
            cc_end.equation(d, e, f, &mut cc_dsf, &mut st);
            assert!(dsf.refines(&mut cc_end_dsf));
            assert!(cc_end_dsf.refines(&mut cc_dsf));
        }
        assert_eq!(cc_dsf, cc_end_dsf);
    }
}
